Skip to content

Commit

Permalink
added some comments to the single component pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
J-Dymond committed Nov 26, 2024
1 parent 504009f commit 9de6d44
Showing 1 changed file with 6 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ def __init__(

# clean inference
def clean_inference(self, x):
# run only the model that is defined
inp = x[self.input_key]
clean_output: dict[str, Any] = {
self.step_name: {},
Expand All @@ -130,11 +131,13 @@ def clean_inference(self, x):
return clean_output

def variational_inference(self, x):
# run only model that is defined in clean output
clean_output = self.clean_inference(x)
inp = x[self.input_key]
var_output: dict[str, Any] = {
self.step_name: {},
}
# variational stage is the same as the full pipeline model, with different input
# turn on dropout for this model
set_dropout(model=self.model, dropout_flag=True)
torch.nn.functional.dropout = dropout_on
Expand All @@ -147,6 +150,9 @@ def variational_inference(self, x):
set_dropout(model=self.model, dropout_flag=False)
torch.nn.functional.dropout = dropout_off
var_output = self.stack_variational_outputs(var_output)
# For confidence function we need to pass both outputs in all cases
# This allows the abstraction to self.confidence_func_map[self.step_name]
conf_args = {"clean_output": clean_output, "var_output": var_output}
var_output = self.confidence_func_map[self.step_name](**conf_args)
# return both as in base function method
return clean_output, var_output

0 comments on commit 9de6d44

Please sign in to comment.