Skip to content

Commit

Permalink
Step2: also take moving average of target model
Browse files Browse the repository at this point in the history
  • Loading branch information
liujch1998 committed Nov 26, 2024
1 parent 8a9185e commit be5a22b
Show file tree
Hide file tree
Showing 10 changed files with 475 additions and 66 deletions.
27 changes: 18 additions & 9 deletions olmo/scaling/scaling_laws/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -722,7 +722,10 @@ def get_step1_data_by_name(configs, task_name, y_metric="rc_bpb", moving_avg=1):
rows = rows[-moving_avg:]
ds, ys, fs = [], [], []
for row in rows:
d = int(float(row["throughput/total_tokens"]))
if "throughput/total_tokens" in row:
d = int(float(row["throughput/total_tokens"]))
else:
d = int(float(row["_step"])) * int(float(row["batch_size_in_tokens"]))
f = float(d * MODEL_FLOPS[name.split("-")[0]])
y = np.average(
[float(row[key]) for key in keys], weights=[WEIGHT_BY_KEY.get(key, 1.0) for key in keys]
Expand Down Expand Up @@ -770,6 +773,8 @@ def get_flops_data_by_name(configs, keys, num_to_avg=1):

def moving_average(arr, n):
ret = np.cumsum(arr, dtype=float)
if len(ret) < n:
return ret / np.arange(1, len(ret) + 1)
ret[n:] = ret[n:] - ret[:-n]
return np.concatenate([ret[: n - 1] / np.arange(1, n), ret[n - 1 :] / n])

Expand Down Expand Up @@ -830,7 +835,10 @@ def get_step2_data_by_name(
rows = [row for row in reader]
xs, ys, ds, ns, ls = [], [], [], [], []
for row in rows:
d = int(float(row["throughput/total_tokens"]))
if "throughput/total_tokens" in row:
d = int(float(row["throughput/total_tokens"]))
else:
d = int(float(row["_step"])) * int(float(row["batch_size_in_tokens"]))
x = np.average(
[float(row[key]) for key in loss_keys],
weights=[WEIGHT_BY_KEY.get(key, 1.0) for key in loss_keys],
Expand Down Expand Up @@ -860,14 +868,15 @@ def get_step2_data_by_name(
ns = ns[int(np.ceil(skip_perc * len(ns))) :]
ls = ls[int(np.ceil(skip_perc * len(ls))) :]

# apply moving_avg
xs = moving_average(xs, n=moving_avg).tolist()
ys = moving_average(ys, n=moving_avg).tolist()
# ys = ys[moving_avg-1:]
# ds = ds[moving_avg-1:]
# ns = ns[moving_avg-1:]
# ls = ls[moving_avg-1:]
# apply moving_avg
xs = moving_average(xs, n=moving_avg).tolist()
ys = moving_average(ys, n=moving_avg).tolist()
# ys = ys[moving_avg-1:]
# ds = ds[moving_avg-1:]
# ns = ns[moving_avg-1:]
# ls = ls[moving_avg-1:]

if config.mode == "train":
# last n points
if last_n_points > 0:
xs = xs[-last_n_points:]
Expand Down
15 changes: 13 additions & 2 deletions scripts/scaling/data/peteish-moreeval/peteish13_eval_final.csv

Large diffs are not rendered by default.

157 changes: 157 additions & 0 deletions scripts/scaling/data/peteish-moreeval/peteish7_eval_150k-end.csv

Large diffs are not rendered by default.

187 changes: 187 additions & 0 deletions scripts/scaling/data/peteish-moreeval/peteish7_eval_full.csv

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions scripts/scaling/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
# python scripts/scaling/predict.py -k v2_main -c scripts/scaling/final.json --step2-config-path scripts/scaling/step2.json -n 13202396160 -d 5000088518656 -t 13b --skip_perc 0.1 --moving_avg 5
# python scripts/scaling/predict.py -k v2_main -c scripts/scaling/final.json --step2-config-path scripts/scaling/step2.json -n 6887575552 -d 3945065873408 -t 7b --skip_perc 0.1 --moving_avg 5 --x_metric c4
# python scripts/scaling/predict.py -k v2_main -c scripts/scaling/final.json --step2-config-path scripts/scaling/step2.json -n 13202396160 -d 5000088518656 -t 13b --skip_perc 0.1 --moving_avg 5 --x_metric c4
# python scripts/scaling/predict.py -k main_mc -c scripts/scaling/final.json --step2-config-path scripts/scaling/step2_mc.json -y mc_acc -n 6887575552 -d 3945065873408 -t 7b-4T-final
# python scripts/scaling/predict.py -k main_mc -c scripts/scaling/final.json --step2-config-path scripts/scaling/step2_mc.json -y mc_acc -n 13202396160 -d 5000088518656 -t 13b-5T-final
# python scripts/scaling/predict.py -k v2_main -c scripts/scaling/final.json --step2-config-path scripts/scaling/step2_mc.json -y mc_acc -n 6887575552 -d 3945065873408 -t 7b-4T --moving_avg 5
# python scripts/scaling/predict.py -k v2_main -c scripts/scaling/final.json --step2-config-path scripts/scaling/step2_mc.json -y mc_acc -n 13202396160 -d 5000088518656 -t 13b-5T --moving_avg 5

import argparse

Expand Down
72 changes: 47 additions & 25 deletions scripts/scaling/single_step.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# python scripts/scaling/single_step.py -k v2_main -c scripts/scaling/final.json -o figure/peteish-moreeval/single_step_main.png --moving_avg 5
# python scripts/scaling/single_step.py -k v2_main -c scripts/scaling/final.json -o figure/peteish-moreeval/single_step_main.pdf --moving_avg 5

import argparse

Expand All @@ -20,6 +20,7 @@
)

MARKERS = ["s", "P", "p", "*", "o"]
FONTSIZE = 11


def parse_args():
Expand Down Expand Up @@ -102,8 +103,22 @@ def plot_step12(
fit_str,
ax=plt.gca(),
):
# plot the fitted curve
for name, data in plotted_predicted_data_by_name.items():
config = configs[name]
ax.plot(
data["ds"],
data["ys"],
color=config.color,
linestyle="--",
alpha=0.7,
linewidth=1.5,
label=f'{config.label} ({"fitted" if config.mode == "train" else "predicted"})',
)

# plot the actual and predicted data
unsigned_rel_errors = []
num_eval_annotation = 0
for name, data in data_by_name.items():
config = configs[name]
predicted_data = predicted_data_by_name[name]
Expand Down Expand Up @@ -133,35 +148,25 @@ def plot_step12(
)
ax.annotate(
f"{abs(rel_error * 100):.1f}%",
(d, y),
(d, y_pred),
textcoords="offset points",
xytext=(3, 3),
xytext=(10, -5 + 10*num_eval_annotation),
ha="left",
va="bottom",
fontsize=10,
fontsize=FONTSIZE,
color=config.color,
)
num_eval_annotation += 1
avg_unsigned_rel_error = np.mean(unsigned_rel_errors)

# plot the fitted curve
for name, data in plotted_predicted_data_by_name.items():
config = configs[name]
ax.plot(
data["ds"],
data["ys"],
color=config.color,
linestyle="--",
linewidth=1.5,
label=f'{config.label} ({"fitted" if config.mode == "train" else "predicted"})',
)

ax.set_xscale("log")
ax.legend(ncols=1, fontsize=7)
ax.set_xlabel("Tokens (D)")
ax.set_ylabel("Task accuracy")
ax.legend(loc="upper right", ncols=1, fontsize=FONTSIZE)
ax.set_xlabel("Tokens (D)", fontsize=FONTSIZE)
ax.set_ylabel("Task accuracy", fontsize=FONTSIZE)
ax.set_title(
f"{task_name}\n{fit_str}\navg rel error on fitting = {avg_unsigned_rel_error * 100:.2f}%",
fontsize=9,
f"{tasks[task_name].display_name} ({avg_unsigned_rel_error * 100:.2f}%)",
fontsize=FONTSIZE,
fontweight="bold",
)


Expand All @@ -171,9 +176,9 @@ def main():

sns.set_style("whitegrid")
num_tasks = len(args.keys)
num_cols = min(3, num_tasks)
num_cols = min(4, num_tasks)
num_rows = (num_tasks + num_cols - 1) // num_cols
fig, axes = plt.subplots(num_rows, num_cols, figsize=(3.75 * num_cols, 3.25 * num_rows), squeeze=False)
fig, axes = plt.subplots(num_rows, num_cols, figsize=(2.75 * num_cols, 2.25 * num_rows), squeeze=False)

results = "Task Name | Actual Value | Predicted Value | Relative Error"

Expand All @@ -199,8 +204,25 @@ def main():
axes[i // num_cols][i % num_cols],
)

fig.tight_layout()
fig.savefig(args.output_path, dpi=300)
handles, labels = axes[-1][-1].get_legend_handles_labels()
# delete x-axis labels for all but the bottom row
for i in range(num_cols):
for j in range(num_rows):
if j != num_rows - 1:
axes[j][i].set_xlabel("")
if i != 0:
axes[j][i].set_ylabel("")

axes[j][i].legend().remove()

fig.tight_layout(w_pad=0.01)
legend = fig.legend(handles, labels, loc='upper center',
ncol=10, fontsize=FONTSIZE, bbox_to_anchor=(0.5, 1.07),
handletextpad=0.3, columnspacing=0.7)
for handle in legend.legend_handles:
handle.set_alpha(1.0)

fig.savefig(args.output_path, dpi=300, bbox_inches='tight')

print(results)

Expand Down
6 changes: 3 additions & 3 deletions scripts/scaling/step1.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# python scripts/scaling/step1.py -k v2_main -c scripts/scaling/final.json -o figure/peteish-moreeval/step1_main.pdf --moving_avg 5
# python scripts/scaling/step1.py -k v2_main -c scripts/scaling/final.json -o figure/peteish-moreeval/step1_c4_main.pdf --y_metric c4 --moving_avg 5
# python scripts/scaling/step1.py -k v2_main -c scripts/scaling/final.json -o figure/peteish-moreeval/step1_acc_main.pdf --y_metric rc_acc
# python scripts/scaling/step1.py -o figure/peteish-moreeval/step1_taskce.pdf -c scripts/scaling/step2.json -k v2_main -y rc_soft_log
# python scripts/scaling/step1.py -k v2_main -c scripts/scaling/step2.json -o figure/peteish-moreeval/step1_taskce.pdf -y rc_soft_log

import argparse
from typing import Any, List, Tuple
Expand All @@ -26,7 +26,8 @@
)

MARKERS = ["s", "P", "p", "*", "o"]
FONTSIZE=11
FONTSIZE = 11


def parse_args():
parser = argparse.ArgumentParser()
Expand Down Expand Up @@ -251,7 +252,6 @@ def plot_step1(
ax.set_xscale("log")
ax.legend(loc="upper right", ncols=1, fontsize=FONTSIZE)
ax.set_xlabel("Tokens (D)", fontsize=FONTSIZE)

y_label_name = {
"rc_bpb": "Task loss",
"rc_acc": "Task RC accuracy",
Expand Down
51 changes: 37 additions & 14 deletions scripts/scaling/step1_flops.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# python scripts/scaling/step1_flops.py -k v2_main -c scripts/scaling/final.json -o figure/peteish-moreeval/step1_flops_main.png --moving_avg 5
# python scripts/scaling/step1_flops.py -k v2_main -c scripts/scaling/final.json -o figure/peteish-moreeval/step1_flops_main.pdf --moving_avg 5

import argparse

Expand All @@ -16,9 +16,11 @@
get_step1_data_by_name,
get_task_sets,
prettify,
tasks,
)

MARKERS = ["s", "P", "p", "*", "o"]
FONTSIZE = 11


def parse_args():
Expand Down Expand Up @@ -166,13 +168,15 @@ def plot_step1(
data["ys"],
color="black",
linestyle="--",
alpha=0.7,
linewidth=1.5,
# label=f'{config.label} ({"fitted" if config.mode == "train" else "predicted"})',
)
break

# plot the actual and predicted data
unsigned_rel_errors = []
num_eval_annotation = 0
for name, data in data_by_name.items():
config = configs[name]
predicted_data = predicted_data_by_name[name]
Expand Down Expand Up @@ -202,28 +206,30 @@ def plot_step1(
)
ax.annotate(
f"{abs(100 * rel_error):.1f}%",
(f, y),
(f, y_pred),
textcoords="offset points",
xytext=(3, 3),
xytext=(10, 1 - 10*num_eval_annotation),
ha="left",
va="bottom",
fontsize=8,
fontsize=FONTSIZE,
color=config.color,
)
num_eval_annotation += 1
avg_unsigned_rel_error = np.mean(unsigned_rel_errors)

ax.set_xscale("log")
ax.legend(loc="upper right", ncols=1, fontsize=8)
ax.set_xlabel("Flops (F)")
ax.legend(loc="upper right", ncols=1, fontsize=FONTSIZE)
ax.set_xlabel("Flops (F)", fontsize=FONTSIZE)
if y_metric == "rc_bpb":
ax.set_ylabel("Task loss")
ax.set_ylabel("Task loss", fontsize=FONTSIZE)
elif y_metric == "rc_acc":
ax.set_ylabel("Task RC accuracy")
ax.set_ylabel("Task RC accuracy", fontsize=FONTSIZE)
else:
raise ValueError(f"Unknown y_metric: {y_metric}")
ax.set_title(
f"{task_name}\n{fit_str}\navg rel error on fitting = {avg_unsigned_rel_error * 100:.2f}%",
fontsize=9,
f"{tasks[task_name].display_name} ({avg_unsigned_rel_error * 100:.2f}%)",
fontsize=FONTSIZE,
fontweight="bold",
)


Expand All @@ -233,13 +239,13 @@ def main():

sns.set_style("whitegrid")
num_tasks = len(args.keys)
num_cols = min(3, num_tasks)
num_cols = min(4, num_tasks)
num_rows = (num_tasks + num_cols - 1) // num_cols

fitting_error = 0

if args.output_path:
fig, axes = plt.subplots(num_rows, num_cols, figsize=(3.75 * num_cols, 3.25 * num_rows), squeeze=False)
fig, axes = plt.subplots(num_rows, num_cols, figsize=(2.75 * num_cols, 2.25 * num_rows), squeeze=False)

results = "Task Name | Actual Value | Predicted Value | Relative Error"

Expand Down Expand Up @@ -280,9 +286,26 @@ def main():
axes[i // num_cols][i % num_cols],
)

handles, labels = axes[-1][-1].get_legend_handles_labels()
# delete x-axis labels for all but the bottom row
for i in range(num_cols):
for j in range(num_rows):
if j != num_rows - 1:
axes[j][i].set_xlabel("")
if i != 0:
axes[j][i].set_ylabel("")

axes[j][i].legend().remove()

fig.tight_layout(w_pad=0.01)
legend = fig.legend(handles, labels, loc='upper center',
ncol=10, fontsize=FONTSIZE, bbox_to_anchor=(0.5, 1.07),
handletextpad=0.3, columnspacing=0.7)
for handle in legend.legend_handles:
handle.set_alpha(1.0)

if args.output_path:
fig.tight_layout()
fig.savefig(args.output_path, dpi=300)
fig.savefig(args.output_path, dpi=300, bbox_inches='tight')

print(results)
print("Total fitting error: ", prettify(fitting_error / num_tasks))
Expand Down
8 changes: 4 additions & 4 deletions scripts/scaling/step2.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# python scripts/scaling/step2.py -k v2_main -c scripts/scaling/step2.json -o figure/peteish-moreeval/step2_main.pdf --skip_perc 0.1 --moving_avg 5
# python scripts/scaling/step2.py -k v2_main -c scripts/scaling/step2.json -o figure/peteish-moreeval/step2_c4_main.pdf --x_metric c4 --skip_perc 0.1 --moving_avg 5
# python scripts/scaling/step2.py -k mmlu_avg_test_5shot -c scripts/scaling/step2_mc.json -o figure/peteish-moreeval/step2_mc_mmlu.pdf -y mc_acc
# python scripts/scaling/step2.py -o figure/peteish-moreeval/step2_taskce.pdf -c scripts/scaling/step2.json -k v2_main --skip_perc 0.5 --use_log_sigmoid --x_metric rc_soft_log
# python scripts/scaling/step2.py -k mmlu_avg_test_5shot -c scripts/scaling/step2_mc.json -o figure/peteish-moreeval/step2_mc_mmlu.pdf -y mc_acc --moving_avg 5
# python scripts/scaling/step2.py -k v2_main -c scripts/scaling/step2.json -o figure/peteish-moreeval/step2_taskce.pdf --skip_perc 0.5 --use_log_sigmoid --x_metric rc_soft_log

import argparse

Expand All @@ -27,7 +27,7 @@
tasks,
)

FONTSIZE=11
FONTSIZE = 11

def parse_args():
parser = argparse.ArgumentParser()
Expand Down Expand Up @@ -193,7 +193,7 @@ def plot_step2(
f"{np.abs(rel_error) * 100:.1f}%",
(x, y),
textcoords="offset points",
xytext=(8 - 35*num_eval_annotation, -7),
xytext=(8 - 40*num_eval_annotation, -7),
ha="left",
va="bottom",
fontsize=FONTSIZE,
Expand Down
Loading

0 comments on commit be5a22b

Please sign in to comment.