Skip to content

Commit

Permalink
Add Dolphin-vision and bunny (#50)
Browse files Browse the repository at this point in the history
* add Dolphin-vision

* add bunny

* bump version
  • Loading branch information
Blaizzy authored Jul 4, 2024
1 parent 8ca2a55 commit 9421149
Show file tree
Hide file tree
Showing 8 changed files with 51 additions and 51 deletions.
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .nanoLlava import (
from .llava_bunny import (
ImageProcessor,
LanguageModel,
Model,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ class TextConfig:
num_attention_heads: int
rms_norm_eps: float
vocab_size: int
attention_bias: bool = True
num_key_value_heads: int = None
rope_theta: float = 1000000
rope_traditional: bool = False
Expand Down Expand Up @@ -55,9 +56,14 @@ def __init__(self, args: TextConfig):
head_dim = args.hidden_size // n_heads
self.scale = head_dim**-0.5

self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=True)
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=True)
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=True)
if hasattr(args, "attention_bias"):
attention_bias = args.attention_bias
else:
attention_bias = False

self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=attention_bias)
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias)
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias)
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)

rope_scale = (
Expand Down
File renamed without changes.
File renamed without changes.
4 changes: 2 additions & 2 deletions mlx_vlm/prompt_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ def get_message_json(model_name, prompt):
Get the appropriate JSON message based on the specified model.
Args:
model_name (str): The model for which to generate the message. Options: 'Idefics 2', 'nanollava', 'llava'.
model_name (str): The model for which to generate the message.
prompt (str): The text prompt to be included in the message.
*args: Additional positional arguments (unused).
**kwargs: Additional keyword arguments (unused).
Expand All @@ -16,7 +16,7 @@ def get_message_json(model_name, prompt):
"role": "user",
"content": [{"type": "image"}, {"type": "text", "text": prompt}],
}
elif model_name.lower() in ["llava-qwen2", "llava", "llava_next"]:
elif model_name.lower() in ["llava-qwen2", "llava", "llava_next", "bunny-llama"]:
message = {"role": "user", "content": f"<image>\n{prompt}"}
elif model_name.lower() == "phi3_v":
message = {"role": "user", "content": f"<|image_1|>\n{prompt}"}
Expand Down
12 changes: 6 additions & 6 deletions mlx_vlm/tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,10 @@ def vision_test_runner(
hidden_states[vision_feature_layer][-1][-1].shape, (vision_hidden_size,)
)

def test_nanoLlava(self):
from mlx_vlm.models import nanoLlava
def test_llava_bunny(self):
from mlx_vlm.models import llava_bunny

text_config = nanoLlava.TextConfig(
text_config = llava_bunny.TextConfig(
model_type="qwen2",
hidden_size=4096,
num_hidden_layers=32,
Expand All @@ -87,7 +87,7 @@ def test_nanoLlava(self):
rope_scaling=None,
)

vision_config = nanoLlava.VisionConfig(
vision_config = llava_bunny.VisionConfig(
model_type="siglip_vision_model",
num_hidden_layers=27,
hidden_size=1152,
Expand All @@ -101,7 +101,7 @@ def test_nanoLlava(self):
layer_norm_eps=1e-6,
)

args = nanoLlava.ModelConfig(
args = llava_bunny.ModelConfig(
text_config=text_config,
vision_config=vision_config,
model_type="llava-qwen2",
Expand All @@ -118,7 +118,7 @@ def test_nanoLlava(self):
vocab_size=151936,
)

model = nanoLlava.Model(args)
model = llava_bunny.Model(args)

self.language_test_runner(
model.language_model,
Expand Down
70 changes: 32 additions & 38 deletions mlx_vlm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,7 @@
from .tokenizer_utils import load_tokenizer

# Constants
MODEL_REMAPPING = {
"llava-qwen2": "nanoLlava",
}
MODEL_REMAPPING = {"llava-qwen2": "llava_bunny", "bunny-llama": "llava_bunny"}

MAX_FILE_SIZE_GB = 5

Expand Down Expand Up @@ -150,12 +148,15 @@ def load_model(model_path: Path, lazy: bool = False) -> nn.Module:

model_class, model_type = get_model_and_args(config=config)

if model_type == "nanoLlava":
if model_type == "llava_bunny":
vision_config = AutoConfig.from_pretrained(config["mm_vision_tower"])
text_config = AutoConfig.from_pretrained(config["language_model"])
vision_config = vision_config.to_dict()
text_config = text_config.to_dict()
config["vision_config"] = vision_config["vision_config"]
config["vision_config"] = {
**vision_config["vision_config"],
**config.get("vision_config", {}),
}
config["text_config"] = text_config
if model_type == "idefics2":
config = AutoConfig.from_pretrained(model_path).to_dict()
Expand Down Expand Up @@ -194,7 +195,6 @@ def load_model(model_path: Path, lazy: bool = False) -> nn.Module:
weights = model_class.LanguageModel(model_config.text_config).sanitize(
weights=weights
)

if (quantization := config.get("quantization", None)) is not None:
# Handle legacy models which may not have everything quantized
class_predicate = (
Expand Down Expand Up @@ -502,10 +502,8 @@ def quantize_model(
divisor = 64
if any(vision_intermediate_size % size != 0 for size in [64, 128]):
for name, module in model.named_modules():
if (
isinstance(module, nn.Linear)
or isinstance(module, nn.Embedding)
and ("vision_model" in name or "vision_tower" in name)
if isinstance(module, nn.Linear) and (
"vision_model" in name or "vision_tower" in name
):
out_features, in_features = module.weight.shape

Expand All @@ -520,34 +518,30 @@ def quantize_model(
if in_features % divisor != 0
else in_features
)
if (
out_features == vision_intermediate_size
or in_features == vision_intermediate_size
):

# If padding is needed, proceed
if (
new_out_features != out_features
or new_in_features != in_features
):
# Create new weight and bias tensors
new_weight = mx.zeros((new_out_features, new_in_features))
new_bias = mx.zeros((new_out_features))

# Copy existing weights and biases to the new tensors
new_weight[:out_features, :in_features] = module.weight
module.weight = new_weight

if hasattr(module, "bias"):
new_bias[:out_features] = module.bias
module.bias = new_bias

if "vision_config" in quantized_config:
quantized_config["vision_config"]["intermediate_size"] = (
((vision_intermediate_size // divisor) + 1) * divisor
if vision_intermediate_size % divisor != 0
else vision_intermediate_size
)

# If padding is needed, proceed
if new_out_features != out_features or new_in_features != in_features:
# Create new weight and bias tensors
new_weight = mx.zeros((new_out_features, new_in_features))
new_bias = mx.zeros((new_out_features))

# Copy existing weights and biases to the new tensors
new_weight[:out_features, :in_features] = module.weight
module.weight = new_weight

if hasattr(module, "bias"):
new_bias[:out_features] = module.bias
module.bias = new_bias

# Ensure vision_config exists in quantized_config
quantized_config.setdefault("vision_config", {})

# Update intermediate_size
quantized_config["vision_config"]["intermediate_size"] = (
((vision_intermediate_size // divisor) + 1) * divisor
if vision_intermediate_size % divisor != 0
else vision_intermediate_size
)

nn.quantize(model, q_group_size, q_bits)
quantized_config["quantization"] = {"group_size": q_group_size, "bits": q_bits}
Expand Down
2 changes: 1 addition & 1 deletion mlx_vlm/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.0.10"
__version__ = "0.0.11"

0 comments on commit 9421149

Please sign in to comment.