Skip to content

Commit

Permalink
Add eval_hf.py that evaluates external models
Browse files Browse the repository at this point in the history
  • Loading branch information
liujch1998 committed Nov 24, 2024
1 parent 4250a8d commit d07617f
Show file tree
Hide file tree
Showing 3 changed files with 170 additions and 0 deletions.
118 changes: 118 additions & 0 deletions scripts/eval_hf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
from itertools import islice
import json
import os
import sys
from tqdm import tqdm
from typing import Any, Dict
import torch
import torch.nn.functional as F
import transformers
from olmo.config import TrainConfig, EvaluatorConfig, EvaluatorType
from olmo.eval import build_evaluator
from olmo.torch_util import move_to_device
from olmo.eval.downstream import label_to_task_map_new
from olmo.exceptions import OLMoCliError


def get_labels(batch: Dict[str, Any]) -> torch.Tensor:
# Labels are just input IDs shifted to the left (first item is ignored).
labels, label_mask, attention_mask, instance_mask = (
batch["input_ids"].clone(),
batch.get("label_mask"),
batch.get("attention_mask"),
batch.get("instance_mask"),
)
if label_mask is not None:
labels.masked_fill_(~label_mask, -100)
if attention_mask is not None:
labels.masked_fill_(attention_mask == 0.0, -100)
if instance_mask is not None:
labels.masked_fill_(~instance_mask.unsqueeze(-1), value=-100)
return labels[..., 1:].contiguous()

def main(cfg: TrainConfig, model_name: str):

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

tokenizer = transformers.AutoTokenizer.from_pretrained(model_name, token=os.environ.get("HF_TOKEN_DOWNLOAD", None))
if tokenizer.pad_token_id is None: # This is to prevent the NoneType error in collate_fn()
tokenizer.pad_token_id = 0
model = transformers.AutoModelForCausalLM.from_pretrained(model_name, token=os.environ.get("HF_TOKEN_DOWNLOAD", None))
model.to(device)
model.eval()

cfg.device_eval_batch_size = 4
cfg.evaluators = [
EvaluatorConfig(label=label, type=EvaluatorType.downstream)
for label in label_to_task_map_new.keys()
if "_train_" not in label and "_mc_" not in label and "_var" not in label
]

evaluators = []
for eval_cfg in cfg.evaluators:
evaluators.append(build_evaluator(cfg, eval_cfg, tokenizer, device))

eval_metrics = {}
for evaluator in tqdm(evaluators):
# Reset metrics.
evaluator.reset_metrics()

# Initialize data loader iterator.
eval_batches = iter(evaluator.eval_loader)

# Adjust how many batches to evaluate on.
num_eval_batches = (
evaluator.subset_num_batches
if evaluator.subset_num_batches is not None
else cfg.eval_subset_num_batches
)
if num_eval_batches > 0:
num_eval_batches = min(num_eval_batches, len(evaluator.eval_loader))
eval_batches = islice(eval_batches, num_eval_batches)

# Run model over batches.
for eval_step, eval_batch in enumerate(eval_batches):
batch = move_to_device(eval_batch, device)
with torch.no_grad():
logits = model(
input_ids=batch["input_ids"],
attention_mask=batch.get("attention_mask"),
).logits
logits_for_loss = logits[..., :-1, :].contiguous()
# shape: (batch_size * seq_len, vocab_size)
logits_for_loss = logits_for_loss.view(-1, logits_for_loss.size(-1))
# shape: (batch_size, seq_len)
labels = get_labels(batch)
# shape: (batch_size * seq_len,)
labels = labels.view(-1)
ce_loss = F.cross_entropy(logits_for_loss, labels, ignore_index=-100, reduction="none")
# Reshape (batch_size * seq_len,) -> (batch_size, seq_len)
ce_loss = ce_loss.view(batch["input_ids"].shape[0], -1)
ce_loss = ce_loss.mean(dim=-1)
evaluator.update_metrics(batch, ce_loss, logits)

# Get final metrics.
metrics = evaluator.compute_metrics()
eval_metrics.update(metrics)
print(metrics)

del eval_batches

print(eval_metrics)

save_folder = f'/weka/oe-training-default/jiachengl/hc-law/eval_bpb_mc_v2'
if not os.path.exists(save_folder):
os.makedirs(save_folder)
with open(f'{save_folder}/{model_name.replace("/", "_")}.json', 'w') as f:
json.dump(eval_metrics, f)


if __name__ == "__main__":

try:
yaml_path, model_name = sys.argv[1], sys.argv[2]
except IndexError:
raise OLMoCliError(f"Usage: {sys.argv[0]} [CONFIG_PATH] [MODEL_NAME]")

cfg = TrainConfig.load(yaml_path)
main(cfg, model_name)
36 changes: 36 additions & 0 deletions scripts/eval_hf.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
#!/usr/bin/env bash

set -ex

MODEL_NAME=$1

gantry run \
--allow-dirty \
--workspace ai2/OLMo-tiny \
--task-name eval-bpb-mc \
--description "Evaluate bpb and mc for ${MODEL_NAME}" \
--priority urgent \
--preemptible \
--beaker-image petew/olmo-torch23-gantry \
--cluster ai2/jupiter-cirrascale-2 \
--gpus 1 \
--budget ai2/oe-training \
--no-nfs \
--weka oe-training-default:/weka/oe-training-default \
--no-python \
--env LOG_FILTER_TYPE=local_rank0_only \
--env OMP_NUM_THREADS=8 \
--env OLMO_TASK=model \
--env-secret HF_TOKEN_DOWNLOAD=JIACHENGL_HF_TOKEN \
--shared-memory 10GiB \
--yes \
--timeout=0 \
-- /bin/bash -c "\
set -exuo pipefail; \
IFS=$'\n\t'; \
conda shell.bash activate base; \
pip install '.[train]'; \
pip install -U transformers==4.46.2; \
pip install -U sentencepiece; \
torchrun --nproc-per-node 1 scripts/eval_hf.py configs/peteish1-weka.yaml ${MODEL_NAME}; \
"
16 changes: 16 additions & 0 deletions scripts/eval_hf_launch.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
bash scripts/eval_hf.sh allenai/OLMo-7B-0724-hf
bash scripts/eval_hf.sh allenai/OLMo-7B-hf
bash scripts/eval_hf.sh allenai/OLMo-1B-hf
bash scripts/eval_hf.sh meta-llama/Llama-3.2-3B
bash scripts/eval_hf.sh meta-llama/Llama-3.2-1B
bash scripts/eval_hf.sh meta-llama/Llama-3.1-8B
bash scripts/eval_hf.sh meta-llama/Meta-Llama-3-8B
bash scripts/eval_hf.sh Qwen/Qwen2.5-14B
bash scripts/eval_hf.sh Qwen/Qwen2.5-7B
bash scripts/eval_hf.sh Qwen/Qwen2.5-3B
bash scripts/eval_hf.sh Qwen/Qwen2.5-1.5B
bash scripts/eval_hf.sh Qwen/Qwen2-7B
bash scripts/eval_hf.sh Qwen/Qwen2-1.5B
bash scripts/eval_hf.sh mistralai/Mistral-Nemo-Base-2407
bash scripts/eval_hf.sh mistralai/Mistral-7B-v0.3
bash scripts/eval_hf.sh mistralai/Mistral-7B-v0.1

0 comments on commit d07617f

Please sign in to comment.