Skip to content

Commit

Permalink
add pred intervals, but commented
Browse files Browse the repository at this point in the history
  • Loading branch information
AkshitaB committed Nov 20, 2024
1 parent 74e44bd commit e88233a
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 6 deletions.
23 changes: 22 additions & 1 deletion scripts/scaling/step1.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
get_coefficients_huber,
grad_chinchilla_n_d_fit,
grad_chinchilla_n_d_negated_fit,
get_std_errors,
)
from olmo.scaling.scaling_laws.utils import (
get_final_configs,
Expand Down Expand Up @@ -138,11 +139,29 @@ def plot_step1(
task_name,
fit_str,
y_metric,
coefficients,
cov,
ax=plt.gca(),
):
# plot the fitted curve
for name, data in plotted_predicted_data_by_name.items():
config = configs[name]

if config.mode == "eval":
std_errors = get_std_errors(
[[config.n, d] for d in data["ds"]],
data["ys"],
coefficients,
cov,
chinchilla_n_d_fit,
grad_chinchilla_n_d_fit,
)

# Compute prediction intervals
plotted_y_lower = data["ys"] - 1.96 * std_errors
plotted_y_upper = data["ys"] + 1.96 * std_errors
# ax.fill_between(data["ds"], plotted_y_lower, plotted_y_upper, color="pink", alpha=0.3)

ax.plot(
data["ds"],
data["ys"],
Expand Down Expand Up @@ -179,7 +198,7 @@ def plot_step1(
color=config.color,
marker="o",
s=10,
label=f"{config.label} ({'predicted'})",
# label=f"{config.label} ({'predicted'})",
)
ax.annotate(
f"{prettify(rel_error)}",
Expand Down Expand Up @@ -254,6 +273,8 @@ def main():
task_name,
str_chinchilla_n_d_fit(coefficients),
args.y_metric,
coefficients,
cov,
axes[i // num_cols][i % num_cols],
)

Expand Down
38 changes: 33 additions & 5 deletions scripts/scaling/step1_flops.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import numpy as np
import seaborn as sns

from olmo.scaling.scaling_laws.fitting_functions import get_coefficients
from olmo.scaling.scaling_laws.fitting_functions import get_coefficients, chinchilla_flops_fit, grad_chinchilla_flops_fit, get_std_errors
from olmo.scaling.scaling_laws.utils import (
get_final_configs,
get_step1_data_by_name,
Expand Down Expand Up @@ -36,7 +36,7 @@ def parse_args():
return args


def chinchilla_flops_fit(x, a, b, E):
def chinchilla_flops(x, a, b, E):
# return ax**b + E
return a * np.pow(x, b) + E

Expand All @@ -54,7 +54,7 @@ def fit_step1(data_by_name, y_metric):
coefficients, cov = get_coefficients(
train_fs,
train_ys,
chinchilla_flops_fit,
chinchilla_flops,
p0,
bounds=bounds,
disp=False,
Expand All @@ -76,9 +76,9 @@ def predict_step1(configs, data_by_name, coefficients, y_metric):
fmax = 1.2 * max([max(data["fs"]) for data in data_by_name.values()])

if y_metric == "rc_bpb":
func = chinchilla_flops_fit
func = chinchilla_flops
elif y_metric == "rc_acc":
func = chinchilla_flops_fit
func = chinchilla_flops
else:
raise ValueError(f"Unknown y_metric: {y_metric}")

Expand Down Expand Up @@ -119,8 +119,34 @@ def plot_step1(
task_name,
fit_str,
y_metric,
coefficients,
cov,
ax=plt.gca(),
):

fmin = min(min(data["fs"]) for data in plotted_predicted_data_by_name.values())
fmax = max(max(data["fs"]) for data in plotted_predicted_data_by_name.values())
fs = np.linspace(fmin, fmax, 100)
plotted_predicted_data = {
"fs": fs,
"ys": [chinchilla_flops(f, *coefficients) for f in fs],
}

std_errors = get_std_errors(
plotted_predicted_data["fs"],
plotted_predicted_data["ys"],
coefficients,
cov,
chinchilla_flops_fit,
grad_chinchilla_flops_fit,
)

# Compute prediction intervals
plotted_y_lower = plotted_predicted_data["ys"] - 1.96 * std_errors
plotted_y_upper = plotted_predicted_data["ys"] + 1.96 * std_errors

# ax.fill_between(plotted_predicted_data["fs"], plotted_y_lower, plotted_y_upper, color="pink", alpha=0.3)

# plot the fitted curve
for name, data in plotted_predicted_data_by_name.items():
config = configs[name]
Expand Down Expand Up @@ -235,6 +261,8 @@ def main():
task_name,
str_chinchilla_flops_fit(coefficients),
args.y_metric,
coefficients,
cov,
axes[i // num_cols][i % num_cols],
)

Expand Down

0 comments on commit e88233a

Please sign in to comment.