Skip to content

Commit

Permalink
In HF there a single bias to rule them all
Browse files Browse the repository at this point in the history
  • Loading branch information
Andrei-Aksionov committed Mar 21, 2024
1 parent 3761ab9 commit 00ee6b6
Showing 1 changed file with 7 additions and 13 deletions.
20 changes: 7 additions & 13 deletions tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,8 @@ def test_against_hf_config_bias_parity(model_name):
ours_config = Config.from_name(model_name)
try:
theirs_config = AutoConfig.from_pretrained(
"/".join(ours_config.hf_config.values()), token=os.getenv("HF_TOKEN")
"/".join(ours_config.hf_config.values()),
token=os.getenv("HF_TOKEN"),
)
except OSError as e_info:
if "You are trying to access a gated repo." in str(e_info):
Expand All @@ -105,16 +106,9 @@ def test_against_hf_config_bias_parity(model_name):

b_names = [name for name in theirs_config.__dict__ if "bias" in name]
assert len(b_names) <= 1
if len(b_names):
if len(b_names) == 1:
assert b_names[0] in ("attention_bias", "use_qkv_bias", "bias")

if hasattr(theirs_config, "attention_bias"):
assert theirs_config.attention_bias == ours_config.attn_qkv_bias
assert theirs_config.attention_bias == ours_config.attn_proj_bias
if hasattr(theirs_config, "use_qkv_bias"):
assert theirs_config.use_qkv_bias == ours_config.attn_qkv_bias
assert theirs_config.use_qkv_bias == ours_config.attn_proj_bias
if hasattr(theirs_config, "bias"):
assert theirs_config.bias == ours_config.attn_qkv_bias
assert theirs_config.bias == ours_config.attn_proj_bias
assert theirs_config.bias == ours_config.mlp_bias
hf_bias = getattr(theirs_config, b_names[0])
assert hf_bias == ours_config.attn_qkv_bias
assert hf_bias == ours_config.attn_proj_bias
assert hf_bias == ours_config.mlp_bias

0 comments on commit 00ee6b6

Please sign in to comment.