diff --git a/tests/test_api_data.py b/tests/test_api_data.py index 9db5e6561..c2def5bb9 100644 --- a/tests/test_api_data.py +++ b/tests/test_api_data.py @@ -82,9 +82,11 @@ def test_data_switch_velocity_representation( ) # The following instead should result to an updated `data` object. - with data.switch_velocity_representation(velocity_representation=VelRepr.Inertial): - with data.mutable_context(mutability=Mutability.MUTABLE): - data.state.physics_model.base_linear_velocity = new_base_linear_velocity + with ( + data.switch_velocity_representation(velocity_representation=VelRepr.Inertial), + data.mutable_context(mutability=Mutability.MUTABLE), + ): + data.state.physics_model.base_linear_velocity = new_base_linear_velocity assert data.state.physics_model.base_linear_velocity == pytest.approx( new_base_linear_velocity diff --git a/tests/test_api_model.py b/tests/test_api_model.py index 1c57f283f..e0e428968 100644 --- a/tests/test_api_model.py +++ b/tests/test_api_model.py @@ -358,14 +358,16 @@ def test_model_jacobian( # Get the J.T @ f product in inertial-fixed input/output representation. # We use doubly right-trivialized jacobian with inertial-fixed 6D forces. - with references.switch_velocity_representation(VelRepr.Inertial): - with data.switch_velocity_representation(VelRepr.Inertial): + with ( + references.switch_velocity_representation(VelRepr.Inertial), + data.switch_velocity_representation(VelRepr.Inertial), + ): - f = references.link_forces(model=model, data=data) - assert f == pytest.approx(references.input.physics_model.f_ext) + f = references.link_forces(model=model, data=data) + assert f == pytest.approx(references.input.physics_model.f_ext) - J = js.model.generalized_free_floating_jacobian(model=model, data=data) - JTf_inertial = jnp.einsum("l6g,l6->g", J, f) + J = js.model.generalized_free_floating_jacobian(model=model, data=data) + JTf_inertial = jnp.einsum("l6g,l6->g", J, f) for vel_repr in [VelRepr.Body, VelRepr.Mixed]: with references.switch_velocity_representation(vel_repr): diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py index 7a0eeda36..c46166987 100644 --- a/tests/test_exceptions.py +++ b/tests/test_exceptions.py @@ -44,25 +44,23 @@ def jit_compiled_function(data: jax.Array) -> jax.Array: return data # In the first call, the function will be compiled and print the message. - with jax.log_compiles(): - with io.StringIO() as buf, redirect_stdout(buf): + with jax.log_compiles(), io.StringIO() as buf, redirect_stdout(buf): - data = 40 - out = jit_compiled_function(data=data) - stdout = buf.getvalue() - assert out == data + data = 40 + out = jit_compiled_function(data=data) + stdout = buf.getvalue() + assert out == data assert msg_during_jit in stdout assert jit_compiled_function._cache_size() == 1 # In the second call, the function won't be compiled and won't print the message. - with jax.log_compiles(): - with io.StringIO() as buf, redirect_stdout(buf): + with jax.log_compiles(), io.StringIO() as buf, redirect_stdout(buf): - data = 41 - out = jit_compiled_function(data=data) - stdout = buf.getvalue() - assert out == data + data = 41 + out = jit_compiled_function(data=data) + stdout = buf.getvalue() + assert out == data assert msg_during_jit not in stdout assert jit_compiled_function._cache_size() == 1 diff --git a/tests/test_pytree.py b/tests/test_pytree.py index c2fcc0149..c27179254 100644 --- a/tests/test_pytree.py +++ b/tests/test_pytree.py @@ -32,13 +32,12 @@ def test_call_jit_compiled_function_passing_different_objects( _ = js.contact.estimate_good_contact_parameters(model=model1) # Now JAX should not compile it again. - with jax.log_compiles(): - with io.StringIO() as buf, redirect_stdout(buf): - # Beyond running without any JIT recompilations, the following function - # should work on different JaxSimModel objects without raising any errors - # related to the comparison of Static fields. - _ = js.contact.estimate_good_contact_parameters(model=model2) - stdout = buf.getvalue() + with jax.log_compiles(), io.StringIO() as buf, redirect_stdout(buf): + # Beyond running without any JIT recompilations, the following function + # should work on different JaxSimModel objects without raising any errors + # related to the comparison of Static fields. + _ = js.contact.estimate_good_contact_parameters(model=model2) + stdout = buf.getvalue() assert ( f"Compiling {js.contact.estimate_good_contact_parameters.__name__}"