Skip to content

tsauri/soft-binary-mask

Repository files navigation

End-to-End Pruning with Soft-binary Weight Mask

  • Apply trainable channel mask on weights of WideResNet-34-9, constrained mask to [0,1]
  • Sparsity regularize mask by minimizing mean --wg and maximizing variance --tw
  • To prune params, multiply mask by weights
  • To prune FLOPs, prune dimensions by deleting output channel dims with mask == 0
  • Target sparsity percentage --target with value [0,1], and mask (reverse) weight decay --mwd to control mask unsparsification if miss --target
  • Warmup for several epochs before start regularization with --wme
  • Used Stochastic Weight Averaging to boost accuracy, start averaging at epoch --swa-start
  • Tried using Confidence Penalty --confp, Dropout --d, and Label Smoothing --lseps, but negligible effect, so set them to 0

Pre-trained models are in chkpt/

To reproduce models

Run this command to train from scratch. Training takes almost 12 hours. Single GPU only.

python -u mainmce.py --dataset cifar100 --arch wideresnet         --epochs 300 --seed 1 --batch-size 64      --lr 1e-2 --wd 1e-4 --momentum 0.9   --test-batch-size 64 --depth 34 --wf 9 --d 0 --lseps 0.0 --confp 0.0   --target 0.5  --wg 30 --tw 10 --wme 50  --mwd 5e-4  --save log-w1 --swa-start 90

Pruned model pruned_model.pth.tar is stored in --save folder log-w1

To evaluate models

  • Pretrained models are in chkpt/ folder.

  • eval.py is old eval code, for 16-bit --freebie.

  • eval_quantize.py is final eval code, uses Pytorch 1.3 quantization

  • Both uses CPU

  • compute_flops.py is the code to calculate FLOPs

  • base_params = sum(p.numel() for name, p in model.named_parameters())

This is the final submission. If 8-bit quantization is valid

python eval_quantize.py --pruned chkpt/pruned_model2.pth.tar

--------------------------------------------------------------------------------
params count /  4
original param count 20696798 @ 20.70 M
8-bit quantized param count 5174199 @ 5.17 M
original FLOPs count 10193691876.0 @ 10.19 GFLOPs
8-bit FLOPs count 2548422969.0 @ 2.55 GFLOPs
--------------------------------------------------------------------------------
Reference param count 36500000.0 @ 36.50 M
Reference FLOPs count 10490000000.0 @ 10.49 GFLOPs
--------------------------------------------------------------------------------
Score = Param/refParam + FLOPs/refFLOPs = 0.14176 + 0.24294 = 0.38470
--------------------------------------------------------------------------------
Test set: Average loss: 1.2146, Accuracy: 8019/10000 (80.19%)

Otherwise use fallback evaluation

python eval.py --pruned chkpt/pruned_model2.pth.tar  --freebie --half

About

micronet challenge

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages