Skip to content

Commit

Permalink
Used named_parameters in Menagerie model_test.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 671398089
Change-Id: I92fe35b9c4d441da92cfde9a9bdc8541db19ab77
  • Loading branch information
nimrod-gileadi authored and copybara-github committed Sep 5, 2024
1 parent 3966424 commit 8229e08
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions test/model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,10 @@

def _get_xmls(pattern: str) -> List[pathlib.Path]:
for d in _MODEL_DIRS:
yield from d.glob(pattern)
# Produce tuples of test name and XML path.
for f in d.glob(pattern):
test_name = str(f).removeprefix(str(f.parent.parent))
yield (test_name, f)

_MODEL_XMLS = list(_get_xmls('scene*.xml'))
_MJX_MODEL_XMLS = list(_get_xmls('scene*mjx.xml'))
Expand Down Expand Up @@ -65,7 +68,7 @@ def _pseudorandom_ctrlnoise(
class ModelsTest(parameterized.TestCase):
"""Tests that MuJoCo models load and do not emit warnings."""

@parameterized.parameters(_MODEL_XMLS)
@parameterized.named_parameters(_MODEL_XMLS)
def test_compiles_and_steps(self, xml_path: pathlib.Path) -> None:
model = mujoco.MjModel.from_xml_path(str(xml_path))
data = mujoco.MjData(model)
Expand All @@ -86,7 +89,7 @@ def test_compiles_and_steps(self, xml_path: pathlib.Path) -> None:
class MjxModelsTest(parameterized.TestCase):
"""Tests that MJX models load and do not return NaNs."""

@parameterized.parameters(_MJX_MODEL_XMLS)
@parameterized.named_parameters(_MJX_MODEL_XMLS)
def test_compiles_and_steps(self, xml_path: pathlib.Path) -> None:
model = mujoco.MjModel.from_xml_path(str(xml_path))
model = mjx.put_model(model)
Expand Down

0 comments on commit 8229e08

Please sign in to comment.