This repository contains the implementations of neural network weight pruning methods.
Pruning has three stages:
- Training
- Pruning:
- OneShotPruning
- Retraining
- No Retraining (none)
- FineTuning (fine-tuning)
- Learning Rate Rewinding (lr-rewinding)
- Weight Rewinding (weight-rewinding)
Install the following python libraries:
torch>=1.6.0 torchvision>=0.7.0 numpy tqdm livelossplot
Run :
$ mkdir saved_models
To get the pruned models you need:
- A Model (torch.nn.module)
- for which forward() takes input as a single variable "batch"
- which returns loss function (by default) when the forward() is called, and returns the prediction when get_prediction argument is set to true in the forward.
- Torch Data Loaders (train_dataloader, val_dataloader)
See the Jupyter Notebooks (.ipynb files) for a better idea about how to do the pruning. This repository contains implementations of the following Models:
lenet, resnet20, resnet32, resnet44, resnet56, resnet110, resnet1202
and functions to load data loaders for MNIST, CIFAR-10
.
|--- models
| |--- __init__.py
| |--- lenet.py
| |--- resnet.py
|--- pruning
| |--- unstructured
| | |--- __init__.py
| | |--- one_shot_pruning.py
| | |--- iterative_pruning.py
| |--- structured
| | |--- __init__.py
|--- utils
| |--- __init__.py
| |--- data_utils.py
| |--- train_utils.py
|--- Lisence
|--- README.md
|--- LISENCE
- Write code for weight distribution visualizations for each layer.
- Write code for iterative pruning.
- Write code for ADMM sparsity regularization.