Skip to content

Commit

Permalink
updated to ignore mypy error relating to undefined models in base class
Browse files Browse the repository at this point in the history
  • Loading branch information
J-Dymond committed Nov 29, 2024
1 parent 2b3db5d commit 27c8796
Showing 1 changed file with 2 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -110,15 +110,15 @@ def variational_inference(self, x: torch.Tensor) -> tuple[dict, dict]:
# for each model in pipeline
for model_key, pl in self.pipeline_map.items():
# turn on dropout for this model
set_dropout(model=pl.model, dropout_flag=True)
set_dropout(model=pl.model, dropout_flag=True) # type: ignore[union-attr]
torch.nn.functional.dropout = dropout_on
# do n runs of the inference
for run_idx in range(self.n_variational_runs):
var_output[model_key][run_idx] = self.func_map[model_key](
input_map[model_key]
)
# turn off dropout for this model
set_dropout(model=pl.model, dropout_flag=False)
set_dropout(model=pl.model, dropout_flag=False) # type: ignore[union-attr]
torch.nn.functional.dropout = dropout_off

# run metric helper functions
Expand Down

0 comments on commit 27c8796

Please sign in to comment.