From b8e4e6b1eb29cdd515f2259df01f7b18bfb73b89 Mon Sep 17 00:00:00 2001 From: Kiuk Chung Date: Tue, 3 Sep 2024 09:38:09 -0700 Subject: [PATCH] [torch/multiprocessing] Use multiprocessing.reduction.register ForkingPickler.register to register custom tensor and storage reductions --- torch/multiprocessing/reductions.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/torch/multiprocessing/reductions.py b/torch/multiprocessing/reductions.py index fa0818571a93c..e4f751e459c75 100644 --- a/torch/multiprocessing/reductions.py +++ b/torch/multiprocessing/reductions.py @@ -2,7 +2,7 @@ import multiprocessing import os import threading -from multiprocessing.reduction import ForkingPickler +from multiprocessing.reduction import register from multiprocessing.util import register_after_fork from typing import Union @@ -36,7 +36,7 @@ def __init__(self, storage): # might be cleared during Python shutdown before this module is cleared. self._free_weak_ref = torch.Storage._free_weak_ref # type: ignore[attr-defined] - @classmethod + @classmethodior def from_weakref(cls, cdata): instance = cls.__new__(cls) instance.cdata = cdata @@ -626,22 +626,22 @@ def reduce_storage(storage): def init_reductions(): - ForkingPickler.register(torch.cuda.Event, reduce_event) + register(torch.cuda.Event, reduce_event) for t in torch._storage_classes: if t.__name__ == "UntypedStorage": - ForkingPickler.register(t, reduce_storage) + register(t, reduce_storage) else: - ForkingPickler.register(t, reduce_typed_storage_child) + register(t, reduce_typed_storage_child) - ForkingPickler.register(torch.storage.TypedStorage, reduce_typed_storage) + register(torch.storage.TypedStorage, reduce_typed_storage) for t in torch._tensor_classes: - ForkingPickler.register(t, reduce_tensor) + register(t, reduce_tensor) # TODO: Maybe this should be in tensor_classes? :) - ForkingPickler.register(torch.Tensor, reduce_tensor) + register(torch.Tensor, reduce_tensor) from torch.nn.parameter import Parameter - ForkingPickler.register(Parameter, reduce_tensor) + register(Parameter, reduce_tensor)