Skip to content

Commit

Permalink
prediction interval bounds
Browse files Browse the repository at this point in the history
  • Loading branch information
AkshitaB committed Nov 19, 2024
1 parent 9b80ad9 commit e5ebb23
Show file tree
Hide file tree
Showing 7 changed files with 52 additions and 21 deletions.
16 changes: 16 additions & 0 deletions olmo/scaling/scaling_laws/fitting_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,22 @@ def sigmoid(x, a, x0, k, b):
o = a / (1 + np.exp(-k * (x - x0))) + b
return o

def sigmoid_fit(x, p):
o = p[0] / (1 + np.exp(-p[2] * (x - p[1]))) + p[3]
return o

def grad_sigmoid_fit(x, p):
exp_term = np.exp(-p[2] * (x - p[1]))
denom = (1 + exp_term)
o = p[0] / denom + p[3]

grad_a = 1 / denom
grad_x0 = p[0] * p[2] * exp_term / (denom ** 2)
grad_k = p[0] * (x - p[1]) * exp_term / (denom ** 2)
grad_b = 1

return [grad_a, grad_x0, grad_k, grad_b]

def exponential_fit(x, a, b, c):
return a * np.exp(b * x) + c

Expand Down
3 changes: 2 additions & 1 deletion olmo/scaling/scaling_laws/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,9 +502,10 @@ def get_step1_data_by_name(configs, task_name, y_metric="rc_bpb", moving_avg=1):
reader = csv.DictReader(file_ref)
rows = [row for row in reader]
rows = rows[-moving_avg:]
ds, ys = [], []
ds, ys, fs = [], [], []
for row in rows:
d = int(float(row["throughput/total_tokens"]))
f = d * MODEL_FLOPS[name]
y = np.average(
[float(row[key]) for key in keys], weights=[WEIGHT_BY_KEY.get(key, 1.0) for key in keys]
)
Expand Down
1 change: 0 additions & 1 deletion olmo/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -1065,7 +1065,6 @@ def check_if_cancelled(self) -> Tuple[bool, int]:
# Finally, check if someone canceled the run from W&B by adding the 'cancel' / 'canceled' tag..
# We won't see it in the run object. So we have to use the import/export API to check.
from requests.exceptions import RequestException

from wandb.errors import CommError

try:
Expand Down
3 changes: 2 additions & 1 deletion scripts/ladder_peteish.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,9 +452,10 @@ def flops_for_model(model_config: Union[ModelConfig, str]) -> int:
def flops_cmd(args: argparse.Namespace):
cfg = config_from_args(args)

from tqdm import tqdm

from olmo.eval import build_evaluator
from olmo.tokenizer import Tokenizer
from tqdm import tqdm

device = torch.device("cpu")
tokenizer = Tokenizer.from_train_config(cfg)
Expand Down
1 change: 1 addition & 0 deletions scripts/scaling/stacked.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import argparse
import re

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
Expand Down
34 changes: 20 additions & 14 deletions scripts/scaling/step1.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,11 @@ def parse_args():
)
parser.add_argument("--moving_avg", type=int, default=1, help="Moving average for bpb loss")
parser.add_argument("-c", "--config-path", type=str, required=True, help="Path to config file")
parser.add_argument("-o", "--output-path", type=str, required=True, help="Path to write output figure")
parser.add_argument("-o", "--output-path", type=str, required=False, help="Path to write output figure")
args = parser.parse_args()

if not args.keys:
args.keys = ["main"]
args.keys = get_task_sets(args.keys)

return args
Expand Down Expand Up @@ -198,7 +200,9 @@ def main():
num_tasks = len(args.keys)
num_cols = min(4, num_tasks)
num_rows = (num_tasks + num_cols - 1) // num_cols
fig, axes = plt.subplots(num_rows, num_cols, figsize=(3.75 * num_cols, 3.25 * num_rows), squeeze=False)

if args.output_path:
fig, axes = plt.subplots(num_rows, num_cols, figsize=(3.75 * num_cols, 3.25 * num_rows), squeeze=False)

results = "Task Name | Actual Value | Predicted Value | Relative Error"

Expand All @@ -208,26 +212,28 @@ def main():
)

# fit the parameters
coefficients, cov = fit_step1(data_by_name, args.accuracy)
coefficients, cov = fit_step1(data_by_name, args.y_metric)

# make predictions
predicted_data_by_name, plotted_predicted_data_by_name, (y, y_pred, rel_error) = predict_step1(
data_by_name, coefficients, y_metric=args.y_metric
)
results += f"\n{task_name} | {prettify(y, False)} | {prettify(y_pred, False)} | {prettify(rel_error)}"

plot_step1(configs,
data_by_name,
predicted_data_by_name,
plotted_predicted_data_by_name,
task_name,
str_chinchilla_n_d_fit(coefficients),
args.y_metric,
axes[i // num_cols][i % num_cols],
)
if args.output_path:
plot_step1(configs,
data_by_name,
predicted_data_by_name,
plotted_predicted_data_by_name,
task_name,
str_chinchilla_n_d_fit(coefficients),
args.y_metric,
axes[i // num_cols][i % num_cols],
)

fig.tight_layout()
fig.savefig(args.output_path, dpi=300)
if args.output_path:
fig.tight_layout()
fig.savefig(args.output_path, dpi=300)

print(results)

Expand Down
15 changes: 11 additions & 4 deletions scripts/scaling/step2.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,13 @@
import numpy as np
import seaborn as sns

from olmo.scaling.scaling_laws.fitting_functions import get_coefficients, sigmoid, get_std_errors
from olmo.scaling.scaling_laws.fitting_functions import (
get_coefficients,
get_std_errors,
grad_sigmoid_fit,
sigmoid,
sigmoid_fit,
)
from olmo.scaling.scaling_laws.utils import (
get_final_configs,
get_step2_data_by_name,
Expand Down Expand Up @@ -106,7 +112,7 @@ def main():
"ys": [sigmoid(x, *coefficients) for x in xs],
}

std_errors = get_std_errors(plotted_predicted_data["xs"], plotted_predicted_data["ys"], coefficients, cov)
std_errors = get_std_errors(plotted_predicted_data["xs"], plotted_predicted_data["ys"], coefficients, cov, sigmoid_fit, grad_sigmoid_fit)

# Compute prediction intervals
plotted_y_lower = plotted_predicted_data["ys"] - 1.96 * std_errors
Expand All @@ -130,7 +136,8 @@ def main():
)
for x, y, y_pred in zip(data["xs"], data["ys"], predicted_data["ys"]):
rel_error = (y_pred - y) / y
std_error = get_std_errors([x], [y_pred], coefficients, cov)[0]
std_error = get_std_errors([x], [y_pred], coefficients, cov, sigmoid_fit, grad_sigmoid_fit)[0]
delta_error = 1.96 * std_error
y_lower = y_pred - 1.96 * std_error
y_upper = y_pred + 1.96 * std_error
rel_error_lower = (y_lower - y) / y
Expand All @@ -150,7 +157,7 @@ def main():
color=config.color,
)
results += (
f"\n{task_name} | {prettify(y, False)} | {prettify(y_pred, False)} +/- {prettify(1.96 * std_error, False)} | {prettify(rel_error)}"
f"\n{task_name} | {prettify(y, False)} | {prettify(y_pred, False)} ± {prettify(delta_error, False)} | {prettify(rel_error)}"
)
avg_unsigned_rel_err = np.mean(unsigned_rel_errs)

Expand Down

0 comments on commit e5ebb23

Please sign in to comment.