Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Question on custom models #88

Open
vince62s opened this issue Jun 13, 2023 · 23 comments
Open

Question on custom models #88

vince62s opened this issue Jun 13, 2023 · 23 comments

Comments

@vince62s
Copy link

Hi,
without using transformers / accelerate blablabla, what are the constraints on the model to be tensor paralelizable ?

does it need to be a nn.Sequential ? does input dimensions need to be always in the same order ?

I am trying to load a model on two gpus but only the first is being allocated. (both are visible)

@BlackSamorez
Copy link
Owner

Hi @vince62s !
For a model to be tensor-paralleliazable automatically it should consist of basic Linear/Convolution/Embedding PyTorch layers. Other basic layers will still work but they will be replicated on each GPU. The problem with autoconfig may arise in runtime if those basic layers' weights are accessed directly (e.g. with linear.weight) since those accesses are not taken into account when planning device communications.
Otherwise, everything should generally work, especially when deploying the model, since it should deal with any PyTorch parameters and buffers no problem. If it's not the case for you, feel free to submit a more detailed report on your problem and I'll take a look.

@vince62s
Copy link
Author

ok, let me explain a bit.
When you use bitsandbytes to quantize (8it or 4bit) you need to load data in a layer and then call cuda() which triggers the quantization. Also you cannot call cuda() twice otherwise it screws up everything.
To overcome cpu ram usage when loading a big model here is what I do.
I build an empty model in cpu ram, then module by module I load data and move it to gpu module by module.
here is where it happens:
https://github.com/OpenNMT/OpenNMT-py/blob/master/onmt/models/model.py#L71-L92
now, I tried to use tensor_parallel at each module level (by modifying this code) and it only uses one gpu.

@BlackSamorez
Copy link
Owner

I'll have a more detailed look later, but right now you can try and take a look at how quantized deployment is done in tensor_parallel int4 demo on kaggle.

@BlackSamorez
Copy link
Owner

Are you quantizing the model before making it tensor_parallel? Because that wouldn't work. What you need to do is to create a model on CPU/meta, then make it tensor_parallel and only then dispatch it on GPUs (quantization happens during dispatch).
Here's a simplified example of how I did it for the Kaggle int4 demo:

import torch
import transformers
import accelerate
from transformers.utils.bitsandbytes import replace_with_bnb_linear

import tensor_parallel as tp

# Initialize a model on meta device (zero allocations)
with accelerate.init_empty_weights():
    model = transformers.AutoModelForCausalLM.from_config(...)

# Make it tensor_parallel
model = tp.TensorParallelPreTrainedModel(model)

# Mark linear layers to be quantized
model = replace_with_bnb_linear(
    model,
    quantization_config=transformers.utils.quantization_config.BitsAndBytesConfig(...),
)

# Load, dispatch and quantize the weights
device_map = tp.infer_sharded_device_map(model) # infer where to put each weight
for <EACH CHECKPOINT>: 
    converted_state_dict = tp.convert_state_dict( # <- tensor_parallel helper function. 
        torch.load(<CHECKPOINT PATH>),            #    Creates a tensor_parallel checkpoint form a normal one
        ...
    )
        
    # Dispatch the checkpoint
    for param_name, param in converted_state_dict.items():
        module_name = param_name
        while len(module_name) > 0 and module_name not in device_map:
            module_name = ".".join(module_name.split(".")[:-1])
        accelerate.utils.set_module_tensor_to_device(model, param_name, device_map[module_name], value=param)

From what I see, you're trying to use an already written dispatch code (from within load_state_dict). It probably will not work with tensor_parallel since it doesn't convert the state_dict and (obviously) doesn't check for internal tensor_parallel wrapper prefixes.
Please try and use accelerate to dispatch the model as shown in the int4 demo. If you have any questions I'll be happy to help!

@vince62s
Copy link
Author

I'll try but then you confirm that this line model = tp.TensorParallelPreTrainedModel(model) does not move anything to gpu, right ?

@BlackSamorez
Copy link
Owner

It does, unless the model is on meta device. Then it remains on meta device.

@vince62s
Copy link
Author

so it HAS to be on meta otherwise quantization won't work in the snippet above, am I correct ?

@vince62s
Copy link
Author

this step is ok
model = tp.tensor_parallel(model)
next is not

