diff --git a/scripts/eval_hf.py b/scripts/eval_hf.py new file mode 100644 index 000000000..c16e3850d --- /dev/null +++ b/scripts/eval_hf.py @@ -0,0 +1,118 @@ +from itertools import islice +import json +import os +import sys +from tqdm import tqdm +from typing import Any, Dict +import torch +import torch.nn.functional as F +import transformers +from olmo.config import TrainConfig, EvaluatorConfig, EvaluatorType +from olmo.eval import build_evaluator +from olmo.torch_util import move_to_device +from olmo.eval.downstream import label_to_task_map_new +from olmo.exceptions import OLMoCliError + + +def get_labels(batch: Dict[str, Any]) -> torch.Tensor: + # Labels are just input IDs shifted to the left (first item is ignored). + labels, label_mask, attention_mask, instance_mask = ( + batch["input_ids"].clone(), + batch.get("label_mask"), + batch.get("attention_mask"), + batch.get("instance_mask"), + ) + if label_mask is not None: + labels.masked_fill_(~label_mask, -100) + if attention_mask is not None: + labels.masked_fill_(attention_mask == 0.0, -100) + if instance_mask is not None: + labels.masked_fill_(~instance_mask.unsqueeze(-1), value=-100) + return labels[..., 1:].contiguous() + +def main(cfg: TrainConfig, model_name: str): + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + tokenizer = transformers.AutoTokenizer.from_pretrained(model_name, token=os.environ.get("HF_TOKEN_DOWNLOAD", None)) + if tokenizer.pad_token_id is None: # This is to prevent the NoneType error in collate_fn() + tokenizer.pad_token_id = 0 + model = transformers.AutoModelForCausalLM.from_pretrained(model_name, token=os.environ.get("HF_TOKEN_DOWNLOAD", None)) + model.to(device) + model.eval() + + cfg.device_eval_batch_size = 4 + cfg.evaluators = [ + EvaluatorConfig(label=label, type=EvaluatorType.downstream) + for label in label_to_task_map_new.keys() + if "_train_" not in label and "_mc_" not in label and "_var" not in label + ] + + evaluators = [] + for eval_cfg in cfg.evaluators: + evaluators.append(build_evaluator(cfg, eval_cfg, tokenizer, device)) + + eval_metrics = {} + for evaluator in tqdm(evaluators): + # Reset metrics. + evaluator.reset_metrics() + + # Initialize data loader iterator. + eval_batches = iter(evaluator.eval_loader) + + # Adjust how many batches to evaluate on. + num_eval_batches = ( + evaluator.subset_num_batches + if evaluator.subset_num_batches is not None + else cfg.eval_subset_num_batches + ) + if num_eval_batches > 0: + num_eval_batches = min(num_eval_batches, len(evaluator.eval_loader)) + eval_batches = islice(eval_batches, num_eval_batches) + + # Run model over batches. + for eval_step, eval_batch in enumerate(eval_batches): + batch = move_to_device(eval_batch, device) + with torch.no_grad(): + logits = model( + input_ids=batch["input_ids"], + attention_mask=batch.get("attention_mask"), + ).logits + logits_for_loss = logits[..., :-1, :].contiguous() + # shape: (batch_size * seq_len, vocab_size) + logits_for_loss = logits_for_loss.view(-1, logits_for_loss.size(-1)) + # shape: (batch_size, seq_len) + labels = get_labels(batch) + # shape: (batch_size * seq_len,) + labels = labels.view(-1) + ce_loss = F.cross_entropy(logits_for_loss, labels, ignore_index=-100, reduction="none") + # Reshape (batch_size * seq_len,) -> (batch_size, seq_len) + ce_loss = ce_loss.view(batch["input_ids"].shape[0], -1) + ce_loss = ce_loss.mean(dim=-1) + evaluator.update_metrics(batch, ce_loss, logits) + + # Get final metrics. + metrics = evaluator.compute_metrics() + eval_metrics.update(metrics) + print(metrics) + + del eval_batches + + print(eval_metrics) + + save_folder = f'/weka/oe-training-default/jiachengl/hc-law/eval_bpb_mc_v2' + if not os.path.exists(save_folder): + os.makedirs(save_folder) + with open(f'{save_folder}/{model_name.replace("/", "_")}.json', 'w') as f: + json.dump(eval_metrics, f) + + +if __name__ == "__main__": + + try: + yaml_path, model_name = sys.argv[1], sys.argv[2] + except IndexError: + raise OLMoCliError(f"Usage: {sys.argv[0]} [CONFIG_PATH] [MODEL_NAME]") + + cfg = TrainConfig.load(yaml_path) + main(cfg, model_name) \ No newline at end of file diff --git a/scripts/eval_hf.sh b/scripts/eval_hf.sh new file mode 100644 index 000000000..97b319b6d --- /dev/null +++ b/scripts/eval_hf.sh @@ -0,0 +1,36 @@ +#!/usr/bin/env bash + +set -ex + +MODEL_NAME=$1 + +gantry run \ + --allow-dirty \ + --workspace ai2/OLMo-tiny \ + --task-name eval-bpb-mc \ + --description "Evaluate bpb and mc for ${MODEL_NAME}" \ + --priority urgent \ + --preemptible \ + --beaker-image petew/olmo-torch23-gantry \ + --cluster ai2/jupiter-cirrascale-2 \ + --gpus 1 \ + --budget ai2/oe-training \ + --no-nfs \ + --weka oe-training-default:/weka/oe-training-default \ + --no-python \ + --env LOG_FILTER_TYPE=local_rank0_only \ + --env OMP_NUM_THREADS=8 \ + --env OLMO_TASK=model \ + --env-secret HF_TOKEN_DOWNLOAD=JIACHENGL_HF_TOKEN \ + --shared-memory 10GiB \ + --yes \ + --timeout=0 \ + -- /bin/bash -c "\ + set -exuo pipefail; \ + IFS=$'\n\t'; \ + conda shell.bash activate base; \ + pip install '.[train]'; \ + pip install -U transformers==4.46.2; \ + pip install -U sentencepiece; \ + torchrun --nproc-per-node 1 scripts/eval_hf.py configs/peteish1-weka.yaml ${MODEL_NAME}; \ + " \ No newline at end of file diff --git a/scripts/eval_hf_launch.sh b/scripts/eval_hf_launch.sh new file mode 100644 index 000000000..ecd770328 --- /dev/null +++ b/scripts/eval_hf_launch.sh @@ -0,0 +1,16 @@ +bash scripts/eval_hf.sh allenai/OLMo-7B-0724-hf +bash scripts/eval_hf.sh allenai/OLMo-7B-hf +bash scripts/eval_hf.sh allenai/OLMo-1B-hf +bash scripts/eval_hf.sh meta-llama/Llama-3.2-3B +bash scripts/eval_hf.sh meta-llama/Llama-3.2-1B +bash scripts/eval_hf.sh meta-llama/Llama-3.1-8B +bash scripts/eval_hf.sh meta-llama/Meta-Llama-3-8B +bash scripts/eval_hf.sh Qwen/Qwen2.5-14B +bash scripts/eval_hf.sh Qwen/Qwen2.5-7B +bash scripts/eval_hf.sh Qwen/Qwen2.5-3B +bash scripts/eval_hf.sh Qwen/Qwen2.5-1.5B +bash scripts/eval_hf.sh Qwen/Qwen2-7B +bash scripts/eval_hf.sh Qwen/Qwen2-1.5B +bash scripts/eval_hf.sh mistralai/Mistral-Nemo-Base-2407 +bash scripts/eval_hf.sh mistralai/Mistral-7B-v0.3 +bash scripts/eval_hf.sh mistralai/Mistral-7B-v0.1