diff --git a/src/accelerate/utils/other.py b/src/accelerate/utils/other.py index 8a518519e44..4eaa89f9603 100644 --- a/src/accelerate/utils/other.py +++ b/src/accelerate/utils/other.py @@ -23,7 +23,6 @@ import numpy as np import torch -from packaging import version from packaging.version import Version from safetensors.torch import save_file as safe_save_file @@ -216,7 +215,7 @@ def save(obj, f, save_on_each_node: bool = False, safe_serialization: bool = Fal # The following are considered "safe" globals to reconstruct various types of objects when using `weights_only=True` # These should be added and then removed after loading in the file -np_core = np._core if version.parse(np.__version__) >= version.parse("2.0.0") else np.core +np_core = np._core if Version(np.__version__) >= Version("2.0.0") else np.core TORCH_SAFE_GLOBALS = [ # numpy arrays are just numbers, not objects, so we can reconstruct them safely np_core.multiarray._reconstruct,