File "/home/vincent/nlp/OpenNMT-py/onmt/model_builder.py", line 418, in build_model
device_map = tp.infer_sharded_device_map(model)
File "/home/vincent/miniconda3/envs/pytorch1.13/lib/python3.10/site-packages/tensor_parallel/dispatch.py", line 61, in infer_sharded_device_map
id_to_device[id(param)] = tp_model.devices[infer_sharded_data_device_id(name)]
File "/home/vincent/miniconda3/envs/pytorch1.13/lib/python3.10/site-packages/tensor_parallel/dispatch.py", line 42, in infer_sharded_data_device_id
raise KeyError(
KeyError: "Can't decide where to put {name} in a sharded model state dict. Are you sure it's a sharded dict?"

@BlackSamorez
Copy link
Owner

Is the model unfrozen? Is Using ZeRO-3 sharding for ... non tensor-parallel parameters ever printed. It means that ZeRO-3 is used for some trainable but unsplittable parameters (e.g. LayerNorm weights). This functionality wasn't properly integrated with manual deploy (since it's incompatible with meta devices).
For now you could try and pass sharded=False to tp.tensor_parallel and then call either model = tp.Sharded(model) if model is TensorParallel or model.wrapped_model = tp.Sharded(model.wrapped_model) if model is TensorParallelPreTrainedModel.
Meanwhile I'll think of a way to make this process simpler and more obvious.

@vince62s
Copy link
Author

vince62s commented Jun 19, 2023

I am not seeing such print out. I tried various things as you mentioned but no luck.
can't figure out how to make this work.

my steps (without talking about tensor_parallel) are as follows:

  1. build the model on meta with empty / skip_init

  2. replace nn.Linear() by a combination of bnb layers / Lora Layers

  3. Iterating over modules, one by one, loading data from state_dict tensors, and moving each to GPU (to trigger quant)

  4. train() or infer()

I tried to call tp.tensor_parallel between 1) and 2) which seems ok but then what to do with tp.Sharded or tp.infer_sharded_device_map is a mystery to me.

@BlackSamorez
Copy link
Owner

Firstly, about tp.infer_sharded_device_map: it has nothing to do with tp.Sharded. I should probably rename it. What it does is that it simply creates a mapping of parameter name->parameter device. I created this function to simplify deployment with accelerate.

Secondly, about tp.Sharded: when trainable parameters are replicated on multiple GPUs with tensor parallelism (e.g. LayernNorm weights can't be split) we average them once in a while for them not to diverge. To do this, tp.TensorParallel can be wrapped with tp.Sharded with model = tp.Sharded(model). This needs to be done after model weights are loaded in and after they are properly frozen/unfrozen (to decide which weights need averaging). But to wrap tp.TensorParallelPreTrainedModel you need to do model.wrapped_model = tp.Sharded(model.wrapped_model) cause reasons. It's a hot mess and I was hoping to refactor it soon.

tldr: call tp.tensor_parallel(..., sharded=False) between 1) and 2). In 3) convert state dicts with tp.convert_state_dict and use tp.infer_sharded_device_map to determine on which GPU to put resulting weights. After 4) use tp.Sharded as described above.

@vince62s
Copy link
Author

