Skip to content

Commit

Permalink
addressed comments from pull request
Browse files Browse the repository at this point in the history
  • Loading branch information
J-Dymond committed Nov 29, 2024
1 parent 4906b26 commit 2b3db5d
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 63 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ class RTCSingleComponentPipeline(RTCVariationalPipelineBase):

def __init__(
self,
model,
step_name,
input_key,
forward_function,
Expand All @@ -48,7 +47,6 @@ def __init__(
"scores",
],
}
self.model = model
self.step_name = step_name
self.input_key = input_key
self.forward_function = forward_function
Expand All @@ -73,13 +71,14 @@ def variational_inference(self, x):
}
# variational stage is the same as the full pipeline model, with different input
# turn on dropout for this model
set_dropout(model=self.model, dropout_flag=True)
# model will be defined in the subclass
set_dropout(model=self.model, dropout_flag=True) # type: ignore[attr-defined]
torch.nn.functional.dropout = dropout_on
# do n runs of the inference
for run_idx in range(self.n_variational_runs):
var_output[self.step_name][run_idx] = self.forward_function(inp)
# turn off dropout for this model
set_dropout(model=self.model, dropout_flag=False)
set_dropout(model=self.model, dropout_flag=False) # type: ignore[attr-defined]
torch.nn.functional.dropout = dropout_off
var_output = self.stack_variational_outputs(var_output)
# For confidence function we need to pass both outputs in all cases
Expand All @@ -104,15 +103,16 @@ def __init__(
device=self.device,
**kwargs,
)
self.model = self.ocr.model
super().__init__(
model=self.ocr.model,
step_name="recognition",
input_key="ocr_data",
forward_function=self.recognise,
confidence_function=self.recognise, # THIS WILL NEED UPDATING : #issue 14
n_variational_runs=n_variational_runs,
**kwargs,
)
self._init_pipeline_map()


class TranslationVariationalPipeline(RTCSingleComponentPipeline):
Expand All @@ -124,24 +124,25 @@ def __init__(
**kwargs,
):
self.set_device()
self.translator = pipeline(
task=model_pars["translator"]["specific_task"],
model=model_pars["translator"]["model"],
max_length=512,
pipeline_class=CustomTranslationPipeline,
device=self.device,
)
# need to initialise the NLI models in this case
self._init_semantic_density()
super().__init__(
model=self.translator.model,
step_name="translation",
input_key="source_text",
forward_function=self.translate,
confidence_function=self.translation_semantic_density,
n_variational_runs=n_variational_runs,
translation_batch_size=translation_batch_size,
)
self.translator = pipeline(
task=model_pars["translator"]["specific_task"],
model=model_pars["translator"]["model"],
max_length=512,
pipeline_class=CustomTranslationPipeline,
device=self.device,
)
self.model = self.translator.model
self._init_pipeline_map()


class ClassificationVariationalPipeline(RTCSingleComponentPipeline):
Expand All @@ -160,23 +161,24 @@ def __init__(
**kwargs,
):
self.set_device()
self.classifier = pipeline(
task=model_pars["classifier"]["specific_task"],
model=model_pars["classifier"]["model"],
multi_label=True,
device=self.device,
)
super().__init__(
model=self.classifier.model,
step_name="classification",
input_key="target_text",
forward_function=self.classify_topic,
confidence_function=self.get_classification_confidence,
n_variational_runs=n_variational_runs,
**kwargs,
)
self.classifier = pipeline(
task=model_pars["classifier"]["specific_task"],
model=model_pars["classifier"]["model"],
multi_label=True,
device=self.device,
)
self.model = self.classifier.model
# topic description labels for the classifier
self.topic_labels = [
class_names_dict["en"]
for class_names_dict in data_pars["class_descriptors"]
]
self._init_pipeline_map()
68 changes: 29 additions & 39 deletions src/arc_spice/variational_pipelines/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
from abc import ABC, abstractmethod
from functools import partial
from typing import Any

Expand All @@ -9,34 +10,6 @@
logger = logging.Logger("RTC_variational_pipeline")


class DummyPipeline:
"""
For initialising a base pipeline which needs to be overwritten by a subclass
"""

def __init__(self, model_name):
"""
Gives the dummy pipeline the required attributes for the method definitions
Args:
model_name: name of the pipeline that is being given a dummy
"""
self.model = model_name

def __call__(self, *args, **kwargs):
"""
Needs to be defined in subclass
Raises:
NotImplementedError: when called to prevent base class being used
"""
error_msg = (
f"{self.model} cannot be called directly and needs to be"
" defined within a subclass."
)
raise NotImplementedError(error_msg)


