Skip to content

Commit

Permalink
[torch/multiprocessing] Use multiprocessing.reduction.register Forkin…
Browse files Browse the repository at this point in the history
…gPickler.register to register custom tensor and storage reductions
  • Loading branch information
kiukchung committed Sep 16, 2024
1 parent 0aa41eb commit 5c4c662
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions torch/multiprocessing/reductions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import multiprocessing
import os
import threading
from multiprocessing.reduction import ForkingPickler
from multiprocessing import reduction
from multiprocessing.util import register_after_fork
from typing import Union

Expand Down Expand Up @@ -626,22 +626,22 @@ def reduce_storage(storage):


def init_reductions():
ForkingPickler.register(torch.cuda.Event, reduce_event)
reduction.register(torch.cuda.Event, reduce_event)

for t in torch._storage_classes:
if t.__name__ == "UntypedStorage":
ForkingPickler.register(t, reduce_storage)
reduction.register(t, reduce_storage)
else:
ForkingPickler.register(t, reduce_typed_storage_child)
reduction.register(t, reduce_typed_storage_child)

ForkingPickler.register(torch.storage.TypedStorage, reduce_typed_storage)
reduction.register(torch.storage.TypedStorage, reduce_typed_storage)

for t in torch._tensor_classes:
ForkingPickler.register(t, reduce_tensor)
reduction.register(t, reduce_tensor)

# TODO: Maybe this should be in tensor_classes? :)
ForkingPickler.register(torch.Tensor, reduce_tensor)
reduction.register(torch.Tensor, reduce_tensor)

from torch.nn.parameter import Parameter

ForkingPickler.register(Parameter, reduce_tensor)
reduction.register(Parameter, reduce_tensor)

0 comments on commit 5c4c662

Please sign in to comment.