diff --git a/source/tests/pt/test_change_bias.py b/source/tests/pt/test_change_bias.py index a3cf3edbbc..58fd953656 100644 --- a/source/tests/pt/test_change_bias.py +++ b/source/tests/pt/test_change_bias.py @@ -87,6 +87,7 @@ def setUp(self) -> None: self.model_path_user_bias = Path(current_path) / ( model_name + "user_bias" + ".pt" ) + self.loss_params = self.config["loss"] def test_change_bias_with_data(self) -> None: run_dp( @@ -96,7 +97,10 @@ def test_change_bias_with_data(self) -> None: str(self.model_path_data_bias), map_location=DEVICE, weights_only=True ) model_params = state_dict["model"]["_extra_state"]["model_params"] - model_for_wrapper = get_model_for_wrapper(model_params) + model_for_wrapper = get_model_for_wrapper( + model_params, + _loss_params=self.loss_params, + ) wrapper = ModelWrapper(model_for_wrapper) wrapper.load_state_dict(state_dict["model"]) updated_bias = wrapper.model["Default"].get_out_bias() @@ -119,7 +123,10 @@ def test_change_bias_with_data_sys_file(self) -> None: str(self.model_path_data_file_bias), map_location=DEVICE, weights_only=True ) model_params = state_dict["model"]["_extra_state"]["model_params"] - model_for_wrapper = get_model_for_wrapper(model_params) + model_for_wrapper = get_model_for_wrapper( + model_params, + _loss_params=self.loss_params, + ) wrapper = ModelWrapper(model_for_wrapper) wrapper.load_state_dict(state_dict["model"]) updated_bias = wrapper.model["Default"].get_out_bias() @@ -140,7 +147,10 @@ def test_change_bias_with_user_defined(self) -> None: str(self.model_path_user_bias), map_location=DEVICE, weights_only=True ) model_params = state_dict["model"]["_extra_state"]["model_params"] - model_for_wrapper = get_model_for_wrapper(model_params) + model_for_wrapper = get_model_for_wrapper( + model_params, + _loss_params=self.loss_params, + ) wrapper = ModelWrapper(model_for_wrapper) wrapper.load_state_dict(state_dict["model"]) updated_bias = wrapper.model["Default"].get_out_bias()