Skip to content

Commit

Permalink
Misc improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
liujch1998 committed Nov 25, 2024
1 parent f77945e commit 8cc954a
Show file tree
Hide file tree
Showing 9 changed files with 254 additions and 197 deletions.
12 changes: 6 additions & 6 deletions olmo/scaling/scaling_laws/fitting_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,17 +43,17 @@ def get_std_errors(xs, ys, coefficients, cov, fitting_func, grad_fitting_func):


# x = flops
# p[0] = A, p[1] = B, p[2] = E
# p[0] = a = log(A), p[1] = alpha, p[2] = E
def chinchilla_flops_fit(x, p):
# return ax**b + E
return p[0] * np.pow(x, p[1]) + p[2]
# return e**a / x**alpha + E
return np.exp(p[0]) / x ** p[1] + p[2]


def grad_chinchilla_flops_fit(x, p):
grad_A = np.pow(x, p[1])
grad_B = p[0] * np.pow(x, p[1]) * np.log(x)
grad_a = np.exp(p[0]) / x ** p[1]
grad_alpha = np.exp(p[0]) * (-np.log(x)) / x ** p[1]
grad_E = 1
return [grad_A, grad_B, grad_E]
return [grad_a, grad_alpha, grad_E]


# x[0] = d, x[1] = h
Expand Down
31 changes: 23 additions & 8 deletions olmo/scaling/scaling_laws/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,15 @@ def get_log_soft_loss_keys(self):
"boolq_val": 0.5,
"winogrande_val": 0.5,
}
v2_maximums_rc: Dict[str, float] = {}
v2_maximums_rc: Dict[str, float] = {
# "mmlu_avg_test": 1.06,
# "arc_challenge_test": 1.65,
# "arc_easy_test": 1.40,
# "piqa_val": 1.53,
# "csqa_val": 1.10,
# "socialiqa_val": 0.73,
# "openbookqa_test": 1.94,
}

v2_core_names = [
"hellaswag_val",
Expand Down Expand Up @@ -317,8 +325,8 @@ def get_log_soft_loss_keys(self):
task_accuracy_key=[f"eval/downstream/{key}_rc_5shot_len_norm" for key in v2_mmlu_val_names],
task_mc_loss_key=[f"eval/downstream_bpb/{key}_mc_5shot_bpb" for key in v2_mmlu_val_names],
task_mc_accuracy_key=[f"eval/downstream/{key}_mc_5shot_len_norm" for key in v2_mmlu_val_names],
task_minimum=0.25,
task_maximum=1.0,
task_minimum=v2_minimums_rc.get("mmlu_avg_val", 0.25),
task_maximum=v2_maximums_rc.get("mmlu_avg_val", 1.0),
)
}

Expand All @@ -330,8 +338,8 @@ def get_log_soft_loss_keys(self):
task_accuracy_key=[f"eval/downstream/{key}_rc_5shot_len_norm" for key in v2_mmlu_test_names],
task_mc_loss_key=[f"eval/downstream_bpb/{key}_mc_5shot_bpb" for key in v2_mmlu_test_names],
task_mc_accuracy_key=[f"eval/downstream/{key}_mc_5shot_len_norm" for key in v2_mmlu_test_names],
task_minimum=0.25,
task_maximum=1.0, # 0.9,
task_minimum=v2_minimums_rc.get("mmlu_avg_test", 0.25),
task_maximum=v2_maximums_rc.get("mmlu_avg_test", 1.0),
)
}

Expand Down Expand Up @@ -673,6 +681,8 @@ def get_step1_data_by_name(configs, task_name, y_metric="rc_bpb", moving_avg=1):
keys = task.get_loss_keys()
elif y_metric == "rc_acc":
keys = task.get_accuracy_keys()
elif y_metric == "c4":
keys = ["eval/c4_en-validation/CrossEntropyLoss"]
else:
raise ValueError(f"Invalid y_metric: {y_metric}")

Expand All @@ -688,7 +698,7 @@ def get_step1_data_by_name(configs, task_name, y_metric="rc_bpb", moving_avg=1):
ds, ys, fs = [], [], []
for row in rows:
d = int(float(row["throughput/total_tokens"]))
f = d * MODEL_FLOPS[name.split("-")[0]]
f = float(d * MODEL_FLOPS[name.split("-")[0]])
y = np.average(
[float(row[key]) for key in keys], weights=[WEIGHT_BY_KEY.get(key, 1.0) for key in keys]
)
Expand Down Expand Up @@ -744,9 +754,14 @@ def get_length(path):
return ""


