Skip to content

vit pre-trained weights

Compare
Choose a tag to compare
@marsggbo marsggbo released this 11 Apr 13:35
· 29 commits to master since this release

1. Convert pre-trained weights from https://github.com/jeonsworld/ViT-pytorch

  • git clone repo
git clone https://github.com/jeonsworld/ViT-pytorch
  • download weights
wget https://storage.googleapis.com/vit_models/imagenet21k+imagenet2012/ViT-B_16.npz
  • weights conversion
from hyperbox.networks.vit import ViT_B

from models.modeling import VisionTransformer, CONFIGS
from functools import partial
ViT_B_16 = partial(VisionTransformer, config=CONFIGS['ViT-B_16'])

def sync_params(net1, net2):
    """
    Args:
        net1: src net
        net2: tgt net
    """        
    count_size = lambda model: sum(p.numel() for p in model.parameters() if p.requires_grad)
    num_params1 = count_size(net1)
    num_params2 = count_size(net2)
    print(f"ViT-B_16: {num_params1} | ViT-B: {num_params2}")

    num_sync_params = 0
    net2.vit_embed.load_state_dict(net1.transformer.embeddings.state_dict())
    num_sync_params += count_size(net2.vit_embed)
    
    assert len(net1.transformer.encoder.layer)==len(net2.vit_blocks)
    for i in range(len(net1.transformer.encoder.layer)):
        layer1 = net1.transformer.encoder.layer[i]
        layer2 = net2.vit_blocks[i]
        layer2.attn.block.norm
        
        # attn
        ## norm
        attention_norm1 = layer1.attention_norm
        attention_norm2 = layer2.attn.block.norm
        attention_norm2.load_state_dict(attention_norm1.state_dict())
        num_sync_params += count_size(attention_norm2)
        
        ## qkv
        q = layer1.attn.query
        k = layer1.attn.key
        v = layer1.attn.value
        qkv = layer2.attn.block.fn.to_qkv
        qkv.weight.data.copy_(torch.cat([q.weight, k.weight, v.weight], dim=0).data)
        if qkv.bias is not None:
            qkv.bias.data.copy_(torch.cat([q.bias, k.bias, v.bias], dim=0))
        num_sync_params += count_size(qkv)

        ## fc
        out1 = layer1.attn.out
        out2 = layer2.attn.block.fn.to_out[0]
        out2.load_state_dict(out1.state_dict())
        num_sync_params += count_size(out2)

        # ff
        ## norm
        ffn_norm1 = layer1.ffn_norm
        ffn_norm2 = layer2.ff.block.norm
        ffn_norm2.load_state_dict(ffn_norm1.state_dict())
        num_sync_params += count_size(ffn_norm2)
        
        ## fc
        mlp_fc11 = layer1.ffn.fc1
        mlp_fc12 = layer1.ffn.fc2
        mlp_fc21 = layer2.ff.block.fn.net[0]
        mlp_fc22 = layer2.ff.block.fn.net[3]
        mlp_fc21.load_state_dict(mlp_fc11.state_dict())
        mlp_fc22.load_state_dict(mlp_fc12.state_dict())
        num_sync_params += count_size(mlp_fc21)
        num_sync_params += count_size(mlp_fc22)
    
    # head
    ## norm
    norm1 = net1.transformer.encoder.encoder_norm
    norm2 = net2.vit_cls_head.mlp_head[0]
    norm2.load_state_dict(norm1.state_dict())
    num_sync_params += count_size(norm2)
    
    ## fc
    fc1 = net1.head
    fc2 = net2.vit_cls_head.mlp_head[1]
    try:
        fc2.load_state_dict(fc1.state_dict())
        num_sync_params += count_size(fc2)
    except:
        pass
    print(f"sync params: {num_sync_params}")  

net1 = ViT_B_16()
net1.load_from(np.load('/path/to/ViT-B_16.npz'))

net2 = ViT_B()
sync_params(net1, net2)
torch.save(net2.state_dict(), 'vit_b.pth')

2. Validate the pretrained weights

import torch
import torchvision
import torchvision.transforms as transforms

from hyperbox.networks.vit import ViT_B, ViT_L

def testloader(data_path='/path/to/imagenet2012/val', batch_size=400, num_workers=4):
    """Create test dataloader for ImageNet dataset."""
    # Define data transforms
    data_transforms = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

    # Load the test dataset
    test_dataset = torchvision.datasets.ImageFolder(root=data_path, transform=data_transforms)

    # Create a test dataloader
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

    return test_loader

def validate(loader, model, criterion, device, verbose=True):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for batch_idx, (x, y) in enumerate(loader):
            x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)
            output = model(x)
            loss = criterion(output, y)
            test_loss += loss.item()

            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(y.view_as(pred)).sum().item()
            if batch_idx % 10 == 0 and verbose:
                print('Test: [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    batch_idx * len(x), len(loader.dataset),
                    100. * batch_idx / len(loader), loss.item()))
    test_loss /= len(loader.dataset)
    test_acc = 100. * correct / len(loader.dataset)

    if verbose:
        print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
            test_loss, correct, len(loader.dataset), test_acc))
    return test_loss, test_acc

net = ViT_B()
net.load_state_dict('vit_b.pth')
device = torch.device("cuda") if torch.cuda.is_available() else torch.device('cpu')
loader = testloader()
criterion = torch.nn.CrossEntropyLoss()
validate(loader, net.to(device), criterion, device)

3. Results

Model pth Dataset Acc@top1
ViT_B vit_b.pth ImageNet1k 75.66%
ViT_B(patch_size=32) vit_b_32.pth ImageNet1k 64.44%
ViT_L vit_L.pth ImageNet1k 79.25%
ViT_L(patch_size=32) vit_L_32.pth ImageNet1k 65.19%
ViT_H(patch_size=14, num_classes=21843) vit_H_14.pth ImageNet21k -

The pretrained weights for ViT_H_14 (patch size=14) is too large (~2.45G), so that we split it into multiple smaller chunks, i.e., vit_H_14.pth.parta*. To use the full weights, you can cat them into a single pth file after downloading them:

cat vit_H_14.pth.part* > vit_H_14.pth

By default, the series of ViT models provided by hyperbox use patch size of 16. To use vit_H_14.pth, you may need to build a model by modifying:

  • patch_size=14
  • num_classes=21843 (the model is pretrained based on ImageNet21k)
import torch
from hyperbox.networks.vit import ViT_H
vit_h_14 = ViT_H(patch_size=14, num_classes=21843)
vit_h_14.load_state_dict(torch.load('vit_H_14.pth'))