-
Notifications
You must be signed in to change notification settings - Fork 503
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add eval_hf.py that evaluates external models
- Loading branch information
1 parent
4250a8d
commit d07617f
Showing
3 changed files
with
170 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
from itertools import islice | ||
import json | ||
import os | ||
import sys | ||
from tqdm import tqdm | ||
from typing import Any, Dict | ||
import torch | ||
import torch.nn.functional as F | ||
import transformers | ||
from olmo.config import TrainConfig, EvaluatorConfig, EvaluatorType | ||
from olmo.eval import build_evaluator | ||
from olmo.torch_util import move_to_device | ||
from olmo.eval.downstream import label_to_task_map_new | ||
from olmo.exceptions import OLMoCliError | ||
|
||
|
||
def get_labels(batch: Dict[str, Any]) -> torch.Tensor: | ||
# Labels are just input IDs shifted to the left (first item is ignored). | ||
labels, label_mask, attention_mask, instance_mask = ( | ||
batch["input_ids"].clone(), | ||
batch.get("label_mask"), | ||
batch.get("attention_mask"), | ||
batch.get("instance_mask"), | ||
) | ||
if label_mask is not None: | ||
labels.masked_fill_(~label_mask, -100) | ||
if attention_mask is not None: | ||
labels.masked_fill_(attention_mask == 0.0, -100) | ||
if instance_mask is not None: | ||
labels.masked_fill_(~instance_mask.unsqueeze(-1), value=-100) | ||
return labels[..., 1:].contiguous() | ||
|
||
def main(cfg: TrainConfig, model_name: str): | ||
|
||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
|
||
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name, token=os.environ.get("HF_TOKEN_DOWNLOAD", None)) | ||
if tokenizer.pad_token_id is None: # This is to prevent the NoneType error in collate_fn() | ||
tokenizer.pad_token_id = 0 | ||
model = transformers.AutoModelForCausalLM.from_pretrained(model_name, token=os.environ.get("HF_TOKEN_DOWNLOAD", None)) | ||
model.to(device) | ||
model.eval() | ||
|
||
cfg.device_eval_batch_size = 4 | ||
cfg.evaluators = [ | ||
EvaluatorConfig(label=label, type=EvaluatorType.downstream) | ||
for label in label_to_task_map_new.keys() | ||
if "_train_" not in label and "_mc_" not in label and "_var" not in label | ||
] | ||
|
||
evaluators = [] | ||
for eval_cfg in cfg.evaluators: | ||
evaluators.append(build_evaluator(cfg, eval_cfg, tokenizer, device)) | ||
|
||
eval_metrics = {} | ||
for evaluator in tqdm(evaluators): | ||
# Reset metrics. | ||
evaluator.reset_metrics() | ||
|
||
# Initialize data loader iterator. | ||
eval_batches = iter(evaluator.eval_loader) | ||
|
||
# Adjust how many batches to evaluate on. | ||
num_eval_batches = ( | ||
evaluator.subset_num_batches | ||
if evaluator.subset_num_batches is not None | ||
else cfg.eval_subset_num_batches | ||
) | ||
if num_eval_batches > 0: | ||
num_eval_batches = min(num_eval_batches, len(evaluator.eval_loader)) | ||
eval_batches = islice(eval_batches, num_eval_batches) | ||
|
||
# Run model over batches. | ||
for eval_step, eval_batch in enumerate(eval_batches): | ||
batch = move_to_device(eval_batch, device) | ||
with torch.no_grad(): | ||
logits = model( | ||
input_ids=batch["input_ids"], | ||
attention_mask=batch.get("attention_mask"), | ||
).logits | ||
logits_for_loss = logits[..., :-1, :].contiguous() | ||
# shape: (batch_size * seq_len, vocab_size) | ||
logits_for_loss = logits_for_loss.view(-1, logits_for_loss.size(-1)) | ||
# shape: (batch_size, seq_len) | ||
labels = get_labels(batch) | ||
# shape: (batch_size * seq_len,) | ||
labels = labels.view(-1) | ||
ce_loss = F.cross_entropy(logits_for_loss, labels, ignore_index=-100, reduction="none") | ||
# Reshape (batch_size * seq_len,) -> (batch_size, seq_len) | ||
ce_loss = ce_loss.view(batch["input_ids"].shape[0], -1) | ||
ce_loss = ce_loss.mean(dim=-1) | ||
evaluator.update_metrics(batch, ce_loss, logits) | ||
|
||
# Get final metrics. | ||
metrics = evaluator.compute_metrics() | ||
eval_metrics.update(metrics) | ||
print(metrics) | ||
|
||
del eval_batches | ||
|
||
print(eval_metrics) | ||
|
||
save_folder = f'/weka/oe-training-default/jiachengl/hc-law/eval_bpb_mc_v2' | ||
if not os.path.exists(save_folder): | ||
os.makedirs(save_folder) | ||
with open(f'{save_folder}/{model_name.replace("/", "_")}.json', 'w') as f: | ||
json.dump(eval_metrics, f) | ||
|
||
|
||
if __name__ == "__main__": | ||
|
||
try: | ||
yaml_path, model_name = sys.argv[1], sys.argv[2] | ||
except IndexError: | ||
raise OLMoCliError(f"Usage: {sys.argv[0]} [CONFIG_PATH] [MODEL_NAME]") | ||
|
||
cfg = TrainConfig.load(yaml_path) | ||
main(cfg, model_name) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
#!/usr/bin/env bash | ||
|
||
set -ex | ||
|
||
MODEL_NAME=$1 | ||
|
||
gantry run \ | ||
--allow-dirty \ | ||
--workspace ai2/OLMo-tiny \ | ||
--task-name eval-bpb-mc \ | ||
--description "Evaluate bpb and mc for ${MODEL_NAME}" \ | ||
--priority urgent \ | ||
--preemptible \ | ||
--beaker-image petew/olmo-torch23-gantry \ | ||
--cluster ai2/jupiter-cirrascale-2 \ | ||
--gpus 1 \ | ||
--budget ai2/oe-training \ | ||
--no-nfs \ | ||
--weka oe-training-default:/weka/oe-training-default \ | ||
--no-python \ | ||
--env LOG_FILTER_TYPE=local_rank0_only \ | ||
--env OMP_NUM_THREADS=8 \ | ||
--env OLMO_TASK=model \ | ||
--env-secret HF_TOKEN_DOWNLOAD=JIACHENGL_HF_TOKEN \ | ||
--shared-memory 10GiB \ | ||
--yes \ | ||
--timeout=0 \ | ||
-- /bin/bash -c "\ | ||
set -exuo pipefail; \ | ||
IFS=$'\n\t'; \ | ||
conda shell.bash activate base; \ | ||
pip install '.[train]'; \ | ||
pip install -U transformers==4.46.2; \ | ||
pip install -U sentencepiece; \ | ||
torchrun --nproc-per-node 1 scripts/eval_hf.py configs/peteish1-weka.yaml ${MODEL_NAME}; \ | ||
" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
bash scripts/eval_hf.sh allenai/OLMo-7B-0724-hf | ||
bash scripts/eval_hf.sh allenai/OLMo-7B-hf | ||
bash scripts/eval_hf.sh allenai/OLMo-1B-hf | ||
bash scripts/eval_hf.sh meta-llama/Llama-3.2-3B | ||
bash scripts/eval_hf.sh meta-llama/Llama-3.2-1B | ||
bash scripts/eval_hf.sh meta-llama/Llama-3.1-8B | ||
bash scripts/eval_hf.sh meta-llama/Meta-Llama-3-8B | ||
bash scripts/eval_hf.sh Qwen/Qwen2.5-14B | ||
bash scripts/eval_hf.sh Qwen/Qwen2.5-7B | ||
bash scripts/eval_hf.sh Qwen/Qwen2.5-3B | ||
bash scripts/eval_hf.sh Qwen/Qwen2.5-1.5B | ||
bash scripts/eval_hf.sh Qwen/Qwen2-7B | ||
bash scripts/eval_hf.sh Qwen/Qwen2-1.5B | ||
bash scripts/eval_hf.sh mistralai/Mistral-Nemo-Base-2407 | ||
bash scripts/eval_hf.sh mistralai/Mistral-7B-v0.3 | ||
bash scripts/eval_hf.sh mistralai/Mistral-7B-v0.1 |