Skip to content

Commit

Permalink
Fixup Zero3 + save_model (#3146)
Browse files Browse the repository at this point in the history
* Fixup + test

* Easier diff

* Move os.makedirs to under return statement
  • Loading branch information
muellerzr committed Oct 12, 2024
1 parent a650d1b commit 74381f9
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 2 deletions.
7 changes: 5 additions & 2 deletions src/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2860,8 +2860,6 @@ def save_model(
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
return

os.makedirs(save_directory, exist_ok=True)

# get the state_dict of the model
if any(
[
Expand All @@ -2876,6 +2874,11 @@ def save_model(
raise RuntimeError("You can't save the model since some parameters are on the meta device.")
state_dict = self.get_state_dict(model)

# Case: DeepSpeed zero3 gets gathered and `state_dict` is empty
if state_dict is None:
return
os.makedirs(save_directory, exist_ok=True)

if safe_serialization:
state_dict = clean_state_dict_for_safetensors(state_dict)
weights_name = SAFE_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import argparse
import json
import os
from pathlib import Path

import evaluate
import torch
Expand Down Expand Up @@ -205,6 +206,13 @@ def training_function(config, args):
if accelerator.is_main_process:
with open(os.path.join(args.output_dir, "all_results.json"), "w") as f:
json.dump(performance_metric, f)

# Finally try saving the model
accelerator.save_model(model, args.output_dir)
accelerator.wait_for_everyone()
assert Path(
args.output_dir, "model.safetensors"
).exists(), "Model was not saved when calling `Accelerator.save_model`"
accelerator.end_training()


Expand Down

0 comments on commit 74381f9

Please sign in to comment.