From e3b22b0e6a54eb52138b155eb2f8964000f46ac9 Mon Sep 17 00:00:00 2001 From: Kiuk Chung Date: Tue, 3 Sep 2024 09:38:09 -0700 Subject: [PATCH] [torch/multiprocessing] Use multiprocessing.reduction.register ForkingPickler.register to register custom tensor and storage reductions --- test/allowlist_for_publicAPI.json | 1 + torch/multiprocessing/reductions.py | 16 ++++++++-------- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/test/allowlist_for_publicAPI.json b/test/allowlist_for_publicAPI.json index 05b1b8223cc03..5a8893e1fb079 100644 --- a/test/allowlist_for_publicAPI.json +++ b/test/allowlist_for_publicAPI.json @@ -550,6 +550,7 @@ "ForkingPickler", "Union", "check_serializing_named_tensor", + "register", "register_after_fork" ], "torch.multiprocessing.spawn": [ diff --git a/torch/multiprocessing/reductions.py b/torch/multiprocessing/reductions.py index fa0818571a93c..874ea70b0bec3 100644 --- a/torch/multiprocessing/reductions.py +++ b/torch/multiprocessing/reductions.py @@ -2,7 +2,7 @@ import multiprocessing import os import threading -from multiprocessing.reduction import ForkingPickler +from multiprocessing.reduction import register from multiprocessing.util import register_after_fork from typing import Union @@ -626,22 +626,22 @@ def reduce_storage(storage): def init_reductions(): - ForkingPickler.register(torch.cuda.Event, reduce_event) + register(torch.cuda.Event, reduce_event) for t in torch._storage_classes: if t.__name__ == "UntypedStorage": - ForkingPickler.register(t, reduce_storage) + register(t, reduce_storage) else: - ForkingPickler.register(t, reduce_typed_storage_child) + register(t, reduce_typed_storage_child) - ForkingPickler.register(torch.storage.TypedStorage, reduce_typed_storage) + register(torch.storage.TypedStorage, reduce_typed_storage) for t in torch._tensor_classes: - ForkingPickler.register(t, reduce_tensor) + register(t, reduce_tensor) # TODO: Maybe this should be in tensor_classes? :) - ForkingPickler.register(torch.Tensor, reduce_tensor) + register(torch.Tensor, reduce_tensor) from torch.nn.parameter import Parameter - ForkingPickler.register(Parameter, reduce_tensor) + register(Parameter, reduce_tensor)