Skip to content

Commit

Permalink
Avoid the intermediate save step in evaluate (#1249)
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca authored Apr 16, 2024
1 parent 5f838a4 commit ca5816f
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 18 deletions.
23 changes: 7 additions & 16 deletions litgpt/eval/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,14 @@
import json
import os
from pathlib import Path
from typing import Optional
from typing import Optional, Union
import yaml
import torch

from litgpt.scripts.convert_lit_checkpoint import convert_lit_checkpoint
from litgpt.utils import CLI, copy_config_files


def save_safetensors(out_dir, repo_id):
from transformers import AutoModel

state_dict = torch.load(out_dir/"model.pth")
model = AutoModel.from_pretrained(
repo_id, state_dict=state_dict
)
model.save_pretrained(out_dir)


def prepare_results(results, save_filepath, print_results=True):
from lm_eval.utils import make_table

Expand All @@ -43,6 +33,7 @@ def convert_and_evaluate(
num_fewshot: Optional[int] = None,
batch_size: int = 1,
device: Optional[str] = None,
dtype: Optional[Union[str, torch.dtype]] = None,
limit: Optional[float] = None,
seed: int = 1234,
save_filepath: Optional[str] = None,
Expand Down Expand Up @@ -102,15 +93,15 @@ def convert_and_evaluate(
if not model_path.exists() or force_conversion:
convert_lit_checkpoint(checkpoint_dir=checkpoint_dir, output_dir=out_dir)

safetensors_path = out_dir / "model.safetensors"
if not safetensors_path.exists() or force_conversion:
save_safetensors(out_dir, repo_id)
from lm_eval.models.huggingface import HFLM

state_dict = torch.load(model_path)
model = HFLM(repo_id, state_dict=state_dict, device=device, batch_size=batch_size, dtype=dtype)

os.environ["TOKENIZERS_PARALLELISM"] = "false"

results = evaluator.simple_evaluate(
model="hf",
model_args=f"pretrained={out_dir}",
model=model,
tasks=tasks.split(","),
num_fewshot=num_fewshot,
batch_size=batch_size,
Expand Down
3 changes: 1 addition & 2 deletions tests/test_evaluate.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.

import os
import shutil
import subprocess
import sys
Expand All @@ -25,7 +24,6 @@
strict=False,
match="Loading a dataset cached in a LocalFileSystem is not supported",
)
@mock.patch.dict(os.environ, {"LT_ACCELERATOR": "cpu"})
def test_evaluate_script(tmp_path, monkeypatch):
ours_config = Config.from_name("pythia-14m")
download_from_hub(repo_id="EleutherAI/pythia-14m", tokenizer_only=True, checkpoint_dir=tmp_path)
Expand All @@ -42,6 +40,7 @@ def test_evaluate_script(tmp_path, monkeypatch):
checkpoint_dir=tmp_path,
out_dir=tmp_path / "out_dir",
device="cpu",
dtype=torch.float32,
limit=5,
tasks="mathqa"
)
Expand Down

0 comments on commit ca5816f

Please sign in to comment.