From 74381f99100d48a5986e18a3dfdd65e59a1fe8d1 Mon Sep 17 00:00:00 2001 From: Zach Mueller Date: Thu, 10 Oct 2024 12:54:14 -0400 Subject: [PATCH] Fixup Zero3 + `save_model` (#3146) * Fixup + test * Easier diff * Move os.makedirs to under return statement --- src/accelerate/accelerator.py | 7 +++++-- .../test_utils/scripts/external_deps/test_performance.py | 8 ++++++++ 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/src/accelerate/accelerator.py b/src/accelerate/accelerator.py index de0ce0f374d..5362f4d9710 100755 --- a/src/accelerate/accelerator.py +++ b/src/accelerate/accelerator.py @@ -2860,8 +2860,6 @@ def save_model( logger.error(f"Provided path ({save_directory}) should be a directory, not a file") return - os.makedirs(save_directory, exist_ok=True) - # get the state_dict of the model if any( [ @@ -2876,6 +2874,11 @@ def save_model( raise RuntimeError("You can't save the model since some parameters are on the meta device.") state_dict = self.get_state_dict(model) + # Case: DeepSpeed zero3 gets gathered and `state_dict` is empty + if state_dict is None: + return + os.makedirs(save_directory, exist_ok=True) + if safe_serialization: state_dict = clean_state_dict_for_safetensors(state_dict) weights_name = SAFE_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME diff --git a/src/accelerate/test_utils/scripts/external_deps/test_performance.py b/src/accelerate/test_utils/scripts/external_deps/test_performance.py index f1f7ddd579f..57fb1a01884 100644 --- a/src/accelerate/test_utils/scripts/external_deps/test_performance.py +++ b/src/accelerate/test_utils/scripts/external_deps/test_performance.py @@ -14,6 +14,7 @@ import argparse import json import os +from pathlib import Path import evaluate import torch @@ -205,6 +206,13 @@ def training_function(config, args): if accelerator.is_main_process: with open(os.path.join(args.output_dir, "all_results.json"), "w") as f: json.dump(performance_metric, f) + + # Finally try saving the model + accelerator.save_model(model, args.output_dir) + accelerator.wait_for_everyone() + assert Path( + args.output_dir, "model.safetensors" + ).exists(), "Model was not saved when calling `Accelerator.save_model`" accelerator.end_training()