Skip to content

Commit

Permalink
added support for RoSA merge&quantize
Browse files Browse the repository at this point in the history
  • Loading branch information
MNikdan committed May 4, 2024
1 parent ea829c7 commit d603fcd
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 5 deletions.
85 changes: 81 additions & 4 deletions auto_gptq/modeling/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from safetensors.torch import save_file as safe_save
from tqdm import tqdm
from transformers import AutoConfig, AutoModelForCausalLM, PreTrainedModel
from peft import AutoPeftModelForCausalLM
from transformers.modeling_utils import no_init_weights
from transformers.utils.generic import ContextManagers
from transformers.utils.hub import (
Expand All @@ -22,6 +23,8 @@
create_commit,
create_repo,
)
from peft.tuners import rosa
from bitsandbytes.functional import dequantize_4bit as dq4bit

from ..nn_modules._fused_base import FusedBaseAttentionModule, FusedBaseMLPModule
from ..nn_modules.qlinear import GeneralQuantLinear
Expand Down Expand Up @@ -171,6 +174,46 @@ def _convert_tensor_to_list(tensor):

return new_examples

def swap_module(self, network, module_name, new_module):
name_parts = module_name.split('.')
parent = network
for part in name_parts[:-1]:
if part.isdigit():
parent = parent[int(part)]
else:
parent = getattr(parent, part)

last_part = name_parts[-1]
if last_part.isdigit():
parent[int(last_part)] = new_module
else:
setattr(parent, last_part, new_module)

def merge_rosa_layer(self, module, target_dtype=torch.bfloat16):
assert isinstance(module, rosa.RosaLayer)

adapt = module.get_delta_weight('default').to(target_dtype)
base = module.find_weight()
if isinstance(module, rosa.Linear4bit):
base = dq4bit(base.data, base.quant_state).to(target_dtype)
else:
assert isinstance(module, rosa.Linear)
base = base.to(target_dtype)

merged_module = torch.nn.Linear(
module.in_features,
module.out_features,
bias=module.base_layer.bias is not None,
device=adapt.device,
dtype=target_dtype
)

merged_module.weight.mul_(0).add_(base + adapt)
if module.base_layer.bias is not None:
merged_module.bias = module.base_layer.bias.clone().to(target_dtype)

return merged_module

@torch.inference_mode()
def quantize(
self,
Expand All @@ -181,6 +224,12 @@ def quantize(
autotune_warmup_after_quantized: bool = False,
cache_examples_on_gpu: bool = True,
):

# for name, module in self.model.named_modules():
# if isinstance(module, rosa.RosaLayer):
# module = module.to(torch.float16)


if self.quantized:
raise EnvironmentError("can't execute quantize because the model is quantized.")

Expand Down Expand Up @@ -267,7 +316,6 @@ def store_input_hook(_, args, kwargs):
ori_outside_layer_module_devices = {}
for module_name in self.outside_layer_modules:
module = get_module_by_name_prefix(self.model, module_name)

if module is None:
continue

Expand Down Expand Up @@ -309,9 +357,14 @@ def store_input_hook(_, args, kwargs):
force_layer_back_to_cpu = True
cur_layer_device = get_device(layer)

full = find_layers(layer)
if self.is_rosa:
full = find_layers(layer, layers=[rosa.RosaLayer])
else:
full = find_layers(layer)

for names in inside_layer_modules:
subset = {n: full[n] for n in names if n in full}

gptq = {}
for name in subset:
gptq[name] = GPTQ(subset[name])
Expand Down Expand Up @@ -352,6 +405,12 @@ def tmp(_, inp, out):

for name in subset:
logger.info(f"Quantizing {name} in layer {i + 1}/{len(layers)}...")

merged = None # needed for RoSA
if isinstance(subset[name], rosa.RosaLayer):
merged = self.merge_rosa_layer(subset[name])
gptq[name].layer = merged

scale, zero, g_idx = gptq[name].fasterquant(
percdamp=self.quantize_config.damp_percent,
group_size=self.quantize_config.group_size,
Expand All @@ -366,6 +425,15 @@ def tmp(_, inp, out):
)
gptq[name].free()

if merged is not None:
subset[name] = merged

for n in subset:
full[n] = subset[n]

for name, module in full.items():
self.swap_module(layer, name, module)

for j in range(num_batches):
layer_input = []
for k, layer_inp in enumerate(layer_inputs[j]):
Expand Down Expand Up @@ -610,6 +678,7 @@ def from_pretrained(
max_memory: Optional[dict] = None,
trust_remote_code: bool = False,
torch_dtype: torch.dtype = torch.float16,
rosa_name_or_path: str = None,
**model_init_kwargs,
):
"""load un-quantized pretrained model to cpu"""
Expand Down Expand Up @@ -686,7 +755,13 @@ def skip(*args, **kwargs):
torch.cuda.empty_cache()

merged_kwargs = {**model_init_kwargs, **cached_file_kwargs}
model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path, **merged_kwargs)

if rosa_name_or_path is not None:
model = AutoPeftModelForCausalLM.from_pretrained(
rosa_name_or_path, **merged_kwargs
)
else:
model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path, **merged_kwargs)

model_config = model.config.to_dict()
seq_len_keys = ["max_position_embeddings", "seq_length", "n_positions"]
Expand All @@ -700,7 +775,9 @@ def skip(*args, **kwargs):
model.seqlen = 4096
model.eval()

return cls(model, False, quantize_config)
obj = cls(model, False, quantize_config)
obj.is_rosa = rosa_name_or_path is not None
return obj

@classmethod
def from_quantized(
Expand Down
3 changes: 2 additions & 1 deletion auto_gptq/quantization/gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import transformers

from .quantizer import Quantizer
from peft.tuners import rosa


logger = getLogger(__name__)
Expand Down Expand Up @@ -38,7 +39,7 @@ def add_batch(self, inp, out):
if len(inp.shape) == 2:
inp = inp.unsqueeze(0)
tmp = inp.shape[0]
if isinstance(self.layer, nn.Linear) or isinstance(self.layer, transformers.Conv1D):
if isinstance(self.layer, nn.Linear) or isinstance(self.layer, transformers.Conv1D) or isinstance(self.layer, rosa.RosaLayer):
if len(inp.shape) == 3:
inp = inp.reshape((-1, inp.shape[-1]))
inp = inp.t()
Expand Down

0 comments on commit d603fcd

Please sign in to comment.