Skip to content

Commit

Permalink
Merge chained context managers
Browse files Browse the repository at this point in the history
  • Loading branch information
flferretti committed Oct 23, 2024
1 parent c271f01 commit 84a4801
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 28 deletions.
8 changes: 5 additions & 3 deletions tests/test_api_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,11 @@ def test_data_switch_velocity_representation(
)

# The following instead should result to an updated `data` object.
with data.switch_velocity_representation(velocity_representation=VelRepr.Inertial):
with data.mutable_context(mutability=Mutability.MUTABLE):
data.state.physics_model.base_linear_velocity = new_base_linear_velocity
with (
data.switch_velocity_representation(velocity_representation=VelRepr.Inertial),
data.mutable_context(mutability=Mutability.MUTABLE),
):
data.state.physics_model.base_linear_velocity = new_base_linear_velocity

assert data.state.physics_model.base_linear_velocity == pytest.approx(
new_base_linear_velocity
Expand Down
14 changes: 8 additions & 6 deletions tests/test_api_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,14 +358,16 @@ def test_model_jacobian(

# Get the J.T @ f product in inertial-fixed input/output representation.
# We use doubly right-trivialized jacobian with inertial-fixed 6D forces.
with references.switch_velocity_representation(VelRepr.Inertial):
with data.switch_velocity_representation(VelRepr.Inertial):
with (
references.switch_velocity_representation(VelRepr.Inertial),
data.switch_velocity_representation(VelRepr.Inertial),
):

f = references.link_forces(model=model, data=data)
assert f == pytest.approx(references.input.physics_model.f_ext)
f = references.link_forces(model=model, data=data)
assert f == pytest.approx(references.input.physics_model.f_ext)

J = js.model.generalized_free_floating_jacobian(model=model, data=data)
JTf_inertial = jnp.einsum("l6g,l6->g", J, f)
J = js.model.generalized_free_floating_jacobian(model=model, data=data)
JTf_inertial = jnp.einsum("l6g,l6->g", J, f)

for vel_repr in [VelRepr.Body, VelRepr.Mixed]:
with references.switch_velocity_representation(vel_repr):
Expand Down
22 changes: 10 additions & 12 deletions tests/test_exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,25 +44,23 @@ def jit_compiled_function(data: jax.Array) -> jax.Array:
return data

# In the first call, the function will be compiled and print the message.
with jax.log_compiles():
with io.StringIO() as buf, redirect_stdout(buf):
with jax.log_compiles(), io.StringIO() as buf, redirect_stdout(buf):

data = 40
out = jit_compiled_function(data=data)
stdout = buf.getvalue()
assert out == data
data = 40
out = jit_compiled_function(data=data)
stdout = buf.getvalue()
assert out == data

assert msg_during_jit in stdout
assert jit_compiled_function._cache_size() == 1

# In the second call, the function won't be compiled and won't print the message.
with jax.log_compiles():
with io.StringIO() as buf, redirect_stdout(buf):
with jax.log_compiles(), io.StringIO() as buf, redirect_stdout(buf):

data = 41
out = jit_compiled_function(data=data)
stdout = buf.getvalue()
assert out == data
data = 41
out = jit_compiled_function(data=data)
stdout = buf.getvalue()
assert out == data

assert msg_during_jit not in stdout
assert jit_compiled_function._cache_size() == 1
Expand Down
13 changes: 6 additions & 7 deletions tests/test_pytree.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,12 @@ def test_call_jit_compiled_function_passing_different_objects(
_ = js.contact.estimate_good_contact_parameters(model=model1)

# Now JAX should not compile it again.
with jax.log_compiles():
with io.StringIO() as buf, redirect_stdout(buf):
# Beyond running without any JIT recompilations, the following function
# should work on different JaxSimModel objects without raising any errors
# related to the comparison of Static fields.
_ = js.contact.estimate_good_contact_parameters(model=model2)
stdout = buf.getvalue()
with jax.log_compiles(), io.StringIO() as buf, redirect_stdout(buf):
# Beyond running without any JIT recompilations, the following function
# should work on different JaxSimModel objects without raising any errors
# related to the comparison of Static fields.
_ = js.contact.estimate_good_contact_parameters(model=model2)
stdout = buf.getvalue()

assert (
f"Compiling {js.contact.estimate_good_contact_parameters.__name__}"
Expand Down

0 comments on commit 84a4801

Please sign in to comment.