def get_step2_data_by_name(configs, task_name, y_metric="rc_acc", moving_avg=1, skip_perc=0.0, last_n_points=-1):
def get_step2_data_by_name(configs, task_name, x_metric="rc_bpb", y_metric="rc_acc", moving_avg=1, skip_perc=0.0, last_n_points=-1):
task = tasks[task_name]
loss_keys = task.get_loss_keys()
if x_metric == "rc_bpb":
loss_keys = task.get_loss_keys()
elif x_metric == "c4":
loss_keys = ["eval/c4_en-validation/CrossEntropyLoss"]
else:
raise ValueError(f"Invalid x_metric: {x_metric}")
if y_metric == "rc_acc":
accuracy_keys = task.get_accuracy_keys()
elif y_metric == "mc_acc":
Expand Down
134 changes: 134 additions & 0 deletions scripts/eval_hf_plot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
import json

import matplotlib.pyplot as plt
import numpy as np

MODELS = [
"allenai/OLMo-7B-0724-hf",
# 'allenai/OLMo-1B-0724-hf',
# 'allenai/OLMo-7B-0424-hf',
"allenai/OLMo-7B-hf",
"allenai/OLMo-1B-hf",
"meta-llama/Llama-3.2-3B",
"meta-llama/Llama-3.2-1B",
# 'meta-llama/Llama-3.1-70B',
"meta-llama/Llama-3.1-8B",
# 'meta-llama/Meta-Llama-3-70B',
"meta-llama/Meta-Llama-3-8B",
# 'meta-llama/Llama-2-70b-hf',
# 'meta-llama/Llama-2-13b-hf',
# 'meta-llama/Llama-2-7b-hf',
# 'google/gemma-2-27b',
# 'google/gemma-2-9b',
# 'google/gemma-2-2b',
# 'google/gemma-7b',
# 'google/gemma-2b',
# 'Qwen/Qwen2.5-72B',
# 'Qwen/Qwen2.5-32B',
"Qwen/Qwen2.5-14B",
"Qwen/Qwen2.5-7B",
"Qwen/Qwen2.5-3B",
"Qwen/Qwen2.5-1.5B",
# 'Qwen/Qwen2-72B',
"Qwen/Qwen2-7B",
"Qwen/Qwen2-1.5B",
"mistralai/Mistral-Nemo-Base-2407",
"mistralai/Mistral-7B-v0.3",
"mistralai/Mistral-7B-v0.1",
]

COLOR_BY_MODEL_PREFIX = {
"allenai": "hotpink",
"meta-llama/Llama-3.2": "darkblue",
"meta-llama/Llama-3.1": "mediumblue",
"meta-llama/Meta-Llama-3": "royalblue",
"meta-llama/Llama-2": "cornflowerblue",
"google/gemma-2-": "darkgreen",
"google/gemma-": "forestgreen",
"Qwen/Qwen2.5": "darkviolet",
"Qwen/Qwen2": "violet",
"mistralai": "darkorange",
}


def get_color(model):
for prefix, color in COLOR_BY_MODEL_PREFIX.items():
if model.startswith(prefix):
return color
return "black"


