From 03f099a8dd720cb338448f77e011bb88942c975c Mon Sep 17 00:00:00 2001 From: Kiuk Chung Date: Tue, 3 Sep 2024 09:38:09 -0700 Subject: [PATCH] [torch/multiprocessing] Use multiprocessing.reduction.register ForkingPickler.register to register custom tensor and storage reductions --- torch/multiprocessing/reductions.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/torch/multiprocessing/reductions.py b/torch/multiprocessing/reductions.py index fa0818571a93c0..e11a707b7a3cb6 100644 --- a/torch/multiprocessing/reductions.py +++ b/torch/multiprocessing/reductions.py @@ -2,7 +2,7 @@ import multiprocessing import os import threading -from multiprocessing.reduction import ForkingPickler +from multiprocessing.reduction import register as _register from multiprocessing.util import register_after_fork from typing import Union @@ -626,22 +626,22 @@ def reduce_storage(storage): def init_reductions(): - ForkingPickler.register(torch.cuda.Event, reduce_event) + _register(torch.cuda.Event, reduce_event) for t in torch._storage_classes: if t.__name__ == "UntypedStorage": - ForkingPickler.register(t, reduce_storage) + _register(t, reduce_storage) else: - ForkingPickler.register(t, reduce_typed_storage_child) + _register(t, reduce_typed_storage_child) - ForkingPickler.register(torch.storage.TypedStorage, reduce_typed_storage) + _register(torch.storage.TypedStorage, reduce_typed_storage) for t in torch._tensor_classes: - ForkingPickler.register(t, reduce_tensor) + _register(t, reduce_tensor) # TODO: Maybe this should be in tensor_classes? :) - ForkingPickler.register(torch.Tensor, reduce_tensor) + _register(torch.Tensor, reduce_tensor) from torch.nn.parameter import Parameter - ForkingPickler.register(Parameter, reduce_tensor) + _register(Parameter, reduce_tensor)