diff --git a/auto_gptq/modeling/_base.py b/auto_gptq/modeling/_base.py index 885eabf..d0b1570 100644 --- a/auto_gptq/modeling/_base.py +++ b/auto_gptq/modeling/_base.py @@ -14,6 +14,7 @@ from safetensors.torch import save_file as safe_save from tqdm import tqdm from transformers import AutoConfig, AutoModelForCausalLM, PreTrainedModel +from peft import AutoPeftModelForCausalLM from transformers.modeling_utils import no_init_weights from transformers.utils.generic import ContextManagers from transformers.utils.hub import ( @@ -22,6 +23,8 @@ create_commit, create_repo, ) +from peft.tuners import rosa +from bitsandbytes.functional import dequantize_4bit as dq4bit from ..nn_modules._fused_base import FusedBaseAttentionModule, FusedBaseMLPModule from ..nn_modules.qlinear import GeneralQuantLinear @@ -171,6 +174,46 @@ def _convert_tensor_to_list(tensor): return new_examples + def swap_module(self, network, module_name, new_module): + name_parts = module_name.split('.') + parent = network + for part in name_parts[:-1]: + if part.isdigit(): + parent = parent[int(part)] + else: + parent = getattr(parent, part) + + last_part = name_parts[-1] + if last_part.isdigit(): + parent[int(last_part)] = new_module + else: + setattr(parent, last_part, new_module) + + def merge_rosa_layer(self, module, target_dtype=torch.bfloat16): + assert isinstance(module, rosa.RosaLayer) + + adapt = module.get_delta_weight('default').to(target_dtype) + base = module.find_weight() + if isinstance(module, rosa.Linear4bit): + base = dq4bit(base.data, base.quant_state).to(target_dtype) + else: + assert isinstance(module, rosa.Linear) + base = base.to(target_dtype) + + merged_module = torch.nn.Linear( + module.in_features, + module.out_features, + bias=module.base_layer.bias is not None, + device=adapt.device, + dtype=target_dtype + ) + + merged_module.weight.mul_(0).add_(base + adapt) + if module.base_layer.bias is not None: + merged_module.bias = module.base_layer.bias.clone().to(target_dtype) + + return merged_module + @torch.inference_mode() def quantize( self, @@ -181,6 +224,12 @@ def quantize( autotune_warmup_after_quantized: bool = False, cache_examples_on_gpu: bool = True, ): + + # for name, module in self.model.named_modules(): + # if isinstance(module, rosa.RosaLayer): + # module = module.to(torch.float16) + + if self.quantized: raise EnvironmentError("can't execute quantize because the model is quantized.") @@ -267,7 +316,6 @@ def store_input_hook(_, args, kwargs): ori_outside_layer_module_devices = {} for module_name in self.outside_layer_modules: module = get_module_by_name_prefix(self.model, module_name) - if module is None: continue @@ -309,9 +357,14 @@ def store_input_hook(_, args, kwargs): force_layer_back_to_cpu = True cur_layer_device = get_device(layer) - full = find_layers(layer) + if self.is_rosa: + full = find_layers(layer, layers=[rosa.RosaLayer]) + else: + full = find_layers(layer) + for names in inside_layer_modules: subset = {n: full[n] for n in names if n in full} + gptq = {} for name in subset: gptq[name] = GPTQ(subset[name]) @@ -352,6 +405,12 @@ def tmp(_, inp, out): for name in subset: logger.info(f"Quantizing {name} in layer {i + 1}/{len(layers)}...") + + merged = None # needed for RoSA + if isinstance(subset[name], rosa.RosaLayer): + merged = self.merge_rosa_layer(subset[name]) + gptq[name].layer = merged + scale, zero, g_idx = gptq[name].fasterquant( percdamp=self.quantize_config.damp_percent, group_size=self.quantize_config.group_size, @@ -366,6 +425,15 @@ def tmp(_, inp, out): ) gptq[name].free() + if merged is not None: + subset[name] = merged + + for n in subset: + full[n] = subset[n] + + for name, module in full.items(): + self.swap_module(layer, name, module) + for j in range(num_batches): layer_input = [] for k, layer_inp in enumerate(layer_inputs[j]): @@ -610,6 +678,7 @@ def from_pretrained( max_memory: Optional[dict] = None, trust_remote_code: bool = False, torch_dtype: torch.dtype = torch.float16, + rosa_name_or_path: str = None, **model_init_kwargs, ): """load un-quantized pretrained model to cpu""" @@ -686,7 +755,13 @@ def skip(*args, **kwargs): torch.cuda.empty_cache() merged_kwargs = {**model_init_kwargs, **cached_file_kwargs} - model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path, **merged_kwargs) + + if rosa_name_or_path is not None: + model = AutoPeftModelForCausalLM.from_pretrained( + rosa_name_or_path, **merged_kwargs + ) + else: + model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path, **merged_kwargs) model_config = model.config.to_dict() seq_len_keys = ["max_position_embeddings", "seq_length", "n_positions"] @@ -700,7 +775,9 @@ def skip(*args, **kwargs): model.seqlen = 4096 model.eval() - return cls(model, False, quantize_config) + obj = cls(model, False, quantize_config) + obj.is_rosa = rosa_name_or_path is not None + return obj @classmethod def from_quantized( diff --git a/auto_gptq/quantization/gptq.py b/auto_gptq/quantization/gptq.py index cda3e7a..2d9f0dd 100644 --- a/auto_gptq/quantization/gptq.py +++ b/auto_gptq/quantization/gptq.py @@ -8,6 +8,7 @@ import transformers from .quantizer import Quantizer +from peft.tuners import rosa logger = getLogger(__name__) @@ -38,7 +39,7 @@ def add_batch(self, inp, out): if len(inp.shape) == 2: inp = inp.unsqueeze(0) tmp = inp.shape[0] - if isinstance(self.layer, nn.Linear) or isinstance(self.layer, transformers.Conv1D): + if isinstance(self.layer, nn.Linear) or isinstance(self.layer, transformers.Conv1D) or isinstance(self.layer, rosa.RosaLayer): if len(inp.shape) == 3: inp = inp.reshape((-1, inp.shape[-1])) inp = inp.t()