Skip to content

Commit

Permalink
some refactoring changes to satisfy mypy.
Browse files Browse the repository at this point in the history
  • Loading branch information
J-Dymond committed Nov 26, 2024
1 parent 9de6d44 commit de365f3
Show file tree
Hide file tree
Showing 5 changed files with 490 additions and 480 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,14 @@
import torch
from transformers import pipeline

from arc_spice.variational_pipelines.dropout_utils import (
dropout_off,
dropout_on,
set_dropout,
)
from arc_spice.variational_pipelines.RTC_variational_pipeline import (
CustomTranslationPipeline,
RTCVariationalPipeline,
RTCPipelineBase,
)
from arc_spice.variational_pipelines.utils import dropout_off, dropout_on, set_dropout


class RTCSingleComponentPipeline(RTCVariationalPipeline):
class RTCSingleComponentPipeline(RTCPipelineBase):
"""
Single component version of the variational pipeline, which inherits methods from
the main `RTCVariationalPipeline` class, without initialising models by overwriting
Expand All @@ -31,24 +27,18 @@ def __init__(
model_pars: dict[str, dict[str, str]],
model_key: str,
data_pars: dict[str, Any],
n_variational_runs: int = 5,
translation_batch_size: int = 8,
n_variational_runs=5,
translation_batch_size=8,
) -> None:
device = (
"cuda"
if torch.cuda.is_available()
else "mps"
if torch.backends.mps.is_available()
else "cpu"
)
super().__init__(n_variational_runs, translation_batch_size)
# define objects that are needed and nothing else
if model_key == "ocr":
self.step_name = "recognition"
self.input_key = "ocr_data"
self.ocr = pipeline(
task=model_pars["OCR"]["specific_task"],
model=model_pars["OCR"]["model"],
device=device,
device=self.device,
)
self.model = self.ocr.model

Expand All @@ -61,7 +51,7 @@ def __init__(
model=model_pars["translator"]["model"],
max_length=512,
pipeline_class=CustomTranslationPipeline,
device=device,
device=self.device,
)
self.model = self.translator.model
# need to initialise the NLI models in this case
Expand All @@ -74,7 +64,7 @@ def __init__(
task=model_pars["classifier"]["specific_task"],
model=model_pars["classifier"]["model"],
multi_label=True,
device=device,
device=self.device,
)
self.model = self.classifier.model
# topic description labels for the classifier
Expand Down Expand Up @@ -113,7 +103,7 @@ def __init__(
"classification": self.classify_topic,
}
self.confidence_func_map: dict[str, Callable] = {
"recognition": self.recognise,
"recognition": self.recognise, ### THIS NEEDS REPLACING WHEN COMPLETED
"translation": self.translation_semantic_density,
"classification": self.get_classification_confidence,
}
Expand Down
Loading

0 comments on commit de365f3

Please sign in to comment.