METRICS_BY_TASK = {
"rc_rc_mmlu": [
("mmlu_stem_test_rc_5shot_bpb", "mmlu_stem_test_rc_5shot_len_norm", 0.215),
("mmlu_humanities_test_rc_5shot_bpb", "mmlu_humanities_test_rc_5shot_len_norm", 0.335),
("mmlu_social_sciences_test_rc_5shot_bpb", "mmlu_social_sciences_test_rc_5shot_len_norm", 0.219),
("mmlu_other_test_rc_5shot_bpb", "mmlu_other_test_rc_5shot_len_norm", 0.231),
],
"rc_rc_hellaswag": [("hellaswag_val_rc_5shot_bpb", "hellaswag_val_rc_5shot_len_norm", 1.0)],
"rc_rc_arc-c": [("arc_challenge_test_rc_5shot_bpb", "arc_challenge_test_rc_5shot_len_norm", 1.0)],
"rc_rc_arc-e": [("arc_easy_test_rc_5shot_bpb", "arc_easy_test_rc_5shot_len_norm", 1.0)],
"rc_rc_piqa": [("piqa_val_rc_5shot_bpb", "piqa_val_rc_5shot_len_norm", 1.0)],
"rc_rc_csqa": [("csqa_val_rc_5shot_bpb", "csqa_val_rc_5shot_len_norm", 1.0)],
"rc_rc_socialiqa": [("socialiqa_val_rc_5shot_bpb", "socialiqa_val_rc_5shot_len_norm", 1.0)],
"rc_rc_openbookqa": [("openbookqa_test_rc_5shot_bpb", "openbookqa_test_rc_5shot_len_norm", 1.0)],
"rc_mc_mmlu": [
("mmlu_stem_test_rc_5shot_bpb", "mmlu_stem_test_mc_5shot_len_norm", 0.215),
("mmlu_humanities_test_rc_5shot_bpb", "mmlu_humanities_test_mc_5shot_len_norm", 0.335),
("mmlu_social_sciences_test_rc_5shot_bpb", "mmlu_social_sciences_test_mc_5shot_len_norm", 0.219),
("mmlu_other_test_rc_5shot_bpb", "mmlu_other_test_mc_5shot_len_norm", 0.231),
],
"rc_mc_hellaswag": [("hellaswag_val_rc_5shot_bpb", "hellaswag_val_mc_5shot_acc", 1.0)],
"rc_mc_arc-c": [("arc_challenge_test_rc_5shot_bpb", "arc_challenge_test_mc_5shot_acc", 1.0)],
"rc_mc_arc-e": [("arc_easy_test_rc_5shot_bpb", "arc_easy_test_mc_5shot_acc", 1.0)],
"rc_mc_piqa": [("piqa_val_rc_5shot_bpb", "piqa_val_mc_5shot_acc", 1.0)],
"rc_mc_csqa": [("csqa_val_rc_5shot_bpb", "csqa_val_mc_5shot_acc", 1.0)],
"rc_mc_socialiqa": [("socialiqa_val_rc_5shot_bpb", "socialiqa_val_mc_5shot_acc", 1.0)],
"rc_mc_openbookqa": [("openbookqa_test_rc_5shot_bpb", "openbookqa_test_mc_5shot_acc", 1.0)],
"mc_mc_mmlu": [
("mmlu_stem_test_mc_5shot_bpb", "mmlu_stem_test_mc_5shot_len_norm", 0.215),
("mmlu_humanities_test_mc_5shot_bpb", "mmlu_humanities_test_mc_5shot_len_norm", 0.335),
("mmlu_social_sciences_test_mc_5shot_bpb", "mmlu_social_sciences_test_mc_5shot_len_norm", 0.219),
("mmlu_other_test_mc_5shot_bpb", "mmlu_other_test_mc_5shot_len_norm", 0.231),
],
"mc_mc_hellaswag": [("hellaswag_val_mc_5shot_bpb", "hellaswag_val_mc_5shot_acc", 1.0)],
"mc_mc_arc-c": [("arc_challenge_test_mc_5shot_bpb", "arc_challenge_test_mc_5shot_acc", 1.0)],
"mc_mc_arc-e": [("arc_easy_test_mc_5shot_bpb", "arc_easy_test_mc_5shot_acc", 1.0)],
"mc_mc_piqa": [("piqa_val_mc_5shot_bpb", "piqa_val_mc_5shot_acc", 1.0)],
"mc_mc_csqa": [("csqa_val_mc_5shot_bpb", "csqa_val_mc_5shot_acc", 1.0)],
"mc_mc_socialiqa": [("socialiqa_val_mc_5shot_bpb", "socialiqa_val_mc_5shot_acc", 1.0)],
"mc_mc_openbookqa": [("openbookqa_test_mc_5shot_bpb", "openbookqa_test_mc_5shot_acc", 1.0)],
}

fig, axs = plt.subplots(8, 3, figsize=(3 * 6, 8 * 4.5))

