diff --git a/mlx_vlm/utils.py b/mlx_vlm/utils.py index 3685d2b..86d5b9b 100644 --- a/mlx_vlm/utils.py +++ b/mlx_vlm/utils.py @@ -631,8 +631,8 @@ def _pad_vision_model(model: nn.Module, vision_size: int, divisor: int = 64) -> ) if new_out != out_features or new_in != in_features: - new_weight = mx.zeros((new_out, new_in)) - new_bias = mx.zeros((new_out)) + new_weight = mx.zeros((new_out, new_in), dtype=module.weight.dtype) + new_bias = mx.zeros((new_out), dtype=module.bias.dtype) new_weight[:out_features, :in_features] = module.weight module.weight = new_weight