Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Question] Does TorchRec supports dist checking point / (DCP) #2534

Open
JacoCheung opened this issue Nov 4, 2024 · 3 comments
Open

[Question] Does TorchRec supports dist checking point / (DCP) #2534

JacoCheung opened this issue Nov 4, 2024 · 3 comments

Comments

@JacoCheung
Copy link

Hi, team, I would like to know how to load and dump a sharded embedding collection via state_dict. Basically

  1. How many files should I save? Should each rank have an exclusive sharding file or only single rank collectively gather the whole embedding and stores as one file? How should I handle the case where both DP and MP are applied?

  2. If each rank maintains a sharding file, how can I load and re-shard in a new distributed environment where the number of GPUs vary from the saved model.

  3. If there is one saved file, how should I load and re-shard especially in multi-node env?

It's more helpful if anyone can provide a sample code! Thanks!

@iamzainhuda
Copy link
Contributor

  1. You should have the checkpoint per rank, since we do not collectively gather the whole embedding. If you wanted to, you could do that and then reconstruct the sharded state dict by the original sharding plan. Although I wouldn't recommend this. You should be able to use torch.distributed.checkpoint utilities for TorchRec models.

  2. For changing the number of GPU's, you would need to understand how the sharding changes. Am I correct in understanding you would want to go from a model sharded on 8 GPUs to load onto 16 GPUs? Resharding here would be important, which you would also have to do yourself before you load. TorchRec doesn't have any utilities surrounding this.

  3. You can broadcast the parameters/state to the other ranks as you load, as a pre_load_state_dict_hook on top of DMP.

@JacoCheung
Copy link
Author

Thansk for your reply! close it as it's completed

@JacoCheung
Copy link
Author

JacoCheung commented Nov 21, 2024

Sorry @iamzainhuda , I have to reopen it because I encountered another issue regarding Adam optimizer. Say if I have an embeddingcollection whose optimizer is fused with backwawrd. But the optimizer.state_dict() returns nothing but only "mometum1/2" tensor, other state like lr, decay are gone. I think the problem is here.


import os
import sys
sys.path.append(os.path.abspath('/home/scratch.junzhang_sw/workspace/torchrec'))
import torch
import torchrec
import torch.distributed as dist

os.environ["RANK"] = "0"
os.environ["WORLD_SIZE"] = "1"
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "29500"
dist.init_process_group(backend="nccl")
ebc = torchrec.EmbeddingCollection(
    device=torch.device("meta"),
    tables=[
        torchrec.EmbeddingConfig(
            name="product_table",
            embedding_dim=64,
            num_embeddings=4096,
            feature_names=["product"],
        ),
        torchrec.EmbeddingConfig(
            name="user_table",
            embedding_dim=64,
            num_embeddings=4096,
            feature_names=["user"],
        )
    ]
)

from torchrec.optim.apply_optimizer_in_backward import apply_optimizer_in_backward

apply_optimizer_in_backward(
    optimizer_class=torch.optim.Adam,
    params=ebc.parameters(),
    optimizer_kwargs={"lr": 0.02},
)

from torchrec.distributed.fbgemm_qcomm_codec import get_qcomm_codecs_registry, QCommsConfig, CommType
from torchrec.distributed.embedding import EmbeddingCollectionSharder

sharder = EmbeddingCollectionSharder(
    # qcomm_codecs_registry=get_qcomm_codecs_registry(
    #         qcomms_config=QCommsConfig(
    #             forward_precision=CommType.FP16,
    #             backward_precision=CommType.BF16,
                  use_index_dedup=True,
    #         )
    #     )
)
dp_rank = dist.get_rank()
model = torchrec.distributed.DistributedModelParallel(ebc, sharders=[sharder], device=torch.device("cuda"))
mb = torchrec.KeyedJaggedTensor(
    keys = ["product", "user"],
    values = torch.tensor([101, 201, 101, 404, 404, 606, 606, 606]).cuda(),
    lengths = torch.tensor([2, 0, 1, 1, 1, 3], dtype=torch.int64).cuda(),
)
import pdb;pdb.set_trace()
ret = model(mb) # => this is awaitable
product = ret['product'] # implicitly call awaitable.wait()
# import pdb;pdb.set_trace()

Above model gives me optimizer state like:

>>> model.fused_optimizer.state_dict()['state']['embeddings.product_table.weight'].keys()
dict_keys(['product_table.momentum1', 'product_table.exp_avg_sq']) 
# only `state` key.value, there are no param_groups that contain the lr, beta1 etc.

@JacoCheung JacoCheung reopened this Nov 21, 2024
@JacoCheung JacoCheung changed the title [Question] Does TorchRec supports checking point / (load/save) [Question] Does TorchRec supports dist checking point / (DCP) Nov 23, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants