Skip to content

Commit

Permalink
Merge branch 'main' into carmocca/revert-stuff
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca authored Apr 4, 2024
2 parents fffcce7 + c634a28 commit 597bf52
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions tests/test_merge_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@


@mock.patch.dict(os.environ, {"LT_ACCELERATOR": "cpu"})
def test_merge_lora(tmp_path, fake_checkpoint_dir):
@pytest.mark.parametrize(
("pretrained_dtype", "lora_dtype"), [(None, None), (torch.float16, torch.float32), (torch.float16, torch.bfloat16)]
)
def test_merge_lora(tmp_path, fake_checkpoint_dir, pretrained_dtype, lora_dtype):
pretrained_checkpoint_dir = tmp_path / "pretrained"
lora_checkpoint_dir = tmp_path / "lora"
shutil.copytree(fake_checkpoint_dir, pretrained_checkpoint_dir)
Expand All @@ -30,14 +33,14 @@ def test_merge_lora(tmp_path, fake_checkpoint_dir):
config = dict(block_size=128, padded_vocab_size=256, n_layer=3, n_head=8, n_embd=16)
with open(pretrained_checkpoint_dir / "model_config.yaml", "w") as fp:
yaml.dump(config, fp)
base_model = GPT.from_name("pythia-14m", **config)
base_model = GPT.from_name("pythia-14m", **config).to(dtype=pretrained_dtype)
state_dict = base_model.state_dict()
assert len(state_dict) == 40
torch.save(state_dict, pretrained_checkpoint_dir / "lit_model.pth")

# Create a fake LoRA checkpoint
lora_kwargs = dict(lora_r=8, lora_alpha=16, lora_dropout=0.05, lora_query=True, lora_value=True)
lora_model = LoRAGPT.from_name("pythia-14m", **config, **lora_kwargs)
lora_model = LoRAGPT.from_name("pythia-14m", **config, **lora_kwargs).to(dtype=lora_dtype)
state_dict = {k: v for k, v in lora_model.state_dict().items() if lora_filter(k, v)}
assert len(state_dict) == 6
torch.save(state_dict, lora_checkpoint_dir / "lit_model.pth.lora")
Expand Down

0 comments on commit 597bf52

Please sign in to comment.