diff --git a/config/RTC_configs/bert-mt5-zero-shot.yaml b/config/RTC_configs/roberta-mt5-zero-shot.yaml similarity index 100% rename from config/RTC_configs/bert-mt5-zero-shot.yaml rename to config/RTC_configs/roberta-mt5-zero-shot.yaml diff --git a/tests/test_inference.py b/tests/test_inference.py index 1498bcd..d032d38 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -14,6 +14,8 @@ CONFIG_ROOT = f"{os.path.dirname(os.path.abspath(__file__))}/../config/" +PIPELINE_PATH = f"{CONFIG_ROOT}/RTC_configs/roberta-mt5-zero-shot.yaml" + @pytest.fixture() def dummy_data(): @@ -38,8 +40,7 @@ def dummy_metadata(): def test_pipeline_inputs(dummy_data, dummy_metadata): - pipeline_config_path = f"{CONFIG_ROOT}/RTC_configs/bert-mt5-bert.yaml" - pipeline_config = open_yaml_path(pipeline_config_path) + pipeline_config = open_yaml_path(PIPELINE_PATH) with patch( # noqa: SIM117 "arc_spice.variational_pipelines.RTC_variational_pipeline.pipeline", @@ -74,8 +75,7 @@ def test_pipeline_inputs(dummy_data, dummy_metadata): def test_single_component_inputs(dummy_data, dummy_metadata): - pipeline_config_path = f"{CONFIG_ROOT}/RTC_configs/bert-mt5-bert.yaml" - pipeline_config = open_yaml_path(pipeline_config_path) + pipeline_config = open_yaml_path(PIPELINE_PATH) dummy_recognise_output = {"outputs": "rec text"} dummy_translate_output = {"outputs": ["translate text"]} dummy_classification_output = {"outputs": "classification"}