Skip to content

Commit

Permalink
update checkpoint infos to compatible with detectron2
Browse files Browse the repository at this point in the history
  • Loading branch information
tkianai committed Apr 27, 2020
1 parent 8c68af9 commit c4bf7be
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 8 deletions.
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
# Custome
.vscode
datasets/
out-of-date/

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ This provides a convenient way to initialize backbone in detectron2.
3. Run: `python train_net_builtin.py --num-gpus <gpu number> --config-file configs/<your config file>`. For example: `sh scripts/train_net_builtin.sh`


- Trained with pytorch formal imagenet trainer
- Trained with pytorch formal imagenet trainer [**Recommend**]

1. Read carefully with some arguments in `train_net.py`
2. Run: `sh /scripts/train_net.sh`
6 changes: 3 additions & 3 deletions imgcls/modeling/meta_arch/clsnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def __init__(self, cfg):

self.num_classes = cfg.MODEL.CLSNET.NUM_CLASSES
self.in_features = cfg.MODEL.CLSNET.IN_FEATURES
self.backbone = build_backbone(cfg)
self.bottom_up = build_backbone(cfg)
self.criterion = nn.CrossEntropyLoss()

self.register_buffer("pixel_mean", torch.Tensor(cfg.MODEL.PIXEL_MEAN).view(-1, 1, 1))
Expand All @@ -63,7 +63,7 @@ def forward_d2(self, batched_inputs):
images = self.preprocess_image(batched_inputs)
gt_labels = [x['label'] for x in batched_inputs]
gt_labels = torch.as_tensor(gt_labels, dtype=torch.long).to(self.device)
features = self.backbone(images.tensor)
features = self.bottom_up(images.tensor)
features = [features[f] for f in self.in_features]

if self.training:
Expand All @@ -79,7 +79,7 @@ def forward_d2(self, batched_inputs):
return processed_results

def forward(self, images):
features = self.backbone(images)
features = self.bottom_up(images)
return features["linear"]


Expand Down
2 changes: 1 addition & 1 deletion scripts/train_net.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@
###


CUDA_VISIBLE_DEVICES=4,5,6,7 python train_net.py --config-file configs/Base_image_cls.yaml --batch-size 2048 --dist-url 'tcp://127.0.0.1:51151' --dist-backend 'nccl' --multiprocessing-distributed --world-size 1 --rank 0 datasets/ImageNet2012
CUDA_VISIBLE_DEVICES=4,5,6,7 python train_net.py --config-file configs/Base_image_cls.yaml --batch-size 1024 --dist-url 'tcp://127.0.0.1:51151' --dist-backend 'nccl' --multiprocessing-distributed --world-size 1 --rank 0 datasets/ImageNet2012
17 changes: 14 additions & 3 deletions train_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ def main_worker(gpu, ngpus_per_node, args, cfg):
if args.gpu is not None:
# best_acc1 may be from a checkpoint from a different GPU
best_acc1 = best_acc1.to(args.gpu)
model.load_state_dict(checkpoint['state_dict'])
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
print("=> loaded checkpoint '{}' (epoch {})"
.format(args.resume, checkpoint['epoch']))
Expand Down Expand Up @@ -278,9 +278,11 @@ def main_worker(gpu, ngpus_per_node, args, cfg):

if not args.multiprocessing_distributed or (args.multiprocessing_distributed
and args.rank % ngpus_per_node == 0):

save_checkpoint({
'epoch': epoch + 1,
'state_dict': model.state_dict(),
'model': model.state_dict(),
'matching_heuristics': True,
'best_acc1': best_acc1,
'optimizer': optimizer.state_dict(),
}, is_best)
Expand Down Expand Up @@ -381,7 +383,16 @@ def validate(val_loader, model, criterion, args):
def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
torch.save(state, filename)
if is_best:
shutil.copyfile(filename, 'model_best.pth.tar')
# strip 'module.'
from collections import OrderedDict
p = OrderedDict()
for key, value in state['model'].items():
if key.startswith('module.'):
key = key[7:]
p[key] = value
state['model'] = p
torch.save(state, "model_best.pth.tar")
# shutil.copyfile(filename, 'model_best.pth.tar')


class AverageMeter(object):
Expand Down

0 comments on commit c4bf7be

Please sign in to comment.