but the issue is tp.infer_sharded_device_map works when being called before the bnb replacement (getting the mappign fine), but after I am getting the same error as above:
raise KeyError(
KeyError: "Can't decide where to put {name} in a sharded model state dict. Are you sure it's a sharded dict?"

@BlackSamorez
Copy link
Owner

BlackSamorez commented Jun 19, 2023

I'll merge a PR #90 fixing this message (and possibly a problem behind it) today. Then I'll be able to tell what goes wrong there.

@vince62s
Copy link
Author

well to be sure I just tested the kaggle notebook and to me it is still very unclear what is going on:
I am loading as a test the decapoda/llama-7b-hf
when printing the device_map, it returns all params duplicated on each device, for instance:
'wrapped_model.module_shards.0.model.layers.0.self_attn.tp_wrapped_module.q_proj.weight': device(type='cuda', index=0)
'wrapped_model.module_shards.1.model.layers.0.self_attn.tp_wrapped_module.q_proj.weight': device(type='cuda', index=1)

after the state_dict load both GPU are loaded with 13-14GB which means each of them carries the full model.

even more trouble some, in 4-bit I would expect the model footprint to be 4GB, split over two GPUs, hence 2GB each

@BlackSamorez
Copy link
Owner

'wrapped_model.module_shards.0.model.layers.0.self_attn.tp_wrapped_module.q_proj.weight': device(type='cuda', index=0)
'wrapped_model.module_shards.1.model.layers.0.self_attn.tp_wrapped_module.q_proj.weight': device(type='cuda', index=1)

Those are module_shards.0 and module_shards.1. They are supposed to have identical structure since those are two parts of the same model put on two devices.
As for 13-14GB for decapoda/llama-7b-hf: this should not be the case. The original demo showcases decapoda-research/llama-30b-hf which wouldn't physically fit if it wasn't properly quantized and split between two GPUs.

@BlackSamorez
Copy link
Owner

By the way v1.2.6 has just been released which fixes the error message discussed above and, hopefully, some other aspects of dispatch.

@vince62s
Copy link
Author

vince62s commented Jun 19, 2023

Those are module_shards.0 and module_shards.1. They are supposed to have identical structure since those are two parts of the same model put on two devices.

Then I don't undertsand anything. How is the model sharded on the 2 devices? layer N going to device 0 and layer M going to device 1? or part of layer N going device 0 and otehr part of layer N going to device 1 ?
I don't undertsand based on this device_map how things are dispatched.

@vince62s
Copy link
Author

As for 13-14GB for decapoda/llama-7b-hf: this should not be the case. The original demo showcases decapoda-research/llama-30b-hf which wouldn't physically fit if it wasn't properly quantized and split between two GPUs.

I updated the accelerate library and the footprint is ok now. so I need to see how to replicate this with my code.

@BlackSamorez
Copy link
Owner

As for 13-14GB for decapoda/llama-7b-hf: this should not be the case. The original demo showcases decapoda-research/llama-30b-hf which wouldn't physically fit if it wasn't properly quantized and split between two GPUs.

I updated the accelerate library and the footprint is ok now. so I need to see how to replicate this with my code.

Yes, fp4 support for dispatch was added only like a few weeks ago.

@BlackSamorez
Copy link
Owner

Those are module_shards.0 and module_shards.1. They are supposed to have identical structure since those are two parts of the same model put on two devices.

Then I don't undertsand anything. How is the model sharded on the 2 devices? layer N going to device 0 and layer M going to device 1? or part of layer N going device 0 and otehr part of layer N going to device 1 ? I don't undertsand based on this device_map how things are dispatched.

The entire point of this library is that each layer is split individually and each GPU has only a portion of each layer. This way all of the GPUs can be utilized simultaneously when using the model.

@vince62s
Copy link
Author

vince62s commented Jun 19, 2023

While digging I found the following issue:

when using the notebook I am seeing that "tp_wrapped_module" comes before the layer name o_proj for instance.
when using my code it comes AFTER, just before "weight"

It has an impact in the set_module_tensor_to_device() function which uses getattr(module, split) recursively.

I have the feeling that tp_wrapped_module does not handle properly the str method.

When I don't use TP the following code gives me the correct class in both cases:

            module = model
            tensor_name = 'module_shards.0.decoder.transformer_layers.0.feed_forward.w_1.tp_wrapped_module.weight'
            if "." in tensor_name:
                splits = tensor_name.split(".")
                print(splits)
                for split in splits[:-1]:
                    new_module = getattr(module, split)
                    if new_module is None:
                        raise ValueError(f"{module} has no attribute {split}.")
                    module = new_module
                tensor_name = splits[-1]
            print("module: ", module)
            print("***********************")
            
            for name, module in model.named_modules():
                if name == 'module_shards.0.decoder.transformer_layers.0.feed_forward.w_1.tp_wrapped_module':
                    print(module.__class__.__name__)

When I switch to tp for my model, then the first print spits out "Linear" instead of "Linea4bit"
I think this does not happen with HF example because the tp_wrapped_module is NOT the last one in the chain.

@vince62s
Copy link
Author

On another note, would that be possible to have the same behavior with a model on "cpu" as on "meta".
If we build an empty model on "cpu" (with the skipt_init function) then we could also call tp.tensor_parallel() without actually moving to cuda and dispatch manually after that. The difference would be to avoid using that hacky method set_module_tensor_to_device(), we could just use param.to(device) it works.

@BlackSamorez
Copy link
Owner

On another note, would that be possible to have the same behavior with a model on "cpu" as on "meta".
If we build an empty model on "cpu" (with the skipt_init function) then we could also call tp.tensor_parallel() without actually moving to cuda and dispatch manually after that. The difference would be to avoid using that hacky method set_module_tensor_to_device(), we could just use param.to(device) it works.

That could come in handy. I'll see what I can do.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants