Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Max Recursion Error when using with lora #122

Open
Ar-Kareem opened this issue Oct 2, 2023 · 2 comments
Open

Max Recursion Error when using with lora #122

Ar-Kareem opened this issue Oct 2, 2023 · 2 comments

Comments

@Ar-Kareem
Copy link

I get the following error when attempting to use LoRa with Llama 2


File "/path/libraries/conda/lib/python3.9/site-packages/tensor_parallel/wrapper.py", line 75, in __getattr__
    return getattr(self.tp_wrapped_module, attr)
  [Previous line repeated 2979 more times]
RecursionError: maximum recursion depth exceeded

caused by the peft module executing: if getattr(model, "is_gradient_checkpointing", True):

Below is the minimal reproducible example that breaks when using tensor parallel and works when disabling it

import os
import functools
import torch
from transformers import LlamaTokenizer, LlamaConfig, LlamaForCausalLM
import tensor_parallel as tp
from peft import get_peft_model, LoraConfig

USE_TENSOR_PARALLEL = True
LLAMA_HF_PATH = "./models/llama2/llama_hf_converted/7b"


def spawn(main_fn, world_size):
    wrapped = functools.partial(_wrap_main_fn, main_fn=main_fn)
    torch.multiprocessing.spawn(wrapped, args=(world_size, ), nprocs=world_size, daemon=True)

def _wrap_main_fn(rank, world_size, main_fn):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12357'
    torch.distributed.init_process_group('nccl', rank=rank, world_size=world_size)
    main_fn(rank)
    torch.distributed.destroy_process_group()

def get_model(device):
    torch.set_default_dtype(torch.bfloat16)
    config = LlamaConfig.from_pretrained(LLAMA_HF_PATH)
    model = LlamaForCausalLM.from_pretrained(LLAMA_HF_PATH, config=config)
    model.half()

    torch.cuda.set_device(device)
    if USE_TENSOR_PARALLEL:
        tpmodel = tp.tensor_parallel(model, [device], distributed=True)
        model = tpmodel[0]
        lora_target_modules = ['q_proj.tp_wrapped_module', 'v_proj.tp_wrapped_module']  # target modules when model is wrapped by tp
    else:
        lora_target_modules = ['q_proj', 'v_proj']

    print('Before lora, is_gradient_checkpointing=', getattr(model, "is_gradient_checkpointing", None))
    peft_config = LoraConfig(
        # inference_mode=False, r=8, lora_alpha=32, lora_dropout=0.1
        inference_mode=False, r=int(8), lora_alpha=int(32), lora_dropout=float(0.01),
        target_modules=lora_target_modules,
    )
    model_peft = get_peft_model(model, peft_config)
    print('After lora, is_gradient_checkpointing=', getattr(model, "is_gradient_checkpointing", None))
    return model_peft


def main_fn(rank):
    devices = ['cuda:0', 'cuda:1']
    get_model(devices[rank])

if __name__ == '__main__':
    spawn(main_fn, world_size=2)

When settings USE_TENSOR_PARALLEL = False the code works, but when setting USE_TENSOR_PARALLEL = True I get the following error:

-- Process 0 terminated with the following error:                                                                                                                                   
Traceback (most recent call last):                                                                                                                                                  
  File "/path/libraries/conda/lib/python3.9/site-packages/torch/multiprocessing/spawn.py", line 69, in _wrap                                                            
    fn(i, *args)                                                                                                                                                                    
  File "/path/projects/silos/TEST_tensor_parallel.py", line 24, in _wrap_main_fn                                                                                        
    main_fn(rank)                                                                                                                                                                   
  File "/path/projects/silos/TEST_tensor_parallel.py", line 56, in main_fn                                                                                              
    get_model(devices[rank])                                                                                                                                                        
  File "/path/projects/silos/TEST_tensor_parallel.py", line 49, in get_model
    model_peft = get_peft_model(model, peft_config)
  File "/path/libraries/conda/lib/python3.9/site-packages/peft/mapping.py", line 105, in get_peft_model
    return PeftModel(model, peft_config, adapter_name=adapter_name)
  File "/path/libraries/conda/lib/python3.9/site-packages/peft/peft_model.py", line 120, in __init__
    if getattr(model, "is_gradient_checkpointing", True):
  File "/path/libraries/conda/lib/python3.9/site-packages/transformers/modeling_utils.py", line 1745, in is_gradient_checkpointing
    return any(hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing for m in self.modules())
  File "/path/libraries/conda/lib/python3.9/site-packages/transformers/modeling_utils.py", line 1745, in <genexpr>
File "/path/libraries/conda/lib/python3.9/site-packages/transformers/modeling_utils.py", line 1745, in <genexpr>
    return any(hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing for m in self.modules())
  File "/path/libraries/conda/lib/python3.9/site-packages/tensor_parallel/wrapper.py", line 75, in __getattr__
    return getattr(self.tp_wrapped_module, attr)
  File "/path/libraries/conda/lib/python3.9/site-packages/tensor_parallel/wrapper.py", line 75, in __getattr__
    return getattr(self.tp_wrapped_module, attr)
  File "/path/libraries/conda/lib/python3.9/site-packages/tensor_parallel/wrapper.py", line 75, in __getattr__
    return getattr(self.tp_wrapped_module, attr)
  [Previous line repeated 2979 more times]
RecursionError: maximum recursion depth exceeded
@Ar-Kareem
Copy link
Author

I've identified that the error happens exactly in the lora library at this setattr line:

https://github.com/huggingface/peft/blob/52ff0cde9f2cc64059e171c2cfd94512914c85df/src/peft/tuners/lora/model.py#L225

When setattr(parent, child_name, new_module) is executed, and parent is a tensor parallel wrapper, child_name is a string "tp_wrapped_module", new_module is a lora linear layer.

@Ar-Kareem
Copy link
Author

I think I fixed it.

Ar-Kareem added a commit to Ar-Kareem/tensor_parallel that referenced this issue Oct 2, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant