From 6a31b02081e366dac56e64ea91048914b8a6925b Mon Sep 17 00:00:00 2001 From: Anchor Yu <91590308+1azyking@users.noreply.github.com> Date: Sun, 15 Dec 2024 21:49:45 +0800 Subject: [PATCH] Update test_change_bias.py Signed-off-by: Anchor Yu <91590308+1azyking@users.noreply.github.com> --- source/tests/pt/test_change_bias.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/source/tests/pt/test_change_bias.py b/source/tests/pt/test_change_bias.py index a3cf3edbbc..58fd953656 100644 --- a/source/tests/pt/test_change_bias.py +++ b/source/tests/pt/test_change_bias.py @@ -87,6 +87,7 @@ def setUp(self) -> None: self.model_path_user_bias = Path(current_path) / ( model_name + "user_bias" + ".pt" ) + self.loss_params = self.config["loss"] def test_change_bias_with_data(self) -> None: run_dp( @@ -96,7 +97,10 @@ def test_change_bias_with_data(self) -> None: str(self.model_path_data_bias), map_location=DEVICE, weights_only=True ) model_params = state_dict["model"]["_extra_state"]["model_params"] - model_for_wrapper = get_model_for_wrapper(model_params) + model_for_wrapper = get_model_for_wrapper( + model_params, + _loss_params=self.loss_params, + ) wrapper = ModelWrapper(model_for_wrapper) wrapper.load_state_dict(state_dict["model"]) updated_bias = wrapper.model["Default"].get_out_bias() @@ -119,7 +123,10 @@ def test_change_bias_with_data_sys_file(self) -> None: str(self.model_path_data_file_bias), map_location=DEVICE, weights_only=True ) model_params = state_dict["model"]["_extra_state"]["model_params"] - model_for_wrapper = get_model_for_wrapper(model_params) + model_for_wrapper = get_model_for_wrapper( + model_params, + _loss_params=self.loss_params, + ) wrapper = ModelWrapper(model_for_wrapper) wrapper.load_state_dict(state_dict["model"]) updated_bias = wrapper.model["Default"].get_out_bias() @@ -140,7 +147,10 @@ def test_change_bias_with_user_defined(self) -> None: str(self.model_path_user_bias), map_location=DEVICE, weights_only=True ) model_params = state_dict["model"]["_extra_state"]["model_params"] - model_for_wrapper = get_model_for_wrapper(model_params) + model_for_wrapper = get_model_for_wrapper( + model_params, + _loss_params=self.loss_params, + ) wrapper = ModelWrapper(model_for_wrapper) wrapper.load_state_dict(state_dict["model"]) updated_bias = wrapper.model["Default"].get_out_bias()