Skip to content

Commit

Permalink
add peft model creation
Browse files Browse the repository at this point in the history
  • Loading branch information
Blaizzy committed Jun 12, 2024
1 parent 2c72233 commit 2391df4
Show file tree
Hide file tree
Showing 3 changed files with 136 additions and 19 deletions.
10 changes: 8 additions & 2 deletions mlx_vlm/trainer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,8 @@
from .utils import collate_fn, find_all_linear_names
from .lora import LoRaLayer, replace_lora_with_linear
from .lora import LoRaLayer, replace_lora_with_linear
from .utils import (
collate_fn,
count_parameters,
find_all_linear_names,
get_peft_model,
print_trainable_parameters,
)
15 changes: 9 additions & 6 deletions mlx_vlm/trainer/lora.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,22 @@
import math
from typing import Union

import mlx.core as mx
import mlx.nn as nn


class LoRaLayer(nn.Module):
def __init__(
self,
linear: Union[nn.Linear, nn.QuantizedLinear],
rank: int,
alpha: float = 0.1,
dropout: float = 0.0,

):
super().__init__()

self.original_layer = linear

self.dropout = nn.Dropout(p=dropout)

output_dims, input_dims = linear.weight.shape
Expand All @@ -26,13 +28,14 @@ def __init__(
high=std_dev,
shape=(input_dims, rank),
)
self.B = mx.zeros(rank, output_dims)
self.B = mx.zeros((rank, output_dims))
self.alpha = alpha

def __call__(self, x):
y = self.original_layer(x)
lora_update = (self.dropout(x) @ self.A) @ self.B
return y + (self.alpha * lora_update).astype(x.dtype)
return y + (self.alpha * lora_update).astype(x.dtype)


def replace_lora_with_linear(model):
for i, layer in enumerate(model.layers):
Expand All @@ -45,7 +48,9 @@ def replace_lora_with_linear(model):
updated_bias = layer.original_layer.bias

# Create a new Linear layer with the updated parameters
new_linear_layer = nn.Linear(updated_weight.size(1), updated_weight.size(0), bias=use_bias)
new_linear_layer = nn.Linear(
updated_weight.size(1), updated_weight.size(0), bias=use_bias
)

new_linear_layer.weight = updated_weight

Expand All @@ -59,7 +64,5 @@ def replace_lora_with_linear(model):
new_linear_layer.bits,
)


# Replace the LoRaLayer with the new Linear layer in the model
model.layers[i] = new_linear_layer

130 changes: 119 additions & 11 deletions mlx_vlm/trainer/utils.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,138 @@

import mlx.nn as nn
import mlx.core as mx
import mlx.nn as nn
import numpy as np

from .lora import LoRaLayer


def get_module_by_name(model, name):
parts = name.split(".")
module = model
for part in parts:
if part.isdigit():
module = module[int(part)]
else:
module = getattr(module, part)
return module


def set_module_by_name(model, name, new_module):
parts = name.split(".")
module = model
for part in parts[:-1]:
if part.isdigit():
module = module[int(part)]
else:
module = getattr(module, part)
if parts[-1].isdigit():
module[int(parts[-1])] = new_module
else:
setattr(module, parts[-1], new_module)


def get_peft_model(model, linear_layers, freeze=True, verbose=True):
source_model_trainable = count_parameters(
model.language_model.trainable_parameters()
)

if freeze:
freeze_model(model)

for name, module in model.named_modules():
if isinstance(module, nn.Linear) and name.split(".")[-1] in linear_layers:
lora_layer = LoRaLayer(module, 10, 0.1, 0.1)
set_module_by_name(model, name, lora_layer)

lora_model_trainable = count_parameters(model.language_model.trainable_parameters())
if verbose:
print_trainable_parameters(source_model_trainable, lora_model_trainable)

return model


def freeze_model(model):
for name, module in model.named_modules():
if name in [
"language_model",
"vision_model",
"vision_tower",
"aligner",
"connector",
"multi_modal_projector",
"mm_projector",
]:
model[f"{name}"].freeze()


def find_all_linear_names(model):
cls = nn.Linear
lora_module_names = set()
multimodal_keywords = ['mm_projector', 'vision_tower', 'vision_resampler']
multimodal_keywords = [
"mm_projector",
"vision_tower",
"vision_resampler",
"aligner",
]
for name, module in model.named_modules():
if any(mm_keyword in name for mm_keyword in multimodal_keywords):
continue
if isinstance(module, cls):
names = name.split('.')
names = name.split(".")
lora_module_names.add(names[0] if len(names) == 1 else names[-1])

if 'lm_head' in lora_module_names: # needed for 16-bit
lora_module_names.remove('lm_head')
if "lm_head" in lora_module_names: # needed for 16-bit
lora_module_names.remove("lm_head")
return list(lora_module_names)


def collate_fn(processor, examples):
texts = ["answer " + example["question"] for example in examples]
labels= [example['multiple_choice_answer'] for example in examples]
labels = [example["multiple_choice_answer"] for example in examples]
images = [example["image"].convert("RGB") for example in examples]
tokens = processor(text=texts, images=images, suffix=labels,
return_tensors="pt", padding="longest",
tokenize_newline_separately=False)
tokens = processor(
text=texts,
images=images,
suffix=labels,
return_tensors="np",
padding="longest",
tokenize_newline_separately=False,
)

tokens = tokens.to(mx.float16)
return tokens
return tokens


def flatten_dict(dd, separator="_", prefix=""):
return (
{
prefix + separator + k if prefix else k: v
for kk, vv in dd.items()
for k, v in flatten_dict(vv, separator, kk).items()
}
if isinstance(dd, dict)
else {prefix: dd}
)


def count_parameters(trainable_params_dict):
total_params = 0
for k, v in flatten_dict(trainable_params_dict).items():
if hasattr(v, "shape"):
total_params += np.prod(v.shape)

if isinstance(v, list):
for v_ in v:
v_ = flatten_dict(v_)
if isinstance(v_, dict):
total_params += sum(
np.prod(p.shape) for p in v_.values() if hasattr(p, "shape")
)

return total_params


def print_trainable_parameters(source_model_trainable, lora_model_trainable):
lora_trainable_percent = (lora_model_trainable / source_model_trainable) * 100
print(
f"#trainable params: {lora_model_trainable} || all params: {source_model_trainable} || trainable%: {lora_trainable_percent}"
)

0 comments on commit 2391df4

Please sign in to comment.