for i, (task, metrics) in enumerate(METRICS_BY_TASK.items()):
ax = axs[i % 8, i // 8]
for model in MODELS:
with open(f'wandb/eval_bpb_mc_v2/{model.replace("/", "_")}.json') as f:
data = json.load(f)
try:
rc_bpb = np.average(
[data[f"eval/downstream_bpb/{metric[0]}"] for metric in metrics],
weights=[metric[2] for metric in metrics],
)
acc = np.average(
[data[f"eval/downstream/{metric[1]}"] for metric in metrics],
weights=[metric[2] for metric in metrics],
)
except KeyError:
continue
color = get_color(model)
ax.scatter([rc_bpb], [acc], color=color, s=100)
ax.annotate(
text=model.split("/")[1],
xy=(float(rc_bpb), float(acc)),
xytext=(8, -3),
textcoords="offset points",
fontsize=8,
)
ax.set_xlabel(f'{task.split("_")[0]} bpb')
ax.set_ylabel(f'{task.split("_")[1]} acc')
ax.set_title(task)

plt.savefig("figure/peteish-moreeval/external.png", dpi=300, bbox_inches="tight")
128 changes: 0 additions & 128 deletions scripts/scaling/eval_bpb_mc.py

This file was deleted.

15 changes: 10 additions & 5 deletions scripts/scaling/predict.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# python scripts/scaling/predict.py -k main -c scripts/scaling/final.json --step2-config-path scripts/scaling/step2.json -n 6887575552 -d 3945065873408 -t 7b
# python scripts/scaling/predict.py -k main -c scripts/scaling/final.json --step2-config-path scripts/scaling/step2.json -n 13202396160 -d 5000088518656 -t 13b
# python scripts/scaling/predict.py -k v2_main -c scripts/scaling/final.json --step2-config-path scripts/scaling/step2.json -n 6887575552 -d 3945065873408 -t 7b --skip_perc 0.1 --moving_avg 5
# python scripts/scaling/predict.py -k v2_main -c scripts/scaling/final.json --step2-config-path scripts/scaling/step2.json -n 13202396160 -d 5000088518656 -t 13b --skip_perc 0.1 --moving_avg 5
# python scripts/scaling/predict.py -k v2_main -c scripts/scaling/final.json --step2-config-path scripts/scaling/step2.json -n 6887575552 -d 3945065873408 -t 7b --skip_perc 0.1 --moving_avg 5 --x_metric c4
# python scripts/scaling/predict.py -k v2_main -c scripts/scaling/final.json --step2-config-path scripts/scaling/step2.json -n 13202396160 -d 5000088518656 -t 13b --skip_perc 0.1 --moving_avg 5 --x_metric c4
# python scripts/scaling/predict.py -k main_mc -c scripts/scaling/final.json --step2-config-path scripts/scaling/step2_mc.json -y mc_acc -n 6887575552 -d 3945065873408 -t 7b-4T-final
# python scripts/scaling/predict.py -k main_mc -c scripts/scaling/final.json --step2-config-path scripts/scaling/step2_mc.json -y mc_acc -n 13202396160 -d 5000088518656 -t 13b-5T-final

Expand All @@ -23,6 +25,9 @@ def parse_args():
parser.add_argument(
"-k", "--keys", nargs="+", default=[], help="For avg metrics. Use one of [all-val-lm, all-bpb]"
)
parser.add_argument(
"-x", "--x_metric", default="rc_bpb", choices=["rc_bpb", "c4"], help="Metric as input"
)
parser.add_argument(
"-y", "--y_metric", default="rc_acc", choices=["rc_acc", "mc_acc"], help="Metric to predict"
)
Expand Down Expand Up @@ -60,13 +65,13 @@ def main():
for r, task_name in enumerate(args.keys):
# Step 1
step1_data_by_name = get_step1_data_by_name(
configs, task_name, y_metric="rc_bpb", moving_avg=args.moving_avg
configs, task_name, y_metric=args.x_metric, moving_avg=args.moving_avg
)
step1_coefficients = fit_step1(step1_data_by_name, y_metric="rc_bpb")
step1_coefficients, _ = fit_step1(step1_data_by_name, y_metric=args.x_metric)

# Step 2
step2_data_by_name = get_step2_data_by_name(
step2_configs, task_name, y_metric=args.y_metric, moving_avg=args.moving_avg, skip_perc=args.skip_perc
step2_configs, task_name, x_metric=args.x_metric, y_metric=args.y_metric, moving_avg=args.moving_avg, skip_perc=args.skip_perc
)
step2_coefficients, _ = fit_step2(step2_data_by_name, task_name, args.y_metric)

Expand Down
Loading

0 comments on commit 8cc954a

Please sign in to comment.