Releases: marsggbo/hyperbox
Releases · marsggbo/hyperbox
hyperbox_proxyless_supernet_and_subnets
1. load weights from pretrained supernet
from hyperbox.networks.proxylessnas.network import ProxylessNAS
supernet = ProxylessNAS(
width_stages=[30, 40, 80, 96, 182, 320],
n_cell_stages=[4, 4, 4, 4, 4, 1],
stride_stages=[2, 2, 2, 1, 2, 1],
width_mult=1.4,
num_classes=1000,
dropout_rate=0,
bn_param=[0.1, 0.001],
)
ckpt = torch.load('/path/to/hyperbox_proxy_mobile_w1.4.pth',map_location='cpu')
supernet.load_state_dict(ckpt)
2. load weights from pretrained subnet
- subnet weight 1: test accuracy on ImageNet is 77.15%
from hyperbox.networks.proxylessnas.network import ProxylessNAS
subnet= ProxylessNAS(
width_stages=[30, 40, 80, 96, 182, 320],
n_cell_stages=[4, 4, 4, 4, 4, 1],
stride_stages=[2, 2, 2, 1, 2, 1],
width_mult=1.4,
num_classes=1000,
dropout_rate=0,
bn_param=[0.1, 0.001],
mask='subnet_acc77.15.json'
)
ckpt = torch.load('/path/to/hyperbox_proxylessnas_w1.4_acc77.15_subnet.pth',map_location='cpu')
subnet.load_state_dict(ckpt)
- subnet weight 2: test accuracy on ImageNet is 77.21%
from hyperbox.networks.proxylessnas.network import ProxylessNAS
subnet= ProxylessNAS(
width_stages=[30, 40, 80, 96, 182, 320],
n_cell_stages=[4, 4, 4, 4, 4, 1],
stride_stages=[2, 2, 2, 1, 2, 1],
width_mult=1.4,
num_classes=1000,
dropout_rate=0,
bn_param=[0.1, 0.001],
mask='subnet_acc77.21.json'
)
ckpt = torch.load('/path/to/hyperbox_proxylessnas_w1.4_acc77.21_subnet.pth',map_location='cpu')
subnet.load_state_dict(ckpt)
hyperbox_OFA_MBV3_k357_d234_e346_w1.2.pth
from hyperbox.networks.ofa.ofa_mbv3 import OFAMobileNetV3
supernet = OFAMobileNetV3(
first_stride=2,
kernel_size_list=[3, 5, 7],
expand_ratio_list=[4, 6],
depth_list=[3, 4],
base_stage_width=[16, 16, 24, 40, 80, 112, 160, 960, 1280],
stride_stages=[1, 2, 2, 2, 1, 2],
act_stages=['relu', 'relu', 'relu', 'h_swish', 'h_swish', 'h_swish'],
se_stages=[False, False, True, False, True, True],
width_mult=1.2,
num_classes=1000,
)
ckpt = torch.load('/path/to/hyperbox_OFA_MBV3_k357_d234_e346_w1.2.pth', map_location='cpu')
supernet.load_state_dict(ckpt)
vit pre-trained weights
1. Convert pre-trained weights from https://github.com/jeonsworld/ViT-pytorch
- git clone repo
git clone https://github.com/jeonsworld/ViT-pytorch
- download weights
wget https://storage.googleapis.com/vit_models/imagenet21k+imagenet2012/ViT-B_16.npz
- weights conversion
from hyperbox.networks.vit import ViT_B
from models.modeling import VisionTransformer, CONFIGS
from functools import partial
ViT_B_16 = partial(VisionTransformer, config=CONFIGS['ViT-B_16'])
def sync_params(net1, net2):
"""
Args:
net1: src net
net2: tgt net
"""
count_size = lambda model: sum(p.numel() for p in model.parameters() if p.requires_grad)
num_params1 = count_size(net1)
num_params2 = count_size(net2)
print(f"ViT-B_16: {num_params1} | ViT-B: {num_params2}")
num_sync_params = 0
net2.vit_embed.load_state_dict(net1.transformer.embeddings.state_dict())
num_sync_params += count_size(net2.vit_embed)
assert len(net1.transformer.encoder.layer)==len(net2.vit_blocks)
for i in range(len(net1.transformer.encoder.layer)):
layer1 = net1.transformer.encoder.layer[i]
layer2 = net2.vit_blocks[i]
layer2.attn.block.norm
# attn
## norm
attention_norm1 = layer1.attention_norm
attention_norm2 = layer2.attn.block.norm
attention_norm2.load_state_dict(attention_norm1.state_dict())
num_sync_params += count_size(attention_norm2)
## qkv
q = layer1.attn.query
k = layer1.attn.key
v = layer1.attn.value
qkv = layer2.attn.block.fn.to_qkv
qkv.weight.data.copy_(torch.cat([q.weight, k.weight, v.weight], dim=0).data)
if qkv.bias is not None:
qkv.bias.data.copy_(torch.cat([q.bias, k.bias, v.bias], dim=0))
num_sync_params += count_size(qkv)
## fc
out1 = layer1.attn.out
out2 = layer2.attn.block.fn.to_out[0]
out2.load_state_dict(out1.state_dict())
num_sync_params += count_size(out2)
# ff
## norm
ffn_norm1 = layer1.ffn_norm
ffn_norm2 = layer2.ff.block.norm
ffn_norm2.load_state_dict(ffn_norm1.state_dict())
num_sync_params += count_size(ffn_norm2)
## fc
mlp_fc11 = layer1.ffn.fc1
mlp_fc12 = layer1.ffn.fc2
mlp_fc21 = layer2.ff.block.fn.net[0]
mlp_fc22 = layer2.ff.block.fn.net[3]
mlp_fc21.load_state_dict(mlp_fc11.state_dict())
mlp_fc22.load_state_dict(mlp_fc12.state_dict())
num_sync_params += count_size(mlp_fc21)
num_sync_params += count_size(mlp_fc22)
# head
## norm
norm1 = net1.transformer.encoder.encoder_norm
norm2 = net2.vit_cls_head.mlp_head[0]
norm2.load_state_dict(norm1.state_dict())
num_sync_params += count_size(norm2)
## fc
fc1 = net1.head
fc2 = net2.vit_cls_head.mlp_head[1]
try:
fc2.load_state_dict(fc1.state_dict())
num_sync_params += count_size(fc2)
except:
pass
print(f"sync params: {num_sync_params}")
net1 = ViT_B_16()
net1.load_from(np.load('/path/to/ViT-B_16.npz'))
net2 = ViT_B()
sync_params(net1, net2)
torch.save(net2.state_dict(), 'vit_b.pth')
2. Validate the pretrained weights
import torch
import torchvision
import torchvision.transforms as transforms
from hyperbox.networks.vit import ViT_B, ViT_L
def testloader(data_path='/path/to/imagenet2012/val', batch_size=400, num_workers=4):
"""Create test dataloader for ImageNet dataset."""
# Define data transforms
data_transforms = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# Load the test dataset
test_dataset = torchvision.datasets.ImageFolder(root=data_path, transform=data_transforms)
# Create a test dataloader
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
return test_loader
def validate(loader, model, criterion, device, verbose=True):
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for batch_idx, (x, y) in enumerate(loader):
x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)
output = model(x)
loss = criterion(output, y)
test_loss += loss.item()
pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
correct += pred.eq(y.view_as(pred)).sum().item()
if batch_idx % 10 == 0 and verbose:
print('Test: [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
batch_idx * len(x), len(loader.dataset),
100. * batch_idx / len(loader), loss.item()))
test_loss /= len(loader.dataset)
test_acc = 100. * correct / len(loader.dataset)
if verbose:
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
test_loss, correct, len(loader.dataset), test_acc))
return test_loss, test_acc
net = ViT_B()
net.load_state_dict('vit_b.pth')
device = torch.device("cuda") if torch.cuda.is_available() else torch.device('cpu')
loader = testloader()
criterion = torch.nn.CrossEntropyLoss()
validate(loader, net.to(device), criterion, device)
3. Results
Model | pth | Dataset | Acc@top1 |
---|---|---|---|
ViT_B | vit_b.pth | ImageNet1k | 75.66% |
ViT_B(patch_size=32) | vit_b_32.pth | ImageNet1k | 64.44% |
ViT_L | vit_L.pth | ImageNet1k | 79.25% |
ViT_L(patch_size=32) | vit_L_32.pth | ImageNet1k | 65.19% |
ViT_H(patch_size=14, num_classes=21843) | vit_H_14.pth | ImageNet21k | - |
The pretrained weights for ViT_H_14 (patch size=14) is too large (~2.45G), so that we split it into multiple smaller chunks, i.e., vit_H_14.pth.parta*
. To use the full weights, you can cat them into a single pth file after downloading them:
cat vit_H_14.pth.part* > vit_H_14.pth
By default, the series of ViT models provided by hyperbox use patch size of 16. To use vit_H_14.pth, you may need to build a model by modifying:
patch_size=14
num_classes=21843
(the model is pretrained based on ImageNet21k)
import torch
from hyperbox.networks.vit import ViT_H
vit_h_14 = ViT_H(patch_size=14, num_classes=21843)
vit_h_14.load_state_dict(torch.load('vit_H_14.pth'))
OFA_MBV3_k357_d234_e46_w1.pth
- hyperbox-based OFA-MobileNetV3
import torch
from hyperbox.networks.ofa import OFAMobileNetV3
device = 'cuda' if torch.cuda.is_available() else 'cpu'
weight = torch.load('path/to/hyperbox_OFA_MBV3_k357_d234_e46_w1.pth')
supernet = OFAMobileNetV3()
supernet.load_state_dict(weight)
mask = supernet.gen_mask(depth=4, expand_ratio=6, kernel_size=7)
net = supernet.build_subnet(mask).to(device)
- Official OFA-MobileNetV3
# https://github.com/mit-han-lab/once-for-all
import torch
from ofa.imagenet_classification.elastic_nn.networks.ofa_mbv3 import OFAMobileNetV3
device = 'cuda' if torch.cuda.is_available() else 'cpu'
weight = torch.load('path/to/official_OFA_MBV3_k357_d234_e46_w1.pth')
supernet = OFAMobileNetV3(dropout_rate=0, width_mult=1.0, ks_list=[3, 5, 7], expand_ratio_list=[3, 4, 6], depth_list=[2, 3, 4])
supernet.load_state_dict(weight)
supernet.set_active_subnet(ks=7, e=6, d=4)
net = supernet.get_active_subnet(preserve_weight=True).to(device)