From e5ebb23a55cb1b2ac2184e14f90cf3ef240ce1e4 Mon Sep 17 00:00:00 2001 From: Akshita Bhagia Date: Mon, 18 Nov 2024 22:36:57 -0800 Subject: [PATCH] prediction interval bounds --- .../scaling/scaling_laws/fitting_functions.py | 16 +++++++++ olmo/scaling/scaling_laws/utils.py | 3 +- olmo/train.py | 1 - scripts/ladder_peteish.py | 3 +- scripts/scaling/stacked.py | 1 + scripts/scaling/step1.py | 34 +++++++++++-------- scripts/scaling/step2.py | 15 +++++--- 7 files changed, 52 insertions(+), 21 deletions(-) diff --git a/olmo/scaling/scaling_laws/fitting_functions.py b/olmo/scaling/scaling_laws/fitting_functions.py index fbe81fb48..6b063debd 100644 --- a/olmo/scaling/scaling_laws/fitting_functions.py +++ b/olmo/scaling/scaling_laws/fitting_functions.py @@ -328,6 +328,22 @@ def sigmoid(x, a, x0, k, b): o = a / (1 + np.exp(-k * (x - x0))) + b return o +def sigmoid_fit(x, p): + o = p[0] / (1 + np.exp(-p[2] * (x - p[1]))) + p[3] + return o + +def grad_sigmoid_fit(x, p): + exp_term = np.exp(-p[2] * (x - p[1])) + denom = (1 + exp_term) + o = p[0] / denom + p[3] + + grad_a = 1 / denom + grad_x0 = p[0] * p[2] * exp_term / (denom ** 2) + grad_k = p[0] * (x - p[1]) * exp_term / (denom ** 2) + grad_b = 1 + + return [grad_a, grad_x0, grad_k, grad_b] + def exponential_fit(x, a, b, c): return a * np.exp(b * x) + c diff --git a/olmo/scaling/scaling_laws/utils.py b/olmo/scaling/scaling_laws/utils.py index ed8f6d002..c3b866019 100644 --- a/olmo/scaling/scaling_laws/utils.py +++ b/olmo/scaling/scaling_laws/utils.py @@ -502,9 +502,10 @@ def get_step1_data_by_name(configs, task_name, y_metric="rc_bpb", moving_avg=1): reader = csv.DictReader(file_ref) rows = [row for row in reader] rows = rows[-moving_avg:] - ds, ys = [], [] + ds, ys, fs = [], [], [] for row in rows: d = int(float(row["throughput/total_tokens"])) + f = d * MODEL_FLOPS[name] y = np.average( [float(row[key]) for key in keys], weights=[WEIGHT_BY_KEY.get(key, 1.0) for key in keys] ) diff --git a/olmo/train.py b/olmo/train.py index f9d75be0d..481bbe676 100644 --- a/olmo/train.py +++ b/olmo/train.py @@ -1065,7 +1065,6 @@ def check_if_cancelled(self) -> Tuple[bool, int]: # Finally, check if someone canceled the run from W&B by adding the 'cancel' / 'canceled' tag.. # We won't see it in the run object. So we have to use the import/export API to check. from requests.exceptions import RequestException - from wandb.errors import CommError try: diff --git a/scripts/ladder_peteish.py b/scripts/ladder_peteish.py index 1079f08bd..ea6a181db 100644 --- a/scripts/ladder_peteish.py +++ b/scripts/ladder_peteish.py @@ -452,9 +452,10 @@ def flops_for_model(model_config: Union[ModelConfig, str]) -> int: def flops_cmd(args: argparse.Namespace): cfg = config_from_args(args) + from tqdm import tqdm + from olmo.eval import build_evaluator from olmo.tokenizer import Tokenizer - from tqdm import tqdm device = torch.device("cpu") tokenizer = Tokenizer.from_train_config(cfg) diff --git a/scripts/scaling/stacked.py b/scripts/scaling/stacked.py index cef1a7a73..c42176f77 100644 --- a/scripts/scaling/stacked.py +++ b/scripts/scaling/stacked.py @@ -1,5 +1,6 @@ import argparse import re + import matplotlib.pyplot as plt import numpy as np import seaborn as sns diff --git a/scripts/scaling/step1.py b/scripts/scaling/step1.py index ba1cf9557..93c358f3c 100644 --- a/scripts/scaling/step1.py +++ b/scripts/scaling/step1.py @@ -32,9 +32,11 @@ def parse_args(): ) parser.add_argument("--moving_avg", type=int, default=1, help="Moving average for bpb loss") parser.add_argument("-c", "--config-path", type=str, required=True, help="Path to config file") - parser.add_argument("-o", "--output-path", type=str, required=True, help="Path to write output figure") + parser.add_argument("-o", "--output-path", type=str, required=False, help="Path to write output figure") args = parser.parse_args() + if not args.keys: + args.keys = ["main"] args.keys = get_task_sets(args.keys) return args @@ -198,7 +200,9 @@ def main(): num_tasks = len(args.keys) num_cols = min(4, num_tasks) num_rows = (num_tasks + num_cols - 1) // num_cols - fig, axes = plt.subplots(num_rows, num_cols, figsize=(3.75 * num_cols, 3.25 * num_rows), squeeze=False) + + if args.output_path: + fig, axes = plt.subplots(num_rows, num_cols, figsize=(3.75 * num_cols, 3.25 * num_rows), squeeze=False) results = "Task Name | Actual Value | Predicted Value | Relative Error" @@ -208,7 +212,7 @@ def main(): ) # fit the parameters - coefficients, cov = fit_step1(data_by_name, args.accuracy) + coefficients, cov = fit_step1(data_by_name, args.y_metric) # make predictions predicted_data_by_name, plotted_predicted_data_by_name, (y, y_pred, rel_error) = predict_step1( @@ -216,18 +220,20 @@ def main(): ) results += f"\n{task_name} | {prettify(y, False)} | {prettify(y_pred, False)} | {prettify(rel_error)}" - plot_step1(configs, - data_by_name, - predicted_data_by_name, - plotted_predicted_data_by_name, - task_name, - str_chinchilla_n_d_fit(coefficients), - args.y_metric, - axes[i // num_cols][i % num_cols], - ) + if args.output_path: + plot_step1(configs, + data_by_name, + predicted_data_by_name, + plotted_predicted_data_by_name, + task_name, + str_chinchilla_n_d_fit(coefficients), + args.y_metric, + axes[i // num_cols][i % num_cols], + ) - fig.tight_layout() - fig.savefig(args.output_path, dpi=300) + if args.output_path: + fig.tight_layout() + fig.savefig(args.output_path, dpi=300) print(results) diff --git a/scripts/scaling/step2.py b/scripts/scaling/step2.py index 422c4d814..d4e7e4a95 100644 --- a/scripts/scaling/step2.py +++ b/scripts/scaling/step2.py @@ -7,7 +7,13 @@ import numpy as np import seaborn as sns -from olmo.scaling.scaling_laws.fitting_functions import get_coefficients, sigmoid, get_std_errors +from olmo.scaling.scaling_laws.fitting_functions import ( + get_coefficients, + get_std_errors, + grad_sigmoid_fit, + sigmoid, + sigmoid_fit, +) from olmo.scaling.scaling_laws.utils import ( get_final_configs, get_step2_data_by_name, @@ -106,7 +112,7 @@ def main(): "ys": [sigmoid(x, *coefficients) for x in xs], } - std_errors = get_std_errors(plotted_predicted_data["xs"], plotted_predicted_data["ys"], coefficients, cov) + std_errors = get_std_errors(plotted_predicted_data["xs"], plotted_predicted_data["ys"], coefficients, cov, sigmoid_fit, grad_sigmoid_fit) # Compute prediction intervals plotted_y_lower = plotted_predicted_data["ys"] - 1.96 * std_errors @@ -130,7 +136,8 @@ def main(): ) for x, y, y_pred in zip(data["xs"], data["ys"], predicted_data["ys"]): rel_error = (y_pred - y) / y - std_error = get_std_errors([x], [y_pred], coefficients, cov)[0] + std_error = get_std_errors([x], [y_pred], coefficients, cov, sigmoid_fit, grad_sigmoid_fit)[0] + delta_error = 1.96 * std_error y_lower = y_pred - 1.96 * std_error y_upper = y_pred + 1.96 * std_error rel_error_lower = (y_lower - y) / y @@ -150,7 +157,7 @@ def main(): color=config.color, ) results += ( - f"\n{task_name} | {prettify(y, False)} | {prettify(y_pred, False)} +/- {prettify(1.96 * std_error, False)} | {prettify(rel_error)}" + f"\n{task_name} | {prettify(y, False)} | {prettify(y_pred, False)} ± {prettify(delta_error, False)} | {prettify(rel_error)}" ) avg_unsigned_rel_err = np.mean(unsigned_rel_errs)