Skip to content

Commit

Permalink
Chained figure
Browse files Browse the repository at this point in the history
  • Loading branch information
liujch1998 committed Nov 26, 2024
1 parent e06ffc4 commit cba738c
Show file tree
Hide file tree
Showing 2 changed files with 177 additions and 17 deletions.
180 changes: 170 additions & 10 deletions scripts/scaling/predict.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
# python scripts/scaling/predict.py -k v2_main -c scripts/scaling/final.json --step2-config-path scripts/scaling/step2.json -n 6887575552 -d 3945065873408 -t 7B-4T --skip_perc 0.1 --moving_avg 5
# python scripts/scaling/predict.py -k v2_main -c scripts/scaling/final.json --step2-config-path scripts/scaling/step2.json -n 13202396160 -d 5000088518656 -t 13B-5T --skip_perc 0.1 --moving_avg 5
# python scripts/scaling/predict.py -k v2_main -c scripts/scaling/final.json --step2-config-path scripts/scaling/step2.json -n 6887575552 -d 3945065873408 -t 7B-4T --skip_perc 0.1 --moving_avg 5 --x_metric c4
# python scripts/scaling/predict.py -k v2_main -c scripts/scaling/final.json --step2-config-path scripts/scaling/step2.json -n 13202396160 -d 5000088518656 -t 13B-5T --skip_perc 0.1 --moving_avg 5 --x_metric c4
# python scripts/scaling/predict.py -k v2_main -c scripts/scaling/final.json --step2-config-path scripts/scaling/step2_mc.json -y mc_acc -n 6887575552 -d 3945065873408 -t 7B-4T --moving_avg 5
# python scripts/scaling/predict.py -k v2_main -c scripts/scaling/final.json --step2-config-path scripts/scaling/step2_mc.json -y mc_acc -n 13202396160 -d 5000088518656 -t 13B-5T --moving_avg 5
# python scripts/scaling/predict.py -k v2_main -c scripts/scaling/final.json --step2-config-path scripts/scaling/step2.json -o figure/peteish-moreeval/chained_main.pdf -n 6887575552 -d 3945065873408 -t 7B-4T --skip_perc 0.1 --moving_avg 5
# python scripts/scaling/predict.py -k v2_main -c scripts/scaling/final.json --step2-config-path scripts/scaling/step2.json -o figure/peteish-moreeval/chained_main.pdf -n 13202396160 -d 5000088518656 -t 13B-5T --skip_perc 0.1 --moving_avg 5
# python scripts/scaling/predict.py -k v2_main -c scripts/scaling/final.json --step2-config-path scripts/scaling/step2.json -o figure/peteish-moreeval/chained_c4_main.pdf -n 6887575552 -d 3945065873408 -t 7B-4T --skip_perc 0.1 --moving_avg 5 --x_metric c4
# python scripts/scaling/predict.py -k v2_main -c scripts/scaling/final.json --step2-config-path scripts/scaling/step2.json -o figure/peteish-moreeval/chained_c4_main.pdf -n 13202396160 -d 5000088518656 -t 13B-5T --skip_perc 0.1 --moving_avg 5 --x_metric c4
# python scripts/scaling/predict.py -k v2_main -c scripts/scaling/final.json --step2-config-path scripts/scaling/step2_mc.json -o figure/peteish-moreeval/chained_mc_main.pdf -y mc_acc -n 6887575552 -d 3945065873408 -t 7B-4T --moving_avg 5
# python scripts/scaling/predict.py -k v2_main -c scripts/scaling/final.json --step2-config-path scripts/scaling/step2_mc.json -o figure/peteish-moreeval/chained_mc_main.pdf -y mc_acc -n 13202396160 -d 5000088518656 -t 13B-5T --moving_avg 5

import argparse

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from step1 import fit_step1
from step2 import fit_step2

Expand All @@ -17,8 +19,12 @@
get_step1_data_by_name,
get_step2_data_by_name,
get_task_sets,
tasks,
)

MARKERS = ["s", "P", "p", "*", "o"]
FONTSIZE = 9


