Skip to content

Commit

Permalink
Add dtype choice in step type/functions (#262)
Browse files Browse the repository at this point in the history
  • Loading branch information
thomashirtz authored Nov 15, 2024
1 parent a8a4d3b commit 7979c36
Showing 1 changed file with 18 additions and 5 deletions.
23 changes: 18 additions & 5 deletions jumanji/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ def restart(
observation: Observation,
extras: Optional[Dict] = None,
shape: Union[int, Sequence[int]] = (),
dtype: Union[jnp.dtype, type] = float,
) -> TimeStep:
"""Returns a `TimeStep` with `step_type` set to `StepType.FIRST`.
Expand All @@ -107,15 +108,17 @@ def restart(
shape: optional parameter to specify the shape of the rewards and discounts.
Allows multi-agent environment compatibility. Defaults to () for
scalar reward and discount.
dtype: Optional parameter to specify the data type of the rewards and discounts.
Defaults to `float`.
Returns:
TimeStep identified as a reset.
"""
extras = extras or {}
return TimeStep(
step_type=StepType.FIRST,
reward=jnp.zeros(shape, dtype=float),
discount=jnp.ones(shape, dtype=float),
reward=jnp.zeros(shape, dtype=dtype),
discount=jnp.ones(shape, dtype=dtype),
observation=observation,
extras=extras,
)
Expand All @@ -127,6 +130,7 @@ def transition(
discount: Optional[Array] = None,
extras: Optional[Dict] = None,
shape: Union[int, Sequence[int]] = (),
dtype: Union[jnp.dtype, type] = float,
) -> TimeStep:
"""Returns a `TimeStep` with `step_type` set to `StepType.MID`.
Expand All @@ -141,11 +145,13 @@ def transition(
shape: optional parameter to specify the shape of the rewards and discounts.
Allows multi-agent environment compatibility. Defaults to () for
scalar reward and discount.
dtype: Optional parameter to specify the data type of the discounts. Defaults
to `float`.
Returns:
TimeStep identified as a transition.
"""
discount = discount if discount is not None else jnp.ones(shape, dtype=float)
discount = discount if discount is not None else jnp.ones(shape, dtype=dtype)
extras = extras or {}
return TimeStep(
step_type=StepType.MID,
Expand All @@ -161,6 +167,7 @@ def termination(
observation: Observation,
extras: Optional[Dict] = None,
shape: Union[int, Sequence[int]] = (),
dtype: Union[jnp.dtype, type] = float,
) -> TimeStep:
"""Returns a `TimeStep` with `step_type` set to `StepType.LAST`.
Expand All @@ -174,6 +181,8 @@ def termination(
shape: optional parameter to specify the shape of the rewards and discounts.
Allows multi-agent environment compatibility. Defaults to () for
scalar reward and discount.
dtype: Optional parameter to specify the data type of the discounts. Defaults
to `float`.
Returns:
TimeStep identified as the termination of an episode.
Expand All @@ -182,7 +191,7 @@ def termination(
return TimeStep(
step_type=StepType.LAST,
reward=reward,
discount=jnp.zeros(shape, dtype=float),
discount=jnp.zeros(shape, dtype=dtype),
observation=observation,
extras=extras,
)
Expand All @@ -194,6 +203,7 @@ def truncation(
discount: Optional[Array] = None,
extras: Optional[Dict] = None,
shape: Union[int, Sequence[int]] = (),
dtype: Union[jnp.dtype, type] = float,
) -> TimeStep:
"""Returns a `TimeStep` with `step_type` set to `StepType.LAST`.
Expand All @@ -208,10 +218,13 @@ def truncation(
shape: optional parameter to specify the shape of the rewards and discounts.
Allows multi-agent environment compatibility. Defaults to () for
scalar reward and discount.
dtype: Optional parameter to specify the data type of the discounts. Defaults
to `float`.
Returns:
TimeStep identified as the truncation of an episode.
"""
discount = discount if discount is not None else jnp.ones(shape, dtype=float)
discount = discount if discount is not None else jnp.ones(shape, dtype=dtype)
extras = extras or {}
return TimeStep(
step_type=StepType.LAST,
Expand Down

0 comments on commit 7979c36

Please sign in to comment.