Skip to content

Commit

Permalink
Fix weight initialization for single-device pretrain/redpajama.py and…
Browse files Browse the repository at this point in the history
… pretrain/openwebtext.py (#759)
  • Loading branch information
awaelchli authored Nov 20, 2023
1 parent 9690b70 commit 759fcc6
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
4 changes: 2 additions & 2 deletions pretrain/openwebtext.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,9 @@ def main(fabric: L.Fabric, resume: Union[bool, Path]) -> None:
config = Config.from_name(model_name)
fabric.print(f"Loading model with {config.__dict__}")
t0 = time.perf_counter()
with fabric.init_module(empty_init=True):
with fabric.init_module(empty_init=(fabric.world_size > 1)):
model = GPT(config)
model.apply(model._init_weights)
model.apply(model._init_weights)

fabric.print(f"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.")
fabric.print(f"Total parameters {num_parameters(model):,}")
Expand Down
4 changes: 2 additions & 2 deletions pretrain/redpajama.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,9 @@ def main(fabric: L.Fabric, train_data_dir: Path, val_data_dir: Path, resume: Uni

fabric.print(f"Loading model with {config.__dict__}")
t0 = time.perf_counter()
with fabric.init_module(empty_init=True):
with fabric.init_module(empty_init=(fabric.world_size > 1)):
model = GPT(config)
model.apply(model._init_weights)
model.apply(model._init_weights)

fabric.print(f"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.")
fabric.print(f"Total parameters {num_parameters(model):,}")
Expand Down

0 comments on commit 759fcc6

Please sign in to comment.