diff --git a/src/arc_spice/variational_pipelines/RTC_single_component_pipeline.py b/src/arc_spice/variational_pipelines/RTC_single_component_pipeline.py index 3774d87..9a13646 100644 --- a/src/arc_spice/variational_pipelines/RTC_single_component_pipeline.py +++ b/src/arc_spice/variational_pipelines/RTC_single_component_pipeline.py @@ -23,7 +23,6 @@ class RTCSingleComponentPipeline(RTCVariationalPipelineBase): def __init__( self, - model, step_name, input_key, forward_function, @@ -48,7 +47,6 @@ def __init__( "scores", ], } - self.model = model self.step_name = step_name self.input_key = input_key self.forward_function = forward_function @@ -73,13 +71,14 @@ def variational_inference(self, x): } # variational stage is the same as the full pipeline model, with different input # turn on dropout for this model - set_dropout(model=self.model, dropout_flag=True) + # model will be defined in the subclass + set_dropout(model=self.model, dropout_flag=True) # type: ignore[attr-defined] torch.nn.functional.dropout = dropout_on # do n runs of the inference for run_idx in range(self.n_variational_runs): var_output[self.step_name][run_idx] = self.forward_function(inp) # turn off dropout for this model - set_dropout(model=self.model, dropout_flag=False) + set_dropout(model=self.model, dropout_flag=False) # type: ignore[attr-defined] torch.nn.functional.dropout = dropout_off var_output = self.stack_variational_outputs(var_output) # For confidence function we need to pass both outputs in all cases @@ -104,8 +103,8 @@ def __init__( device=self.device, **kwargs, ) + self.model = self.ocr.model super().__init__( - model=self.ocr.model, step_name="recognition", input_key="ocr_data", forward_function=self.recognise, @@ -113,6 +112,7 @@ def __init__( n_variational_runs=n_variational_runs, **kwargs, ) + self._init_pipeline_map() class TranslationVariationalPipeline(RTCSingleComponentPipeline): @@ -124,17 +124,9 @@ def __init__( **kwargs, ): self.set_device() - self.translator = pipeline( - task=model_pars["translator"]["specific_task"], - model=model_pars["translator"]["model"], - max_length=512, - pipeline_class=CustomTranslationPipeline, - device=self.device, - ) # need to initialise the NLI models in this case self._init_semantic_density() super().__init__( - model=self.translator.model, step_name="translation", input_key="source_text", forward_function=self.translate, @@ -142,6 +134,15 @@ def __init__( n_variational_runs=n_variational_runs, translation_batch_size=translation_batch_size, ) + self.translator = pipeline( + task=model_pars["translator"]["specific_task"], + model=model_pars["translator"]["model"], + max_length=512, + pipeline_class=CustomTranslationPipeline, + device=self.device, + ) + self.model = self.translator.model + self._init_pipeline_map() class ClassificationVariationalPipeline(RTCSingleComponentPipeline): @@ -160,14 +161,7 @@ def __init__( **kwargs, ): self.set_device() - self.classifier = pipeline( - task=model_pars["classifier"]["specific_task"], - model=model_pars["classifier"]["model"], - multi_label=True, - device=self.device, - ) super().__init__( - model=self.classifier.model, step_name="classification", input_key="target_text", forward_function=self.classify_topic, @@ -175,8 +169,16 @@ def __init__( n_variational_runs=n_variational_runs, **kwargs, ) + self.classifier = pipeline( + task=model_pars["classifier"]["specific_task"], + model=model_pars["classifier"]["model"], + multi_label=True, + device=self.device, + ) + self.model = self.classifier.model # topic description labels for the classifier self.topic_labels = [ class_names_dict["en"] for class_names_dict in data_pars["class_descriptors"] ] + self._init_pipeline_map() diff --git a/src/arc_spice/variational_pipelines/utils.py b/src/arc_spice/variational_pipelines/utils.py index 0e7f9aa..c9e2692 100644 --- a/src/arc_spice/variational_pipelines/utils.py +++ b/src/arc_spice/variational_pipelines/utils.py @@ -1,4 +1,5 @@ import logging +from abc import ABC, abstractmethod from functools import partial from typing import Any @@ -9,34 +10,6 @@ logger = logging.Logger("RTC_variational_pipeline") -class DummyPipeline: - """ - For initialising a base pipeline which needs to be overwritten by a subclass - """ - - def __init__(self, model_name): - """ - Gives the dummy pipeline the required attributes for the method definitions - - Args: - model_name: name of the pipeline that is being given a dummy - """ - self.model = model_name - - def __call__(self, *args, **kwargs): - """ - Needs to be defined in subclass - - Raises: - NotImplementedError: when called to prevent base class being used - """ - error_msg = ( - f"{self.model} cannot be called directly and needs to be" - " defined within a subclass." - ) - raise NotImplementedError(error_msg) - - def set_dropout(model: torch.nn.Module, dropout_flag: bool) -> None: """ Turn on or turn off dropout layers of a model. @@ -110,7 +83,20 @@ def dropout_w_training_override( dropout_off = partial(dropout_w_training_override, training_override=False) -class RTCVariationalPipelineBase: +class RTCVariationalPipelineBase(ABC): + """ + Base class for the RTC variational pipelines, cannot be instantiated directly, needs + to have `clean_inference` and `variational_inference` defined by subclass. + """ + + @abstractmethod + def clean_inference(self, x): + pass + + @abstractmethod + def variational_inference(self, x): + pass + def __init__(self, n_variational_runs=5, translation_batch_size=8): # device for inference self.set_device() @@ -140,17 +126,13 @@ def __init__(self, n_variational_runs=5, translation_batch_size=8): self.n_variational_runs = n_variational_runs self.translation_batch_size = translation_batch_size - if not hasattr(self, "ocr"): - self.ocr = DummyPipeline("ocr") - if not hasattr(self, "translator"): - self.translator = DummyPipeline("translator") - if not hasattr(self, "classifier"): - self.classifier = DummyPipeline("classifier") + self.ocr = None + self.translator = None + self.classifier = None # map pipeline names to their pipeline counterparts self.topic_labels = None # This should be defined in subclass if needed - self._init_pipeline_map() def _init_pipeline_map(self): """ @@ -212,7 +194,13 @@ def check_dropout(self): """ logger.debug("\n\n------------------ Testing Dropout --------------------") for model_key, pl in self.pipeline_map.items(): - if not isinstance(pl, Pipeline): + # only test models that exist + if pl is None: + pipeline_none_msg_key = ( + f"pipeline under model key, `{model_key}`, is currently" + " set to None. Was this intended?" + ) + logger.debug(pipeline_none_msg_key) continue # turn on dropout for this model set_dropout(model=pl.model, dropout_flag=True) @@ -240,6 +228,8 @@ def recognise(self, inp) -> dict[str, str]: dictionary of outputs """ # Until the OCR data is available + # This will need the below comment: + # type: ignore[misc] # TODO https://github.com/alan-turing-institute/ARC-SPICE/issues/14 return {"outputs": inp["source_text"]} @@ -256,7 +246,7 @@ def translate(self, text: str) -> dict[str, torch.Tensor | str]: # split text into sentences text_splits = self.split_translate_inputs(text, ".") # perform translation on sentences - translator_outputs = self.translator( + translator_outputs = self.translator( # type: ignore[misc] text_splits, output_logits=True, return_dict_in_generate=True, @@ -297,7 +287,7 @@ def classify_topic(self, text: str) -> dict[str, str]: Returns: Dictionary of classification outputs, namely the output scores. """ - forward = self.classifier(text, self.topic_labels) + forward = self.classifier(text, self.topic_labels) # type: ignore[misc] return {"scores": forward["scores"]} def stack_translator_sentence_metrics( diff --git a/tests/test_inference.py b/tests/test_inference.py index 047a4d2..a7e6b08 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -12,7 +12,6 @@ from arc_spice.variational_pipelines.RTC_variational_pipeline import ( RTCVariationalPipeline, ) -from arc_spice.variational_pipelines.utils import DummyPipeline CONFIG_ROOT = f"{os.path.dirname(os.path.abspath(__file__))}/../config/" @@ -46,7 +45,7 @@ def test_pipeline_inputs(dummy_data, dummy_metadata): with patch( # noqa: SIM117 "arc_spice.variational_pipelines.RTC_variational_pipeline.pipeline", - return_value=DummyPipeline("dummy_model"), + return_value=None, ): with patch( ( @@ -83,8 +82,7 @@ def test_single_component_inputs(dummy_data, dummy_metadata): dummy_classification_output = {"outputs": "classification"} with patch( # noqa: SIM117 - "arc_spice.variational_pipelines.RTC_single_component_pipeline.pipeline", - return_value=DummyPipeline("dummy_model"), + "arc_spice.variational_pipelines.RTC_single_component_pipeline.pipeline" ): with patch( (