-
Notifications
You must be signed in to change notification settings - Fork 39
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
tensor_parallel method distributed=True #114
Comments
With model_tp = tensor_parallel(
module=model, device_ids=["cuda:{}".format(LOCAL_RANK)], distributed=True
) This way each worker will have it's own GPU and a portion of the model, and they will all communicate through Here's a full example of how you could initialize and use the model: import os
import argparse
import torch
import torch.distributed as dist
from torch.distributed.elastic.multiprocessing.errors import record
from transformers import AutoModelForCausalLM
from tensor_parallel import tensor_parallel
NAME = "bigscience/bloom-560m"
# Environment variables set by torch.distributed.launch
LOCAL_RANK = int(os.environ['LOCAL_RANK'])
WORLD_SIZE = int(os.environ['WORLD_SIZE'])
WORLD_RANK = int(os.environ['RANK'])
def init_process(backend):
""" Initialize the distributed environment. """
dist.init_process_group(backend, rank=LOCAL_RANK, world_size=WORLD_SIZE)
run(backend)
@record
def run(backend):
torch.manual_seed(0)
if backend == 'nccl':
device=torch.device("cuda:{}".format(LOCAL_RANK))
else:
device=torch.device("cpu")
model = AutoModelForCausalLM.from_pretrained(NAME).to(device)
inp1 = torch.randint(1, 1000, size=(2, 3), device=device)
inp2 = torch.randint(1, 1000, size=(2, 1), device=device)
inp3 = torch.randint(1, 1000, size=(2, 2), device=device)
out1_ref = model(inp1, use_cache=True, output_hidden_states=True)
out2_ref = model(inp2, use_cache=True, past_key_values=out1_ref.past_key_values)
out3_ref = model(inp3, use_cache=True, past_key_values=out2_ref.past_key_values)
model_tp = tensor_parallel(
module=model, device_ids=[device], distributed=True
)
del model
out1 = model_tp(inp1, use_cache=True, output_hidden_states=True)
# print([key for key in out1])
out2 = model_tp(inp2, use_cache=True, past_key_values=out1.past_key_values)
out3 = model_tp(inp3, use_cache=True, past_key_values=out2.past_key_values)
torch.testing.assert_close(out1_ref.hidden_states[-1], out1.hidden_states[-1], atol=3e-3, rtol=1e-05)
# print(out1_ref.logits, out1.logits)
torch.testing.assert_close(out1_ref.logits, out1.logits, atol=3e-3, rtol=1e-05)
torch.testing.assert_close(out2_ref.logits, out2.logits, atol=3e-3, rtol=1e-05)
torch.testing.assert_close(out3_ref.logits, out3.logits, atol=3e-3, rtol=1e-05)
print(f"Everything seems to work at worker {LOCAL_RANK}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--backend", type=str, default="nccl", choices=['nccl', 'gloo'])
args = parser.parse_args()
init_process(backend=args.backend) |
Ah I see what I was not doing right now, thanks for sharing this example and so quickly too! It looks like this is for distributed and tensor parallelism in one, is that right?
These steps work with I tried your method (no init empty weights , device map or set_module_tensor_to_device), but this seems to consume all of my servers RAM while loading the weights prior to putting them onto the GPU, since I have 4 processes this maxes out and causes the script to crash. If I'm right in thinking that distributed=True leads to a copy of the model on each device then you can close this issue.
I've left my older code reg moving tensors to devices in comments for your convenience. |
Hey, really liking this library !
I'm wanting to benchmark the difference between running TP normally (as in the demo notebook) and adding the distributed=True flag to the tp.tensor_parallel method; which I think will use torch.distributed rather than the NCCL backend for process communication between devices, please correct me if I understood this wrong.
I can see that in the Docstring that torchrun is required, ie starting the script with
'torchrun --nproc_per_node=4'
.I have done this and also played around with init process group within the code in combination with torchrun.
To be honest, I'm struggling to understand how to implement this with distributed=True, especially since with this flag, the device_ids parameter requires only one GPU to be passed.
I'm working with 4x 3090 GPUs that have NVLink and the Falcon-40B-Instruct model.
Any guidance on how to set this up would be much appreciated.
Let me know if you need any additional info from me.
Thanks and keep going !
The text was updated successfully, but these errors were encountered: