Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Finetuning using moondream.hf implementation fails #187

Open
borisloktev opened this issue Jan 3, 2025 · 0 comments
Open

Finetuning using moondream.hf implementation fails #187

borisloktev opened this issue Jan 3, 2025 · 0 comments

Comments

@borisloktev
Copy link

Hi there!
I'm trying to fine-tune the model on a custom dataset for object detection. I've created a script that uses the hf-compatible class from this repo as ist model, but when i try and load it, it gives me an error
ValueError: Trying to set a tensor of shape torch.Size([2048, 256]) in "weight" (which has shape torch.Size([2048, 512])), this looks incorrect.
while trying to load the weights
am I doing something wrong? maybe there is a config mismatch between versions somewhere that I could try and update?
Thank you in advance!

My code for reference:

import torch
from torch.utils.data import DataLoader
from transformers import get_linear_schedule_with_warmup
import json
import torch
from PIL import Image
from torch.utils.data import Dataset
import os
from torchvision.transforms.functional import to_pil_image
import numpy as np


class DetectionDataset(Dataset):
    def __init__(self, jsonl_file, image_dir, transform=None):
        """
        Args:
            jsonl_file (str): Path to the JSONL file containing annotations
            image_dir (str): Directory containing the images
            transform (callable, optional): Optional transform to be applied on images
        """
        self.image_dir = image_dir
        self.transform = transform
        self.data = []

        # Load annotations from JSONL file
        with open(jsonl_file, "r", encoding="utf-8") as f:
            for line in f:
                self.data.append(json.loads(line.strip()))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]

        # Load and process image
        image_path = os.path.join(self.image_dir, item["image"])
        image = Image.open(image_path).convert("RGB")
        if self.transform:
            image = self.transform(image)
        else:
            # If no transform is provided, convert PIL Image to tensor
            image = torch.from_numpy(np.array(image)).permute(2, 0, 1).float() / 255.0

        # Process bounding boxes
        # item['suffix'] contains the normalized coordinates [x_center, y_center, width, height]
        boxes = torch.tensor(item["suffix"], dtype=torch.float32)

        # For this example, we'll use a simple "detect objects" query
        # You can modify this based on your specific needs
        query = "blueprint features"

        return {"images": image, "boxes": boxes, "query": query}


def custom_collate_fn(batch):
    images = [item["images"] for item in batch]  # Keep as list of tensors
    boxes = [item["boxes"] for item in batch]
    queries = [item["query"] for item in batch]

    return {
        "images": images,  # List of tensors with different sizes
        "boxes": boxes,
        "query": queries,
    }


