diff --git a/test/allowlist_for_publicAPI.json b/test/allowlist_for_publicAPI.json index 05b1b8223cc038..5a8893e1fb0792 100644 --- a/test/allowlist_for_publicAPI.json +++ b/test/allowlist_for_publicAPI.json @@ -550,6 +550,7 @@ "ForkingPickler", "Union", "check_serializing_named_tensor", + "register", "register_after_fork" ], "torch.multiprocessing.spawn": [ diff --git a/torch/multiprocessing/reductions.py b/torch/multiprocessing/reductions.py index fa0818571a93c0..874ea70b0bec3d 100644 --- a/torch/multiprocessing/reductions.py +++ b/torch/multiprocessing/reductions.py @@ -2,7 +2,7 @@ import multiprocessing import os import threading -from multiprocessing.reduction import ForkingPickler +from multiprocessing.reduction import register from multiprocessing.util import register_after_fork from typing import Union @@ -626,22 +626,22 @@ def reduce_storage(storage): def init_reductions(): - ForkingPickler.register(torch.cuda.Event, reduce_event) + register(torch.cuda.Event, reduce_event) for t in torch._storage_classes: if t.__name__ == "UntypedStorage": - ForkingPickler.register(t, reduce_storage) + register(t, reduce_storage) else: - ForkingPickler.register(t, reduce_typed_storage_child) + register(t, reduce_typed_storage_child) - ForkingPickler.register(torch.storage.TypedStorage, reduce_typed_storage) + register(torch.storage.TypedStorage, reduce_typed_storage) for t in torch._tensor_classes: - ForkingPickler.register(t, reduce_tensor) + register(t, reduce_tensor) # TODO: Maybe this should be in tensor_classes? :) - ForkingPickler.register(torch.Tensor, reduce_tensor) + register(torch.Tensor, reduce_tensor) from torch.nn.parameter import Parameter - ForkingPickler.register(Parameter, reduce_tensor) + register(Parameter, reduce_tensor)