diff --git a/tests/test_config.py b/tests/test_config.py index b2a6a87dab..db0ec3be69 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -95,7 +95,8 @@ def test_against_hf_config_bias_parity(model_name): ours_config = Config.from_name(model_name) try: theirs_config = AutoConfig.from_pretrained( - "/".join(ours_config.hf_config.values()), token=os.getenv("HF_TOKEN") + "/".join(ours_config.hf_config.values()), + token=os.getenv("HF_TOKEN"), ) except OSError as e_info: if "You are trying to access a gated repo." in str(e_info): @@ -105,16 +106,9 @@ def test_against_hf_config_bias_parity(model_name): b_names = [name for name in theirs_config.__dict__ if "bias" in name] assert len(b_names) <= 1 - if len(b_names): + if len(b_names) == 1: assert b_names[0] in ("attention_bias", "use_qkv_bias", "bias") - - if hasattr(theirs_config, "attention_bias"): - assert theirs_config.attention_bias == ours_config.attn_qkv_bias - assert theirs_config.attention_bias == ours_config.attn_proj_bias - if hasattr(theirs_config, "use_qkv_bias"): - assert theirs_config.use_qkv_bias == ours_config.attn_qkv_bias - assert theirs_config.use_qkv_bias == ours_config.attn_proj_bias - if hasattr(theirs_config, "bias"): - assert theirs_config.bias == ours_config.attn_qkv_bias - assert theirs_config.bias == ours_config.attn_proj_bias - assert theirs_config.bias == ours_config.mlp_bias + hf_bias = getattr(theirs_config, b_names[0]) + assert hf_bias == ours_config.attn_qkv_bias + assert hf_bias == ours_config.attn_proj_bias + assert hf_bias == ours_config.mlp_bias