def train_detection(
    model,
    tokenizer,
    train_dataset,
    val_dataset=None,
    epochs=3,
    batch_size=8,
    learning_rate=5e-5,
    warmup_steps=0.1,
    max_grad_norm=1.0,
    device="cuda",
):
    model.train()
    model.to(device)

    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
    train_dataloader = DataLoader(
        train_dataset, batch_size=batch_size, shuffle=True, collate_fn=custom_collate_fn
    )

    # Create scheduler
    num_training_steps = len(train_dataloader) * epochs
    scheduler = get_linear_schedule_with_warmup(
        optimizer, num_warmup_steps=warmup_steps, num_training_steps=num_training_steps
    )

    # Loss functions
    coord_loss_fn = torch.nn.CrossEntropyLoss()
    size_loss_fn = torch.nn.CrossEntropyLoss()

    for epoch in range(epochs):
        total_loss = 0
        for batch in train_dataloader:
            pil_images = [to_pil_image(img) for img in batch["images"]]

            query = batch["query"]
            target_boxes = [
                i.to(device) for i in batch["boxes"]
            ]  # Shape: [batch_size, num_boxes, 4]

            # Encode images
            image_embeds = model.encode_image(pil_images)

            # Prepare input prompts
            prompts = [f"<image>\n\nDetect: {q}\n\n" for q in query]

            # Forward pass through the model
            loss = 0
            for i in range(len(prompts)):
                inputs_embeds = model.input_embeds(
                    prompts[i], image_embeds[i : i + 1], tokenizer
                )

                # Get hidden states from text model
                attention_mask = torch.ones(
                    (inputs_embeds.shape[0], inputs_embeds.shape[1]), device=device
                )
                outputs = model.text_model(
                    inputs_embeds=inputs_embeds,
                    attention_mask=attention_mask,
                    output_hidden_states=True,
                    return_dict=True,
                )
                hidden_states = outputs.hidden_states[-1][
                    :, -1, :
                ]  # Get last hidden state

                # Generate coordinates and sizes
                for box in target_boxes[i]:
                    # Predict x coordinate
                    x_logits = model.region_model.decode_coordinate(hidden_states)
                    x_target = (
                        box[0] * 1024
                    ).long()  # Convert normalized coord to discrete steps
                    loss += coord_loss_fn(x_logits, x_target)

                    # Update hidden states with x coordinate
                    x_coord_encoded = model.region_model.encode_coordinate(
                        box[0]
                    ).unsqueeze(0)
                    inputs_embeds = torch.cat(
                        [inputs_embeds, x_coord_encoded.unsqueeze(0)], dim=1
                    )
                    attention_mask = torch.ones(
                        (inputs_embeds.shape[0], inputs_embeds.shape[1]), device=device
                    )
                    outputs = model.text_model(
                        inputs_embeds=inputs_embeds,
                        attention_mask=attention_mask,
                        output_hidden_states=True,
                        return_dict=True,
                    )
                    hidden_states = outputs.hidden_states[-1][:, -1, :]

                    # Predict y coordinate
                    y_logits = model.region_model.decode_coordinate(hidden_states)
                    y_target = (box[1] * 1024).long()
                    loss += coord_loss_fn(y_logits, y_target)

                    # Update hidden states with y coordinate
                    y_coord_encoded = model.region_model.encode_coordinate(
                        box[1]
                    ).unsqueeze(0)
                    inputs_embeds = torch.cat(
                        [inputs_embeds, y_coord_encoded.unsqueeze(0)], dim=1
                    )
                    attention_mask = torch.ones(
                        (inputs_embeds.shape[0], inputs_embeds.shape[1]), device=device
                    )
                    outputs = model.text_model(
                        inputs_embeds=inputs_embeds,
                        attention_mask=attention_mask,
                        output_hidden_states=True,
                        return_dict=True,
                    )
                    hidden_states = outputs.hidden_states[-1][:, -1, :]

                    # Predict size
                    size_logits = model.region_model.decode_size(hidden_states)
                    size_target = (box[2:] * 1024).long()
                    loss += size_loss_fn(size_logits, size_target)
            loss = loss / sum(
                len(boxes) for boxes in target_boxes
            )  # Normalize by total number of boxes

            # Backward pass
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()

            total_loss += loss.item()

        avg_loss = total_loss / len(train_dataloader)
        print(f"Epoch {epoch+1}/{epochs}, Average Loss: {avg_loss:.4f}")

        # # Validation
        # if val_dataset:
        #     evaluate_detection(model, val_dataset, device)
        #     model.train()


from transformers import AutoTokenizer
from moondream.hf.moondream import Moondream
from moondream.hf import LATEST_REVISION

DEVICE = "cuda"
DTYPE = (
    torch.float32 if DEVICE == "cpu" else torch.float16
)  # CPU doesn't support float16

tokenizer = AutoTokenizer.from_pretrained(
    "vikhyatk/moondream2",
    revision=LATEST_REVISION,
    trust_remote_code=True,
)
# Modify the model initialization to resize token embeddings
moondream = Moondream.from_pretrained(
    "vikhyatk/moondream2",
    trust_remote_code=True,
    revision=LATEST_REVISION,
    attn_implementation="flash_attention_2" if DEVICE == "cuda" else None,
    torch_dtype=DTYPE,
    device_map={"": DEVICE},
)
train_labels = "ocr_annotations/train.jsonl"
img_dir = "images"
train_dataset = DetectionDataset(jsonl_file=train_labels, image_dir=img_dir)

train_detection(
    model=moondream,
    tokenizer=tokenizer,
    train_dataset=train_dataset,
    epochs=3,
    batch_size=8,
    learning_rate=5e-5,
    warmup_steps=0.1,
    max_grad_norm=1.0,
    device=DEVICE,
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant