Skip to content

Commit

Permalink
Update test_change_bias.py
Browse files Browse the repository at this point in the history
Signed-off-by: Anchor Yu <[email protected]>
  • Loading branch information
1azyking authored Dec 15, 2024
1 parent 7dead9c commit 6a31b02
Showing 1 changed file with 13 additions and 3 deletions.
16 changes: 13 additions & 3 deletions source/tests/pt/test_change_bias.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ def setUp(self) -> None:
self.model_path_user_bias = Path(current_path) / (
model_name + "user_bias" + ".pt"
)
self.loss_params = self.config["loss"]

def test_change_bias_with_data(self) -> None:
run_dp(
Expand All @@ -96,7 +97,10 @@ def test_change_bias_with_data(self) -> None:
str(self.model_path_data_bias), map_location=DEVICE, weights_only=True
)
model_params = state_dict["model"]["_extra_state"]["model_params"]
model_for_wrapper = get_model_for_wrapper(model_params)
model_for_wrapper = get_model_for_wrapper(
model_params,
_loss_params=self.loss_params,
)
wrapper = ModelWrapper(model_for_wrapper)
wrapper.load_state_dict(state_dict["model"])
updated_bias = wrapper.model["Default"].get_out_bias()
Expand All @@ -119,7 +123,10 @@ def test_change_bias_with_data_sys_file(self) -> None:
str(self.model_path_data_file_bias), map_location=DEVICE, weights_only=True
)
model_params = state_dict["model"]["_extra_state"]["model_params"]
model_for_wrapper = get_model_for_wrapper(model_params)
model_for_wrapper = get_model_for_wrapper(
model_params,
_loss_params=self.loss_params,
)
wrapper = ModelWrapper(model_for_wrapper)
wrapper.load_state_dict(state_dict["model"])
updated_bias = wrapper.model["Default"].get_out_bias()
Expand All @@ -140,7 +147,10 @@ def test_change_bias_with_user_defined(self) -> None:
str(self.model_path_user_bias), map_location=DEVICE, weights_only=True
)
model_params = state_dict["model"]["_extra_state"]["model_params"]
model_for_wrapper = get_model_for_wrapper(model_params)
model_for_wrapper = get_model_for_wrapper(
model_params,
_loss_params=self.loss_params,
)
wrapper = ModelWrapper(model_for_wrapper)
wrapper.load_state_dict(state_dict["model"])
updated_bias = wrapper.model["Default"].get_out_bias()
Expand Down

0 comments on commit 6a31b02

Please sign in to comment.