From 9de6d44f80ba5ec4e60b600b9a7e3821ed5f9354 Mon Sep 17 00:00:00 2001 From: J-Dymond Date: Tue, 26 Nov 2024 15:21:41 +0000 Subject: [PATCH] added some comments to the single component pipeline --- .../variational_pipelines/RTC_single_component_pipeline.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/arc_spice/variational_pipelines/RTC_single_component_pipeline.py b/src/arc_spice/variational_pipelines/RTC_single_component_pipeline.py index bac7898..1147202 100644 --- a/src/arc_spice/variational_pipelines/RTC_single_component_pipeline.py +++ b/src/arc_spice/variational_pipelines/RTC_single_component_pipeline.py @@ -122,6 +122,7 @@ def __init__( # clean inference def clean_inference(self, x): + # run only the model that is defined inp = x[self.input_key] clean_output: dict[str, Any] = { self.step_name: {}, @@ -130,11 +131,13 @@ def clean_inference(self, x): return clean_output def variational_inference(self, x): + # run only model that is defined in clean output clean_output = self.clean_inference(x) inp = x[self.input_key] var_output: dict[str, Any] = { self.step_name: {}, } + # variational stage is the same as the full pipeline model, with different input # turn on dropout for this model set_dropout(model=self.model, dropout_flag=True) torch.nn.functional.dropout = dropout_on @@ -147,6 +150,9 @@ def variational_inference(self, x): set_dropout(model=self.model, dropout_flag=False) torch.nn.functional.dropout = dropout_off var_output = self.stack_variational_outputs(var_output) + # For confidence function we need to pass both outputs in all cases + # This allows the abstraction to self.confidence_func_map[self.step_name] conf_args = {"clean_output": clean_output, "var_output": var_output} var_output = self.confidence_func_map[self.step_name](**conf_args) + # return both as in base function method return clean_output, var_output