diff --git a/tests/test_inference.py b/tests/test_inference.py index 969db39..5e63d2a 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -47,17 +47,20 @@ def test_pipeline_inputs(dummy_data, dummy_metadata): "arc_spice.variational_pipelines.RTC_variational_pipeline.pipeline", ): with patch( - ( - "arc_spice.variational_pipelines.RTC_variational_pipeline." - "RTCVariationalPipeline._init_semantic_density" - ), - return_value=None, + "arc_spice.variational_pipelines.utils.pipeline", ): - pipeline = RTCVariationalPipeline( - model_pars=pipeline_config, - data_pars=dummy_metadata, - translation_batch_size=1, - ) + with patch( + ( + "arc_spice.variational_pipelines.RTC_variational_pipeline." + "RTCVariationalPipeline._init_semantic_density" + ), + return_value=None, + ): + pipeline = RTCVariationalPipeline( + model_pars=pipeline_config, + data_pars=dummy_metadata, + translation_batch_size=1, + ) dummy_recognise_output = {"outputs": "rec text"} dummy_translate_output = {"outputs": ["translate text"]} @@ -86,23 +89,26 @@ def test_single_component_inputs(dummy_data, dummy_metadata): "arc_spice.variational_pipelines.RTC_single_component_pipeline.pipeline" ): with patch( - ( - "arc_spice.variational_pipelines.RTC_single_component_pipeline." - "RTCSingleComponentPipeline._init_semantic_density" - ), - return_value=None, + "arc_spice.variational_pipelines.utils.pipeline", ): - recognise_pipeline = RecognitionVariationalPipeline( - model_pars=pipeline_config, - ) - translate_pipeline = TranslationVariationalPipeline( - model_pars=pipeline_config, - translation_batch_size=1, - ) - classify_pipeline = ClassificationVariationalPipeline( - model_pars=pipeline_config, - data_pars=dummy_metadata, - ) + with patch( + ( + "arc_spice.variational_pipelines.RTC_single_component_pipeline." + "RTCSingleComponentPipeline._init_semantic_density" + ), + return_value=None, + ): + recognise_pipeline = RecognitionVariationalPipeline( + model_pars=pipeline_config, + ) + translate_pipeline = TranslationVariationalPipeline( + model_pars=pipeline_config, + translation_batch_size=1, + ) + classify_pipeline = ClassificationVariationalPipeline( + model_pars=pipeline_config, + data_pars=dummy_metadata, + ) recognise_pipeline.forward_function = MagicMock(return_value=dummy_recognise_output) translate_pipeline.forward_function = MagicMock(return_value=dummy_translate_output)