vit pre-trained weights
1. Convert pre-trained weights from https://github.com/jeonsworld/ViT-pytorch
- git clone repo
git clone https://github.com/jeonsworld/ViT-pytorch
- download weights
wget https://storage.googleapis.com/vit_models/imagenet21k+imagenet2012/ViT-B_16.npz
- weights conversion
from hyperbox.networks.vit import ViT_B
from models.modeling import VisionTransformer, CONFIGS
from functools import partial
ViT_B_16 = partial(VisionTransformer, config=CONFIGS['ViT-B_16'])
def sync_params(net1, net2):
"""
Args:
net1: src net
net2: tgt net
"""
count_size = lambda model: sum(p.numel() for p in model.parameters() if p.requires_grad)
num_params1 = count_size(net1)
num_params2 = count_size(net2)
print(f"ViT-B_16: {num_params1} | ViT-B: {num_params2}")
num_sync_params = 0
net2.vit_embed.load_state_dict(net1.transformer.embeddings.state_dict())
num_sync_params += count_size(net2.vit_embed)
assert len(net1.transformer.encoder.layer)==len(net2.vit_blocks)
for i in range(len(net1.transformer.encoder.layer)):
layer1 = net1.transformer.encoder.layer[i]
layer2 = net2.vit_blocks[i]
layer2.attn.block.norm
# attn
## norm
attention_norm1 = layer1.attention_norm
attention_norm2 = layer2.attn.block.norm
attention_norm2.load_state_dict(attention_norm1.state_dict())
num_sync_params += count_size(attention_norm2)
## qkv
q = layer1.attn.query
k = layer1.attn.key
v = layer1.attn.value
qkv = layer2.attn.block.fn.to_qkv
qkv.weight.data.copy_(torch.cat([q.weight, k.weight, v.weight], dim=0).data)
if qkv.bias is not None:
qkv.bias.data.copy_(torch.cat([q.bias, k.bias, v.bias], dim=0))
num_sync_params += count_size(qkv)
## fc
out1 = layer1.attn.out
out2 = layer2.attn.block.fn.to_out[0]
out2.load_state_dict(out1.state_dict())
num_sync_params += count_size(out2)
# ff
## norm
ffn_norm1 = layer1.ffn_norm
ffn_norm2 = layer2.ff.block.norm
ffn_norm2.load_state_dict(ffn_norm1.state_dict())
num_sync_params += count_size(ffn_norm2)
## fc
mlp_fc11 = layer1.ffn.fc1
mlp_fc12 = layer1.ffn.fc2
mlp_fc21 = layer2.ff.block.fn.net[0]
mlp_fc22 = layer2.ff.block.fn.net[3]
mlp_fc21.load_state_dict(mlp_fc11.state_dict())
mlp_fc22.load_state_dict(mlp_fc12.state_dict())
num_sync_params += count_size(mlp_fc21)
num_sync_params += count_size(mlp_fc22)
# head
## norm
norm1 = net1.transformer.encoder.encoder_norm
norm2 = net2.vit_cls_head.mlp_head[0]
norm2.load_state_dict(norm1.state_dict())
num_sync_params += count_size(norm2)
## fc
fc1 = net1.head
fc2 = net2.vit_cls_head.mlp_head[1]
try:
fc2.load_state_dict(fc1.state_dict())
num_sync_params += count_size(fc2)
except:
pass
print(f"sync params: {num_sync_params}")
net1 = ViT_B_16()
net1.load_from(np.load('/path/to/ViT-B_16.npz'))
net2 = ViT_B()
sync_params(net1, net2)
torch.save(net2.state_dict(), 'vit_b.pth')
2. Validate the pretrained weights
import torch
import torchvision
import torchvision.transforms as transforms
from hyperbox.networks.vit import ViT_B, ViT_L
def testloader(data_path='/path/to/imagenet2012/val', batch_size=400, num_workers=4):
"""Create test dataloader for ImageNet dataset."""
# Define data transforms
data_transforms = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# Load the test dataset
test_dataset = torchvision.datasets.ImageFolder(root=data_path, transform=data_transforms)
# Create a test dataloader
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
return test_loader
def validate(loader, model, criterion, device, verbose=True):
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for batch_idx, (x, y) in enumerate(loader):
x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)
output = model(x)
loss = criterion(output, y)
test_loss += loss.item()
pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
correct += pred.eq(y.view_as(pred)).sum().item()
if batch_idx % 10 == 0 and verbose:
print('Test: [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
batch_idx * len(x), len(loader.dataset),
100. * batch_idx / len(loader), loss.item()))
test_loss /= len(loader.dataset)
test_acc = 100. * correct / len(loader.dataset)
if verbose:
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
test_loss, correct, len(loader.dataset), test_acc))
return test_loss, test_acc
net = ViT_B()
net.load_state_dict('vit_b.pth')
device = torch.device("cuda") if torch.cuda.is_available() else torch.device('cpu')
loader = testloader()
criterion = torch.nn.CrossEntropyLoss()
validate(loader, net.to(device), criterion, device)
3. Results
Model | pth | Dataset | Acc@top1 |
---|---|---|---|
ViT_B | vit_b.pth | ImageNet1k | 75.66% |
ViT_B(patch_size=32) | vit_b_32.pth | ImageNet1k | 64.44% |
ViT_L | vit_L.pth | ImageNet1k | 79.25% |
ViT_L(patch_size=32) | vit_L_32.pth | ImageNet1k | 65.19% |
ViT_H(patch_size=14, num_classes=21843) | vit_H_14.pth | ImageNet21k | - |
The pretrained weights for ViT_H_14 (patch size=14) is too large (~2.45G), so that we split it into multiple smaller chunks, i.e., vit_H_14.pth.parta*
. To use the full weights, you can cat them into a single pth file after downloading them:
cat vit_H_14.pth.part* > vit_H_14.pth
By default, the series of ViT models provided by hyperbox use patch size of 16. To use vit_H_14.pth, you may need to build a model by modifying:
patch_size=14
num_classes=21843
(the model is pretrained based on ImageNet21k)
import torch
from hyperbox.networks.vit import ViT_H
vit_h_14 = ViT_H(patch_size=14, num_classes=21843)
vit_h_14.load_state_dict(torch.load('vit_H_14.pth'))