def parse_args():
parser = argparse.ArgumentParser()
Expand All @@ -38,6 +44,7 @@ def parse_args():
)
parser.add_argument("-c", "--config-path", type=str, required=True, help="Path to config file")
parser.add_argument("--step2-config-path", type=str, default=None, help="Path to config file for step2")
parser.add_argument("-o", "--output-path", type=str, required=True, help="Path to write output figure")
parser.add_argument("-n", "--n", type=int, required=True, help="Model size of the target model")
parser.add_argument("-d", "--d", type=int, required=True, help="Data size of the target model")
parser.add_argument(
Expand All @@ -51,6 +58,116 @@ def parse_args():
return args


def predict_chained(data_by_name, step1_coefficients, step2_coefficients):
predicted_data_by_name = {}
plotted_predicted_data_by_name = {}

dmin = 0.8 * min([min(data["ds"]) for data in data_by_name.values()])
dmax = 1.5 * max([max(data["ds"]) for data in data_by_name.values()])

for name, data in data_by_name.items():
predicted_data_by_name[name] = {
"ds": data["ds"],
"ys": [sigmoid(chinchilla_n_d_fit([n, d], step1_coefficients), *step2_coefficients) for n, d in zip(data["ns"], data["ds"])],
}
ds = np.exp(np.linspace(np.log(dmin), np.log(dmax), 100))
ns = [data["ns"][0]] * len(ds)
plotted_predicted_data_by_name[name] = {
"ds": ds,
"ys": [sigmoid(chinchilla_n_d_fit([n, d], step1_coefficients), *step2_coefficients) for n, d in zip(ns, ds)],
}

if data["mode"] == "eval":
predicted_data = predicted_data_by_name[name]
for d, y, y_pred in zip(data["ds"], data["ys"], predicted_data["ys"]):
rel_error = (y_pred - y) / y

return predicted_data_by_name, plotted_predicted_data_by_name, (y, y_pred, rel_error)


def str_chained_fit(step1_coefficients, step2_coefficients):
a, b, alpha, beta, E = step1_coefficients
A, B = np.exp(a), np.exp(b)
a, x0, k, b = step2_coefficients
return (
f"L(N, D) = {A:.2f} / N^{alpha:.2f} + {B:.2f} / D^{beta:.2f} + {E:.2f}; Acc(L) = {a:.2f} / (1 + e^(-{k:.2f}(L - {x0:.2f}))) + {b:.2f}"
)


def plot_chained(
configs,
data_by_name,
predicted_data_by_name,
plotted_predicted_data_by_name,
task_name,
fit_str,
ax=plt.gca(),
):
# plot the fitted curve
for name, data in plotted_predicted_data_by_name.items():
config = configs[name]
ax.plot(
data["ds"],
data["ys"],
color=config.color,
linestyle="--",
alpha=0.7,
linewidth=1.5,
label=f'{config.label} (fitted)' if config.mode == "train" else None,
)

# plot the actual and predicted data
num_eval_annotation = 0
for name, data in data_by_name.items():
config = configs[name]
predicted_data = predicted_data_by_name[name]

for i, (d, y) in enumerate(zip(data["ds"], data["ys"])):
ax.scatter(
d,
y,
color=config.color,
marker=MARKERS[i] if config.mode == "train" else "o",
s=50 if config.mode == "train" else 20,
label=f"{config.label} (target)" if config.mode == "eval" else None,
)

for d, y, y_pred in zip(data["ds"], data["ys"], predicted_data["ys"]):
rel_error = (y_pred - y) / y
if config.mode == "train":
pass
else:
ax.scatter(
d,
y_pred,
color=config.color,
marker="x",
s=20,
label=f"{config.label} (predicted)",
)
ax.annotate(
f"{abs(rel_error * 100):.1f}%",
(d, y_pred),
textcoords="offset points",
xytext=(10, -5 + 10*num_eval_annotation),
ha="left",
va="bottom",
fontsize=FONTSIZE,
color=config.color,
)
num_eval_annotation += 1

ax.set_xscale("log")
ax.legend(loc="upper right", ncols=1, fontsize=FONTSIZE)
ax.set_xlabel("Tokens (D)", fontsize=FONTSIZE)
ax.set_ylabel("Task RC accuracy", fontsize=FONTSIZE)
ax.set_title(
f"{tasks[task_name].display_name}",
fontsize=FONTSIZE,
fontweight="bold",
)


def main():
args = parse_args()
configs = get_final_configs(args.config_path)
Expand All @@ -59,16 +176,18 @@ def main():
else:
step2_configs = configs

sns.set_style("whitegrid")
num_tasks = len(args.keys)
num_cols = min(4, num_tasks)
num_rows = (num_tasks + num_cols - 1) // num_cols
fig, axes = plt.subplots(num_rows, num_cols, figsize=(2.75 * num_cols, 2.25 * num_rows), squeeze=False)

results = "Task Name | Prediction | Actual | Rel Error"

for r, task_name in enumerate(args.keys):
# Step 1
step1_data_by_name = get_step1_data_by_name(
configs, task_name, y_metric=args.x_metric, moving_avg=args.moving_avg
)
step1_coefficients, _ = fit_step1(step1_data_by_name, y_metric=args.x_metric)

# Step 2
step2_data_by_name = get_step2_data_by_name(
step2_configs,
task_name,
Expand All @@ -77,8 +196,29 @@ def main():
moving_avg=args.moving_avg,
skip_perc=args.skip_perc,
)
single_step_data_by_name = get_step1_data_by_name(
configs, task_name, y_metric="rc_acc", moving_avg=args.moving_avg
)

# fit the parameters
step1_coefficients, _ = fit_step1(step1_data_by_name, y_metric=args.x_metric)
step2_coefficients, _ = fit_step2(step2_data_by_name, task_name, args.y_metric, args.use_log_sigmoid)

# make predictions
predicted_data_by_name, plotted_predicted_data_by_name, (y, y_pred, rel_error) = predict_chained(
single_step_data_by_name, step1_coefficients, step2_coefficients
)

plot_chained(
configs,
single_step_data_by_name,
predicted_data_by_name,
plotted_predicted_data_by_name,
task_name,
str_chained_fit(step1_coefficients, step2_coefficients),
axes[r // num_cols][r % num_cols],
)

# make predictions
pred_loss = chinchilla_n_d_fit([args.n, args.d], step1_coefficients)
fit_fn = log_sigmoid if args.use_log_sigmoid else sigmoid
Expand All @@ -91,6 +231,26 @@ def main():
else:
results += f"\n{task_name} | {pred_acc * 100:.1f} | - | -"

handles, labels = axes[-1][-1].get_legend_handles_labels()
# delete x-axis labels for all but the bottom row
for i in range(num_cols):
for j in range(num_rows):
if j != num_rows - 1:
axes[j][i].set_xlabel("")
if i != 0:
axes[j][i].set_ylabel("")

axes[j][i].legend().remove()

fig.tight_layout(w_pad=0.01)
legend = fig.legend(handles, labels, loc='upper center',
ncol=10, fontsize=FONTSIZE, bbox_to_anchor=(0.5, 1.07),
handletextpad=0.3, columnspacing=0.7)
for handle in legend.legend_handles:
handle.set_alpha(1.0)

fig.savefig(args.output_path, dpi=300, bbox_inches='tight')

print(results)


Expand Down
14 changes: 7 additions & 7 deletions scripts/scaling/single_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def parse_args():
return args


def fit_step12(data_by_name, task_name):
def fit_single_step(data_by_name, task_name):
train_nds, train_ys = [], []
for name, data in data_by_name.items():
if data["mode"] == "train":
Expand All @@ -59,7 +59,7 @@ def fit_step12(data_by_name, task_name):
return coefficients


def predict_step12(data_by_name, coefficients):
def predict_single_step(data_by_name, coefficients):
predicted_data_by_name = {}
plotted_predicted_data_by_name = {}

Expand Down Expand Up @@ -94,7 +94,7 @@ def str_combined_fit(coefficients):
)


def plot_step12(
def plot_single_step(
configs,
data_by_name,
predicted_data_by_name,
Expand Down Expand Up @@ -162,7 +162,7 @@ def plot_step12(
ax.set_xscale("log")
ax.legend(loc="upper right", ncols=1, fontsize=FONTSIZE)
ax.set_xlabel("Tokens (D)", fontsize=FONTSIZE)
ax.set_ylabel("Task accuracy", fontsize=FONTSIZE)
ax.set_ylabel("Task RC accuracy", fontsize=FONTSIZE)
ax.set_title(
f"{tasks[task_name].display_name} ({avg_unsigned_rel_error * 100:.2f}%)",
fontsize=FONTSIZE,
Expand All @@ -186,15 +186,15 @@ def main():
data_by_name = get_step1_data_by_name(configs, task_name, y_metric="rc_acc", moving_avg=args.moving_avg)

# fit the parameters
coefficients = fit_step12(data_by_name, task_name)
coefficients = fit_single_step(data_by_name, task_name)

# make predictions
predicted_data_by_name, plotted_predicted_data_by_name, (y, y_pred, rel_error) = predict_step12(
predicted_data_by_name, plotted_predicted_data_by_name, (y, y_pred, rel_error) = predict_single_step(
data_by_name, coefficients
)
results += f"\n{task_name} | {prettify(y, False)} | {prettify(y_pred, False)} | {prettify(rel_error)}"

plot_step12(
plot_single_step(
configs,
data_by_name,
predicted_data_by_name,
Expand Down

0 comments on commit cba738c

Please sign in to comment.