diff --git a/py_trees/behaviours.py b/py_trees/behaviours.py index 0c283004..6fbd3df1 100644 --- a/py_trees/behaviours.py +++ b/py_trees/behaviours.py @@ -16,6 +16,7 @@ import copy import functools import operator +import random import typing from . import behaviour, blackboard, common, meta @@ -736,3 +737,47 @@ def update(self) -> common.Status: "|".join(["T" if result else "F" for result in results]) ) return common.Status.FAILURE + + +class ProbabilisticBehaviour(behaviour.Behaviour): + """ + Return a status based on a probability distribution. If unspecified - a uniform distribution will be used. + + Args: + name: name of the behaviour + weights: 3 probabilities that correspond to returning :data:`~py_trees.common.Status.SUCCESS`, + :data:`~py_trees.common.Status.FAILURE` and :data:`~py_trees.common.Status.RUNNING` respectively. + + .. note:: Probability distribution does not need to be normalised, it will be normalised internally. + + Raises: + ValueError if only some probabilities are specified + + """ + + def __init__(self, name: str, weights: typing.Optional[typing.List[float]] = None): + if weights is not None and (type(weights) is not list or len(weights) != 3): + raise ValueError( + "Either all or none of the probabilities must be specified" + ) + + super(ProbabilisticBehaviour, self).__init__(name=name) + + self._population = [ + common.Status.SUCCESS, + common.Status.FAILURE, + common.Status.RUNNING, + ] + self._weights = weights if weights is not None else [1.0, 1.0, 1.0] + + def update(self) -> common.Status: + """ + Return a status based on a probability distribution. + + Returns: + :data:`~py_trees.common.Status.SUCCESS` with probability weights[0], + :data:`~py_trees.common.Status.FAILURE` with probability weights[1] and + :data:`~py_trees.common.Status.RUNNING` with probability weights[2]. + """ + self.logger.debug("%s.update()" % self.__class__.__name__) + return random.choices(self._population, self._weights, k=1)[0] diff --git a/tests/test_probabilistic_behaviour.py b/tests/test_probabilistic_behaviour.py new file mode 100644 index 00000000..5fe0714b --- /dev/null +++ b/tests/test_probabilistic_behaviour.py @@ -0,0 +1,57 @@ +#!/usr/bin/env python +# +# License: BSD +# https://raw.githubusercontent.com/splintered-reality/py_trees/devel/LICENSE +# + +############################################################################## +# Imports +############################################################################## + +import py_trees +import py_trees.console as console +import py_trees.tests +import pytest + +############################################################################## +# Logging Level +############################################################################## + +py_trees.logging.level = py_trees.logging.Level.DEBUG +logger = py_trees.logging.Logger("Tests") + +############################################################################## +# Tests +############################################################################## + + +def test_probabilistic_behaviour_workflow() -> None: + console.banner("Probabilistic Behaviour") + + with pytest.raises(ValueError) as context: # if raised, context survives + # intentional error -> silence mypy + unused_root = py_trees.behaviours.ProbabilisticBehaviour( # noqa: F841 [unused] + name="ProbabilisticBehaviour", weights="invalid_type" # type: ignore[arg-type] + ) + py_trees.tests.print_assert_details("ValueError raised", "raised", "not raised") + py_trees.tests.print_assert_details("ValueError raised", "yes", "yes") + assert "ValueError" == context.typename + + root = py_trees.behaviours.ProbabilisticBehaviour( + name="ProbabilisticBehaviour", weights=[0.0, 0.0, 1.0] + ) + + py_trees.tests.print_assert_details( + text="task not yet ticked", + expected=py_trees.common.Status.INVALID, + result=root.status, + ) + assert root.status == py_trees.common.Status.INVALID + + root.tick_once() + py_trees.tests.print_assert_details( + text="task ticked once", + expected=py_trees.common.Status.RUNNING, + result=root.status, + ) + assert root.status == py_trees.common.Status.RUNNING