From 27c879636dfc74810bc4dad71fca80f6defc22d7 Mon Sep 17 00:00:00 2001 From: J-Dymond Date: Fri, 29 Nov 2024 12:32:35 +0000 Subject: [PATCH] updated to ignore mypy error relating to undefined models in base class --- .../variational_pipelines/RTC_variational_pipeline.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/arc_spice/variational_pipelines/RTC_variational_pipeline.py b/src/arc_spice/variational_pipelines/RTC_variational_pipeline.py index a0f79ff..eba0c16 100644 --- a/src/arc_spice/variational_pipelines/RTC_variational_pipeline.py +++ b/src/arc_spice/variational_pipelines/RTC_variational_pipeline.py @@ -110,7 +110,7 @@ def variational_inference(self, x: torch.Tensor) -> tuple[dict, dict]: # for each model in pipeline for model_key, pl in self.pipeline_map.items(): # turn on dropout for this model - set_dropout(model=pl.model, dropout_flag=True) + set_dropout(model=pl.model, dropout_flag=True) # type: ignore[union-attr] torch.nn.functional.dropout = dropout_on # do n runs of the inference for run_idx in range(self.n_variational_runs): @@ -118,7 +118,7 @@ def variational_inference(self, x: torch.Tensor) -> tuple[dict, dict]: input_map[model_key] ) # turn off dropout for this model - set_dropout(model=pl.model, dropout_flag=False) + set_dropout(model=pl.model, dropout_flag=False) # type: ignore[union-attr] torch.nn.functional.dropout = dropout_off # run metric helper functions