Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding image decoder to CoCa #467

Draft
wants to merge 11 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 46 additions & 14 deletions src/open_clip/coca_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
MultimodalTransformer,
)
from .model import CLIPTextCfg, CLIPVisionCfg, _build_vision_tower, _build_text_tower
from .tokenizer import VQGANTokenizer

try:
from transformers import (
Expand Down Expand Up @@ -48,9 +49,10 @@ class MultimodalCfg(CLIPTextCfg):
heads: int = 8
n_queries: int = 256
attn_pooler_heads: int = 8
vocab_size: int = 1 # TODO: not sure where we put this, here or VisionCfg


def _build_text_decoder_tower(
def _build_decoder_tower(
embed_dim,
multimodal_cfg,
quick_gelu: bool = False,
Expand Down Expand Up @@ -80,15 +82,17 @@ class CoCa(nn.Module):
def __init__(
self,
embed_dim,
multimodal_cfg: MultimodalCfg,
multimodal_txt_cfg: MultimodalCfg,
text_cfg: CLIPTextCfg,
vision_cfg: CLIPVisionCfg,
multimodal_img_cfg: MultimodalCfg = None,
quick_gelu: bool = False,
cast_dtype: Optional[torch.dtype] = None,
pad_id: int = 0,
):
super().__init__()
multimodal_cfg = MultimodalCfg(**multimodal_cfg) if isinstance(multimodal_cfg, dict) else multimodal_cfg
multimodal_txt_cfg = MultimodalCfg(**multimodal_txt_cfg) if isinstance(multimodal_txt_cfg, dict) else multimodal_txt_cfg
multimodal_img_cfg = MultimodalCfg(**multimodal_img_cfg) if isinstance(multimodal_img_cfg, dict) else multimodal_img_cfg
text_cfg = CLIPTextCfg(**text_cfg) if isinstance(text_cfg, dict) else text_cfg
vision_cfg = CLIPVisionCfg(**vision_cfg) if isinstance(vision_cfg, dict) else vision_cfg

Expand All @@ -99,7 +103,7 @@ def __init__(
cast_dtype=cast_dtype,
)

vocab_size = (
txt_vocab_size = (
text_cfg.vocab_size # for hf models
if hasattr(text_cfg, "hf_model_name") and text_cfg.hf_model_name is not None
else text_cfg.vocab_size
Expand All @@ -112,13 +116,26 @@ def __init__(
cast_dtype=cast_dtype,
)

self.text_decoder = _build_text_decoder_tower(
vocab_size,
multimodal_cfg=multimodal_cfg,
self.text_decoder = _build_decoder_tower(
txt_vocab_size,
multimodal_cfg=multimodal_txt_cfg,
quick_gelu=quick_gelu,
cast_dtype=cast_dtype,
)

self.img_decoder = None
if multimodal_img_cfg is not None:
self.img_decoder = _build_decoder_tower(
multimodal_img_cfg.vocab_size, # VQGAN vocab size?
multimodal_cfg=multimodal_img_cfg,
quick_gelu=quick_gelu,
cast_dtype=cast_dtype,
)
self.img_tokenizer = VQGANTokenizer("/admin/home-iejmac/taming-transformers/logs/vqgan_imagenet_f16_1024/configs/model.yaml", "/admin/home-iejmac/taming-transformers/logs/vqgan_imagenet_f16_1024/checkpoints/last.ckpt", 80)

for param in self.img_tokenizer.parameters(): # freeze
param.requires_grad = False

self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
self.pad_id = pad_id

Expand All @@ -127,6 +144,8 @@ def set_grad_checkpointing(self, enable=True):
self.visual.set_grad_checkpointing(enable)
self.text.set_grad_checkpointing(enable)
self.text_decoder.set_grad_checkpointing(enable)
if self.img_decoder is not None:
self.img_decoder.set_grad_checkpointing(enable)

def _encode_image(self, images, normalize=True):
image_latent, tokens_embs = self.visual(images)
Expand All @@ -147,23 +166,36 @@ def encode_text(self, text, normalize=True, embed_cls=True):
text_latent, _ = self._encode_text(text, normalize=normalize, embed_cls=embed_cls)
return text_latent

def forward(self, image, text, embed_cls=True, image_latent=None, image_embs=None):
text_latent, token_embs = self._encode_text(text, embed_cls=embed_cls)
def forward(self, image, text, embed_cls=True, image_latent=None, image_embs=None, text_latent=None, text_embs=None, image_tok=None):
if text_latent is None or text_embs is None:
text_latent, text_embs = self._encode_text(text, embed_cls=embed_cls)
if image_latent is None or image_embs is None:
image_latent, image_embs = self._encode_image(image)

# TODO: add assertion to avoid bugs?
labels = text[:, -token_embs.shape[1]:]
labels_text = text[:, -text_embs.shape[1]:]

logits = self.text_decoder(image_embs, token_embs)
return {
logits_text = self.text_decoder(image_embs, text_embs)
output_dict = {
"image_features": image_latent,
"text_features": text_latent,
"logits": logits,
"labels": labels,
"logits_text": logits_text,
"labels_text": labels_text,
"logit_scale": self.logit_scale.exp()
}

if self.img_decoder is not None:
logits_image = self.img_decoder(text_embs, image_embs)
labels_image = self.img_tokenizer(image_tok)
labels_image = labels_image.to(image.device)

labels_image = labels_image[:, -image_embs.shape[1]:]

output_dict["logits_image"] = logits_image
output_dict["labels_image"] = labels_image

return output_dict

def generate(
self,
image,
Expand Down
1 change: 1 addition & 0 deletions src/open_clip/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,7 @@ def create_loss(args):
return CoCaLoss(
caption_loss_weight=args.coca_caption_loss_weight,
clip_loss_weight=args.coca_contrastive_loss_weight,
image_generation_loss_weight=args.coca_image_generation_loss_weight,
local_loss=args.local_loss,
gather_with_grad=args.gather_with_grad,
cache_labels=True,
Expand Down
26 changes: 19 additions & 7 deletions src/open_clip/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ def __init__(
self,
caption_loss_weight,
clip_loss_weight,
image_generation_loss_weight=0.0,
pad_id=0, # pad_token for open_clip custom tokenizer
local_loss=False,
gather_with_grad=False,
Expand All @@ -155,22 +156,33 @@ def __init__(

self.clip_loss_weight = clip_loss_weight
self.caption_loss_weight = caption_loss_weight
self.caption_loss = nn.CrossEntropyLoss(ignore_index=pad_id)
self.image_generation_loss_weight = image_generation_loss_weight
self.generative_loss = nn.CrossEntropyLoss(ignore_index=pad_id)

def forward(self, image_features, text_features, logits, labels, logit_scale, output_dict=False):
def forward(self, image_features, text_features, logits_text, labels_text, logit_scale, logits_image=None, labels_image=None, output_dict=False):
clip_loss = super().forward(image_features, text_features, logit_scale)
clip_loss = self.clip_loss_weight * clip_loss

caption_loss = self.caption_loss(
logits.permute(0, 2, 1),
labels,
caption_loss = self.generative_loss(
logits_text.permute(0, 2, 1),
labels_text,
)
caption_loss = caption_loss * self.caption_loss_weight
out_dict = {"contrastive_loss": clip_loss, "caption_loss": caption_loss}

image_gen_loss = None
if labels_image is not None and logits_image is not None:
image_gen_loss = self.generative_loss(
logits_image.permute(0, 2, 1),
labels_image,
)
image_gen_loss = image_gen_loss * self.image_generation_loss_weight
out_dict["image_gen_loss"] = image_gen_loss

if output_dict:
return {"contrastive_loss": clip_loss, "caption_loss": caption_loss}
return out_dict

return clip_loss, caption_loss
return (clip_loss, caption_loss) if image_gen_loss is None else (clip_loss, caption_loss, image_gen_loss)


class DistillClipLoss(ClipLoss):
Expand Down
4 changes: 2 additions & 2 deletions src/open_clip/model_configs/coca_ViT-B-32.json
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
"embed_cls": true,
"output_tokens": true
},
"multimodal_cfg": {
"multimodal_txt_cfg": {
"context_length": 76,
"vocab_size": 49408,
"width": 512,
Expand All @@ -27,4 +27,4 @@
"attn_pooler_heads": 8
},
"custom_text": true
}
}
39 changes: 39 additions & 0 deletions src/open_clip/model_configs/img_coca_ViT-B-32.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
{
"embed_dim": 512,
"vision_cfg": {
"image_size": 224,
"layers": 12,
"width": 768,
"patch_size": 32,
"n_queries": 256,
"attentional_pool": true,
"attn_pooler_heads": 8,
"output_tokens": true
},
"text_cfg": {
"context_length": 76,
"vocab_size": 49408,
"width": 512,
"heads": 8,
"layers": 12,
"embed_cls": true,
"output_tokens": true
},
"multimodal_txt_cfg": {
"context_length": 76,
"vocab_size": 49408,
"width": 512,
"heads": 8,
"layers": 12,
"attn_pooler_heads": 8
},
"multimodal_img_cfg": {
"context_length": 255,
"width": 512,
"heads": 8,
"layers": 12,
"attn_pooler_heads": 8,
"vocab_size": 1024
},
"custom_text": true
}
63 changes: 63 additions & 0 deletions src/open_clip/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import os
from functools import lru_cache
from typing import Union, List
from torch import nn

import ftfy
import regex as re
Expand Down Expand Up @@ -212,3 +213,65 @@ def __call__(self, texts: Union[str, List[str]], context_length: int = 77) -> to
truncation=True,
).input_ids
return input_ids


# TEMPORARY WIP
import importlib
import torch.nn.functional as F
from omegaconf import OmegaConf
from taming.models.vqgan import VQModel # I don't love this part
from einops import rearrange
from math import sqrt

def get_obj_from_str(string, reload=False):
module, cls = string.rsplit(".", 1)
if reload:
module_imp = importlib.import_module(module)
importlib.reload(module_imp)
return getattr(importlib.import_module(module, package=None), cls)

def instantiate_from_config(config):
if not "target" in config:
raise KeyError("Expected key `target` to instantiate.")
return get_obj_from_str(config["target"])(**config.get("params", dict()))


class VQGANTokenizer(nn.Module):
"""VQGAN image tokenizer"""

def __init__(self, config_path, model_path, split_batch):
super().__init__()
config = OmegaConf.load(config_path)
self.vqgan = VQModel(**config.model.params)
sd = torch.load(model_path, map_location="cpu")["state_dict"]
missing, unexpected = self.vqgan.load_state_dict(sd, strict=False)
self.num_tokens = config["model"]["params"]["n_embed"]
self.split_batch = split_batch

self.vqgan.eval()

def encode(self, image):
tot_indices = []
for img in torch.split(image, self.split_batch):
_, _, [_, _, indices] = self.vqgan.encode(img)
tot_indices.append(indices.reshape(img.shape[0], -1))
indices = torch.cat(tot_indices)
return indices # [bs, ctx_len]

def _get_embeddings(self, tokens):
one_hot_indices = F.one_hot(tokens, num_classes = self.num_tokens).float()
z = one_hot_indices @ self.vqgan.quantize.embedding.weight
return z

def decode(self, tokens):
b, n = tokens.shape
z = self._get_embeddings(tokens)

z = rearrange(z, 'b (h w) c -> b c h w', h = int(sqrt(n)))
img = self.vqgan.decode(z)

img = (img.clamp(-1., 1.) + 1) * 0.5
return img

def __call__(self, image):
return self.encode(image)
22 changes: 19 additions & 3 deletions src/training/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,22 @@ def expand_urls(urls, weights=None):
all_urls = list(urls)
return all_urls, weights

# TODO: remove temp, figure out better way
import PIL
import torchvision.transforms as T
import torchvision.transforms.functional as TF
def preprocess_vqgan(img):
target_img_size=256
s = min(img.size)
r = target_img_size / s
s = (round(r * img.size[1]), round(r * img.size[0]))
img = TF.resize(img, s, interpolation=PIL.Image.LANCZOS)
img = TF.center_crop(img, output_size=2 * [target_img_size])
img = T.ToTensor()(img)

img = 2. * img - 1.
return img


def get_dataset_size(shards):
shards_list, _ = expand_urls(shards)
Expand Down Expand Up @@ -378,9 +394,9 @@ def get_wds_dataset(args, preprocess_img, is_train, epoch=0, floor=False, tokeni
pipeline.extend([
wds.select(filter_no_caption_or_no_image),
wds.decode("pilrgb", handler=log_and_continue),
wds.rename(image="jpg;png;jpeg;webp", text="txt"),
wds.map_dict(image=preprocess_img, text=lambda text: tokenizer(text)[0]),
wds.to_tuple("image", "text"),
wds.rename(image="jpg;png;jpeg;webp", image_vqgan="jpg;png;jpeg;webp", text="txt"),
wds.map_dict(image=preprocess_img, image_vqgan=preprocess_vqgan, text=lambda text: tokenizer(text)[0]),
wds.to_tuple("image", "image_vqgan", "text"),
wds.batched(args.batch_size, partial=not is_train)
])

Expand Down
6 changes: 6 additions & 0 deletions src/training/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,12 @@ def parse_args(args):
default=2.0,
help="Weight assigned to caption loss in CoCa."
)
parser.add_argument(
"--coca-image-generation-loss-weight",
type=float,
default=2.0,
help="Weight assigned to caption loss in CoCa."
)
parser.add_argument(
"--coca-contrastive-loss-weight",
type=float,
Expand Down
6 changes: 4 additions & 2 deletions src/training/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,16 +88,18 @@ def train_one_epoch(model, data, loss, epoch, optimizer, scaler, scheduler, dist
if not args.skip_scheduler:
scheduler(step)

images, texts = batch
images, images_vqgan, texts = batch

images = images.to(device=device, dtype=cast_dtype, non_blocking=True)
images_vqgan = images_vqgan.to(device=device, dtype=cast_dtype, non_blocking=True)
texts = texts.to(device=device, non_blocking=True)

data_time_m.update(time.time() - end)
optimizer.zero_grad()

if args.accum_freq == 1:
with autocast():
model_out = model(images, texts)
model_out = model(images, texts, image_tok=images_vqgan)
logit_scale = model_out["logit_scale"]
if args.distill:
with torch.no_grad():
Expand Down