diff --git a/src/arc_spice/variational_pipelines/RTC_variational_pipeline.py b/src/arc_spice/variational_pipelines/RTC_variational_pipeline.py index a0f79ff..eba0c16 100644 --- a/src/arc_spice/variational_pipelines/RTC_variational_pipeline.py +++ b/src/arc_spice/variational_pipelines/RTC_variational_pipeline.py @@ -110,7 +110,7 @@ def variational_inference(self, x: torch.Tensor) -> tuple[dict, dict]: # for each model in pipeline for model_key, pl in self.pipeline_map.items(): # turn on dropout for this model - set_dropout(model=pl.model, dropout_flag=True) + set_dropout(model=pl.model, dropout_flag=True) # type: ignore[union-attr] torch.nn.functional.dropout = dropout_on # do n runs of the inference for run_idx in range(self.n_variational_runs): @@ -118,7 +118,7 @@ def variational_inference(self, x: torch.Tensor) -> tuple[dict, dict]: input_map[model_key] ) # turn off dropout for this model - set_dropout(model=pl.model, dropout_flag=False) + set_dropout(model=pl.model, dropout_flag=False) # type: ignore[union-attr] torch.nn.functional.dropout = dropout_off # run metric helper functions