def set_dropout(model: torch.nn.Module, dropout_flag: bool) -> None:
"""
Turn on or turn off dropout layers of a model.
Expand Down Expand Up @@ -110,7 +83,20 @@ def dropout_w_training_override(
dropout_off = partial(dropout_w_training_override, training_override=False)


class RTCVariationalPipelineBase:
class RTCVariationalPipelineBase(ABC):
"""
Base class for the RTC variational pipelines, cannot be instantiated directly, needs
to have `clean_inference` and `variational_inference` defined by subclass.
"""

@abstractmethod
def clean_inference(self, x):
pass

@abstractmethod
def variational_inference(self, x):
pass

def __init__(self, n_variational_runs=5, translation_batch_size=8):
# device for inference
self.set_device()
Expand Down Expand Up @@ -140,17 +126,13 @@ def __init__(self, n_variational_runs=5, translation_batch_size=8):
self.n_variational_runs = n_variational_runs
self.translation_batch_size = translation_batch_size

if not hasattr(self, "ocr"):
self.ocr = DummyPipeline("ocr")
if not hasattr(self, "translator"):
self.translator = DummyPipeline("translator")
if not hasattr(self, "classifier"):
self.classifier = DummyPipeline("classifier")
self.ocr = None
self.translator = None
self.classifier = None

# map pipeline names to their pipeline counterparts

self.topic_labels = None # This should be defined in subclass if needed
self._init_pipeline_map()

def _init_pipeline_map(self):
"""
Expand Down Expand Up @@ -212,7 +194,13 @@ def check_dropout(self):
"""
logger.debug("\n\n------------------ Testing Dropout --------------------")
for model_key, pl in self.pipeline_map.items():
if not isinstance(pl, Pipeline):
# only test models that exist
if pl is None:
pipeline_none_msg_key = (
f"pipeline under model key, `{model_key}`, is currently"
" set to None. Was this intended?"
)
logger.debug(pipeline_none_msg_key)
continue
# turn on dropout for this model
set_dropout(model=pl.model, dropout_flag=True)
Expand Down Expand Up @@ -240,6 +228,8 @@ def recognise(self, inp) -> dict[str, str]:
dictionary of outputs
"""
# Until the OCR data is available
# This will need the below comment:
# type: ignore[misc]
# TODO https://github.com/alan-turing-institute/ARC-SPICE/issues/14
return {"outputs": inp["source_text"]}

Expand All @@ -256,7 +246,7 @@ def translate(self, text: str) -> dict[str, torch.Tensor | str]:
# split text into sentences
text_splits = self.split_translate_inputs(text, ".")
# perform translation on sentences
translator_outputs = self.translator(
translator_outputs = self.translator( # type: ignore[misc]
text_splits,
output_logits=True,
return_dict_in_generate=True,
Expand Down Expand Up @@ -297,7 +287,7 @@ def classify_topic(self, text: str) -> dict[str, str]:
Returns:
Dictionary of classification outputs, namely the output scores.
"""
forward = self.classifier(text, self.topic_labels)
forward = self.classifier(text, self.topic_labels) # type: ignore[misc]
return {"scores": forward["scores"]}

def stack_translator_sentence_metrics(
Expand Down
6 changes: 2 additions & 4 deletions tests/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from arc_spice.variational_pipelines.RTC_variational_pipeline import (
RTCVariationalPipeline,
)
from arc_spice.variational_pipelines.utils import DummyPipeline

CONFIG_ROOT = f"{os.path.dirname(os.path.abspath(__file__))}/../config/"

Expand Down Expand Up @@ -46,7 +45,7 @@ def test_pipeline_inputs(dummy_data, dummy_metadata):

with patch( # noqa: SIM117
"arc_spice.variational_pipelines.RTC_variational_pipeline.pipeline",
return_value=DummyPipeline("dummy_model"),
return_value=None,
):
with patch(
(
Expand Down Expand Up @@ -83,8 +82,7 @@ def test_single_component_inputs(dummy_data, dummy_metadata):
dummy_classification_output = {"outputs": "classification"}

with patch( # noqa: SIM117
"arc_spice.variational_pipelines.RTC_single_component_pipeline.pipeline",
return_value=DummyPipeline("dummy_model"),
"arc_spice.variational_pipelines.RTC_single_component_pipeline.pipeline"
):
with patch(
(
Expand Down

0 comments on commit 2b3db5d

Please sign in to comment.