From 759fcc696b7f3710552d244a0e6e3b6169d19210 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 20 Nov 2023 21:36:03 +0100 Subject: [PATCH] Fix weight initialization for single-device pretrain/redpajama.py and pretrain/openwebtext.py (#759) --- pretrain/openwebtext.py | 4 ++-- pretrain/redpajama.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/pretrain/openwebtext.py b/pretrain/openwebtext.py index e431a5b10e..2ea6c096fb 100644 --- a/pretrain/openwebtext.py +++ b/pretrain/openwebtext.py @@ -77,9 +77,9 @@ def main(fabric: L.Fabric, resume: Union[bool, Path]) -> None: config = Config.from_name(model_name) fabric.print(f"Loading model with {config.__dict__}") t0 = time.perf_counter() - with fabric.init_module(empty_init=True): + with fabric.init_module(empty_init=(fabric.world_size > 1)): model = GPT(config) - model.apply(model._init_weights) + model.apply(model._init_weights) fabric.print(f"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.") fabric.print(f"Total parameters {num_parameters(model):,}") diff --git a/pretrain/redpajama.py b/pretrain/redpajama.py index f8edd79312..7110a08522 100644 --- a/pretrain/redpajama.py +++ b/pretrain/redpajama.py @@ -107,9 +107,9 @@ def main(fabric: L.Fabric, train_data_dir: Path, val_data_dir: Path, resume: Uni fabric.print(f"Loading model with {config.__dict__}") t0 = time.perf_counter() - with fabric.init_module(empty_init=True): + with fabric.init_module(empty_init=(fabric.world_size > 1)): model = GPT(config) - model.apply(model._init_weights) + model.apply(model._init_weights) fabric.print(f"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.") fabric.print(f"Total parameters {num_parameters(model):,}")