forked from mlfoundations/open_clip
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Combining CLIPA-v2 and SigLIP (both big_vision based) models (mlfound…
- Loading branch information
1 parent
41429e2
commit dd76e74
Showing
50 changed files
with
1,761 additions
and
478 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
# eval on a single gpu | ||
CUDA_VISIBLE_DEVICES=2 TORCH_CUDNN_V8_API_ENABLED=1 TFDS_PREFETCH_SIZE=8192 python3 -m training.main \ | ||
--model ViT-L-16-CL32-GAP \ | ||
--pretrained "/path/to/clipa_vit_l16_i37_t8.pt" \ | ||
--seed 0 \ | ||
--imagenet-val '/path/to/ImageNet/val' |
10 changes: 10 additions & 0 deletions
10
scripts/clipav2_vit_h14_i84_224_336_cl32_gap_datacomp1b.sh
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
CUDA_VISIBLE_DEVICES=1 python3 -m training.main \ | ||
--model ViT-H-14-CL32-GAP-BigVision \ | ||
--pretrained "/path/to/vit_h14_i84_224_336_cl32_gap_datacomp1b.pt" \ | ||
--force-image-size 336 \ | ||
--square-resize-only \ | ||
--interpolation 'bilinear' \ | ||
--image-mean 0.485 0.456 0.406 \ | ||
--image-std 0.229 0.224 0.225 \ | ||
--seed 0 \ | ||
--imagenet-val '/path/to/ImageNet/val' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,136 @@ | ||
import torch | ||
import numpy as np | ||
|
||
from .model import CustomTextCLIP | ||
from .transformer import TextTransformer, Transformer | ||
|
||
|
||
@torch.no_grad() | ||
def load_big_vision_weights(model: CustomTextCLIP, checkpoint_path: str): | ||
""" Load weights from .npz checkpoints for official Google big_vision image-text models | ||
Currently the SigLIP source models are supported and a CustomTextCLIP destination model | ||
w/ timm image encoder. | ||
""" | ||
from timm.layers import resample_patch_embed, resample_abs_pos_embed | ||
|
||
def _n2p(w, t=True): | ||
if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1: | ||
w = w.flatten() | ||
if t: | ||
if w.ndim == 4: | ||
w = w.transpose([3, 2, 0, 1]) | ||
elif w.ndim == 3: | ||
w = w.transpose([2, 0, 1]) | ||
elif w.ndim == 2: | ||
w = w.transpose([1, 0]) | ||
return torch.from_numpy(w) | ||
|
||
w = np.load(checkpoint_path) | ||
interpolation = 'bilinear' | ||
antialias = False | ||
|
||
def _convert_timm_img(module, prefix): | ||
embed_conv_w = _n2p(w[f'{prefix}embedding/kernel']) | ||
if embed_conv_w.shape[-2:] != module.patch_embed.proj.weight.shape[-2:]: | ||
embed_conv_w = resample_patch_embed( | ||
embed_conv_w, | ||
module.patch_embed.proj.weight.shape[-2:], | ||
interpolation=interpolation, | ||
antialias=antialias, | ||
verbose=True, | ||
) | ||
module.patch_embed.proj.weight.copy_(embed_conv_w) | ||
module.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias'])) | ||
|
||
if module.cls_token is not None: | ||
module.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False)) | ||
|
||
pos_embed_w = _n2p(w[f'{prefix}pos_embedding'], t=False) | ||
if pos_embed_w.shape != module.pos_embed.shape: | ||
assert False, f'{pos_embed_w.shape}, {module.pos_embed.shape}' | ||
num_prefix_tokens = 0 if getattr(module, 'no_embed_class', False) else getattr(module, 'num_prefix_tokens', 1) | ||
pos_embed_w = resample_abs_pos_embed( # resize pos embedding when different size from pretrained weights | ||
pos_embed_w, | ||
new_size=module.patch_embed.grid_size, | ||
num_prefix_tokens=num_prefix_tokens, | ||
interpolation=interpolation, | ||
antialias=antialias, | ||
verbose=True, | ||
) | ||
module.pos_embed.copy_(pos_embed_w) | ||
|
||
mha_sub, b_sub, ln1_sub = (0, 0, 1) | ||
for i, block in enumerate(module.blocks.children()): | ||
block_prefix = f'{prefix}Transformer/encoderblock_{i}/' | ||
mha_prefix = block_prefix + f'MultiHeadDotProductAttention_{mha_sub}/' | ||
block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'])) | ||
block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'])) | ||
block.attn.qkv.weight.copy_(torch.cat([ | ||
_n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')])) | ||
block.attn.qkv.bias.copy_(torch.cat([ | ||
_n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')])) | ||
block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1)) | ||
block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'])) | ||
for r in range(2): | ||
getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/kernel'])) | ||
getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/bias'])) | ||
block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/scale'])) | ||
block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/bias'])) | ||
|
||
module.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale'])) | ||
module.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias'])) | ||
|
||
if module.attn_pool is not None: | ||
block_prefix = f'{prefix}MAPHead_0/' | ||
mha_prefix = block_prefix + f'MultiHeadDotProductAttention_0/' | ||
module.attn_pool.latent.copy_(_n2p(w[f'{block_prefix}probe'], t=False)) | ||
module.attn_pool.q.weight.copy_(_n2p(w[f'{mha_prefix}query/kernel'], t=False).flatten(1).T) | ||
module.attn_pool.q.bias.copy_(_n2p(w[f'{mha_prefix}query/bias'], t=False).reshape(-1)) | ||
module.attn_pool.kv.weight.copy_(torch.cat([ | ||
_n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('key', 'value')])) | ||
module.attn_pool.kv.bias.copy_(torch.cat([ | ||
_n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('key', 'value')])) | ||
module.attn_pool.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1)) | ||
module.attn_pool.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'])) | ||
module.attn_pool.norm.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'])) | ||
module.attn_pool.norm.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'])) | ||
for r in range(2): | ||
getattr(module.attn_pool.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_{r}/kernel'])) | ||
getattr(module.attn_pool.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_{r}/bias'])) | ||
|
||
def _convert_openclip_transformer(module: Transformer, prefix): | ||
for i, block in enumerate(module.resblocks.children()): | ||
block_prefix = f'{prefix}encoderblock_{i}/' | ||
mha_prefix = block_prefix + f'MultiHeadDotProductAttention_0/' | ||
block.ln_1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'])) | ||
block.ln_1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'])) | ||
block.attn.in_proj_weight.copy_(torch.cat([ | ||
_n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')])) | ||
block.attn.in_proj_bias.copy_(torch.cat([ | ||
_n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')])) | ||
block.attn.out_proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1)) | ||
block.attn.out_proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'])) | ||
block.ln_2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_1/scale'])) | ||
block.ln_2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_1/bias'])) | ||
block.mlp.c_fc.weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_0/kernel'])) | ||
block.mlp.c_fc.bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_0/bias'])) | ||
block.mlp.c_proj.weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_1/kernel'])) | ||
block.mlp.c_proj.bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_1/bias'])) | ||
|
||
def _convert_openclip_txt(module: TextTransformer, prefix): | ||
module.token_embedding.weight.copy_(_n2p(w[f'{prefix}Embed_0/embedding'], t=False)) | ||
pos_embed_w = _n2p(w[f'{prefix}pos_embedding'], t=False).squeeze(0) | ||
module.positional_embedding.copy_(pos_embed_w) | ||
_convert_openclip_transformer(module.transformer, prefix=prefix + 'Encoder_0/') | ||
module.ln_final.weight.copy_(_n2p(w[f'{prefix}Encoder_0/encoder_norm/scale'])) | ||
module.ln_final.bias.copy_(_n2p(w[f'{prefix}Encoder_0/encoder_norm/bias'])) | ||
module.text_projection.weight.copy_(_n2p(w[f'{prefix}head/kernel'])) | ||
module.text_projection.bias.copy_(_n2p(w[f'{prefix}head/bias'])) | ||
|
||
_convert_timm_img(model.visual.trunk, 'params/img/') | ||
_convert_openclip_txt(model.text, 'params/txt/') | ||
model.logit_bias.copy_(_n2p(w['params/b'])[0]) | ||
model.logit_scale.copy_(_n2p(w['params/t'])[0]) | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,6 @@ | ||
OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073) | ||
OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711) | ||
IMAGENET_MEAN = (0.485, 0.456, 0.406) | ||
IMAGENET_STD = (0.229, 0.224, 0.225) | ||
INCEPTION_MEAN = (0.5, 0.5, 0.5) | ||
INCEPTION_STD = (0.5, 0.5, 0.5) |
Oops, something went wrong.