diff --git a/KD_Lib/utils/__init__.py b/KD_Lib/utils/__init__.py new file mode 100644 index 00000000..81c41531 --- /dev/null +++ b/KD_Lib/utils/__init__.py @@ -0,0 +1 @@ +from .pipeline import Pipeline diff --git a/KD_Lib/utils/pipeline.py b/KD_Lib/utils/pipeline.py new file mode 100644 index 00000000..b7a2d2f9 --- /dev/null +++ b/KD_Lib/utils/pipeline.py @@ -0,0 +1,117 @@ +from itertools import islice +from tqdm import tqdm +import time + +from KD_Lib.KD.common import BaseClass + + +class Pipeline: + """ + Pipeline of knowledge distillation, pruning and quantization methods + supported by KD_Lib. Sequentially applies a list of methods on the student model. + + All the elements in list must implement either train_student, prune or quantize + methods. + + :param: steps (list) list of KD_Lib.KD or KD_Lib.Pruning or KD_Lib.Quantization + :param: epochs (int) number of iterations through whole batch for each method in + list + :param: plot_losses (bool) Plot a graph of losses during training + :param: save_model (bool) Save model after performing the list methods + :param: save_model_pth (str) Path where model is saved if save_model is True + :param: verbose (int) Verbose + """ + + def __init__( + self, + steps, + epochs=5, + plot_losses=True, + save_model=True, + save_model_pth="./models/student.pt", + verbose=0, + ): + self.steps = steps + self.device = device + self.verbose = verbose + + self.plot_losses = plot_losses + self.save_model = save_model + self.save_model_path = save_model_pth + self._validate_steps() + self.epochs = epochs + + def _validate_steps(self): + name, process = zip(*self.steps) + + for t in process: + if not hasattr(t, ("train_student", "prune", "quantize")): + raise TypeError( + "All the steps must support at least one of " + "train_student, prune or quantize method, {} is not" + " supported yet".format(str(t)) + ) + + def get_steps(self): + return self.steps + + def _iter(self, num_steps=-1): + _length = len(self.steps) if num_steps == -1 else num_steps + + for idx, (name, process) in enumerate(islice(self.steps, 0, _length)): + yield idx, name, process + + def _fit(self): + + if self.verbose: + pbar = tqdm(total=len(self)) + + for idx, name, process in self._iter(): + print("Starting {}".format(name)) + if idx != 0: + if hasattr(process, "train_student"): + if hasattr(self.steps[idx - 1], "train_student"): + process.student_model = self.steps[idx - 1].student_model + else: + process.student_model = self.steps[idx - 1].model + t1 = time.time() + if hasattr(process, "train_student"): + process.train_student( + self.epochs, self.plot_losses, self.save_model, self.save_model_path + ) + elif hasattr(proces, "prune"): + process.prune() + elif hasattr(process, "quantize"): + process.quantize() + else: + raise TypeError( + "{} is not supported by the pipeline yet.".format(process) + ) + + t2 = time.time() - t1 + print( + "{} completed in {}hr {}min {}s".format( + name, t2 // (60 * 60), t2 // 60, t2 % 60 + ) + ) + + if self.verbose: + pbar.update(1) + + if self.verbose: + pbar.close() + + def train(self): + """ + Train the (student) model sequentially through the list. + """ + self._validate_steps() + + t1 = time.time() + self._fit() + t2 = time.time() - t1 + print( + "Pipeline execution completed in {}hr {}min {}s".format( + t2 // (60 * 60), t2 // 60, t2 % 60 + ) + ) diff --git a/setup.py b/setup.py index dd5add73..d840bb84 100755 --- a/setup.py +++ b/setup.py @@ -13,67 +13,74 @@ LONG_DESCRIPTION = f.read() # Define the keywords -KEYWORDS = ["Knowledge Distillation", "Pruning", "Quantization", "pytorch", "machine learning", "deep learning"] +KEYWORDS = [ + "Knowledge Distillation", + "Pruning", + "Quantization", + "pytorch", + "machine learning", + "deep learning", +] REQUIRE_PATH = "requirements.txt" PROJECT = os.path.abspath(os.path.dirname(__file__)) -setup_requirements = ['pytest-runner'] +setup_requirements = ["pytest-runner"] -test_requirements = ['pytest', 'pytest-cov'] +test_requirements = ["pytest", "pytest-cov"] requirements = [ -'pip==19.3.1', -'transformers==4.6.1', -'sacremoses', -'tokenizers==0.10.1', -'huggingface-hub==0.0.8', -'torchtext==0.9.1', -'bumpversion==0.5.3', -'wheel==0.32.1', -'watchdog==0.9.0', -'flake8==3.5.0', -'tox==3.5.2', -'coverage==4.5.1', -'Sphinx==1.8.1', -'twine==1.12.1', -'pytest==3.8.2', -'pytest-runner==4.2', -'pytest-cov==2.6.1', -'matplotlib==3.2.1', -'torch==1.8.1', -'torchvision==0.9.1', -'tensorboard==2.2.1', -'contextlib2==0.6.0.post1', -'pandas==1.0.1', -'tqdm==4.42.1', -'numpy==1.18.1', -'sphinx-rtd-theme==0.5.0', + "pip==19.3.1", + "transformers==4.6.1", + "sacremoses", + "tokenizers==0.10.1", + "huggingface-hub==0.0.8", + "torchtext==0.9.1", + "bumpversion==0.5.3", + "wheel==0.32.1", + "watchdog==0.9.0", + "flake8==3.5.0", + "tox==3.5.2", + "coverage==4.5.1", + "Sphinx==1.8.1", + "twine==1.12.1", + "pytest==3.8.2", + "pytest-runner==4.2", + "pytest-cov==2.6.1", + "matplotlib==3.2.1", + "torch==1.8.1", + "torchvision==0.9.1", + "tensorboard==2.2.1", + "contextlib2==0.6.0.post1", + "pandas==1.0.1", + "tqdm==4.42.1", + "numpy==1.18.1", + "sphinx-rtd-theme==0.5.0", ] if __name__ == "__main__": setup( - author="Het Shah", - author_email='divhet163@gmail.com', - classifiers=[ - 'Development Status :: 2 - Pre-Alpha', - 'Intended Audience :: Developers', - 'License :: OSI Approved :: MIT License', - 'Natural Language :: English', - 'Programming Language :: Python :: 3.6', - 'Programming Language :: Python :: 3.7', - ], - description="A Pytorch Library to help extend all Knowledge Distillation works", - install_requires=requirements, - license="MIT license", - long_description=LONG_DESCRIPTION, - include_package_data=True, - keywords=KEYWORDS, - name='KD_Lib', - packages=find_packages(where=PROJECT), - setup_requires=setup_requirements, - test_suite="tests", - tests_require=test_requirements, - url="https://github.com/SforAiDL/KD_Lib", - version='0.0.29', - zip_safe=False, -) + author="Het Shah", + author_email="divhet163@gmail.com", + classifiers=[ + "Development Status :: 2 - Pre-Alpha", + "Intended Audience :: Developers", + "License :: OSI Approved :: MIT License", + "Natural Language :: English", + "Programming Language :: Python :: 3.6", + "Programming Language :: Python :: 3.7", + ], + description="A Pytorch Library to help extend all Knowledge Distillation works", + install_requires=requirements, + license="MIT license", + long_description=LONG_DESCRIPTION, + include_package_data=True, + keywords=KEYWORDS, + name="KD_Lib", + packages=find_packages(where=PROJECT), + setup_requires=setup_requirements, + test_suite="tests", + tests_require=test_requirements, + url="https://github.com/SforAiDL/KD_Lib", + version="0.0.29", + zip_safe=False, + ) diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py new file mode 100644 index 00000000..55d954db --- /dev/null +++ b/tests/test_pipeline.py @@ -0,0 +1,54 @@ +from KD_Lib.utils import Pipeline +from KD_Lib.KD import VanillaKD +from KD_Lib.Pruning import Lottery_Tickets_Pruner +from KD_Lib.Quantization import Dynamic_Quantizer +from KD_Lib.models import Shallow + +import torch +from torchvision import datasets, transforms +import torch.optim as optim + + +train_loader = torch.utils.data.DataLoader( + datasets.MNIST( + "mnist_data", + train=True, + download=True, + transform=transforms.Compose( + [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))] + ), + ), + batch_size=32, + shuffle=True, +) + +test_loader = torch.utils.data.DataLoader( + datasets.MNIST( + "mnist_data", + train=False, + transform=transforms.Compose( + [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))] + ), + ), + batch_size=32, + shuffle=True, +) + + +def test_Pipeline(): + teacher = Shallow(hidden_size=400) + student = Shallow(hidden_size=100) + + t_optimizer = optim.SGD(teacher.parameters(), 0.01) + s_optimizer = optim.SGD(student.parameters(), 0.01) + + distiller = VanillaKD( + teacher, student, train_loader, test_loader, t_optimizer, s_optimizer + ) + + pruner = Lottery_Tickets_Pruner(student, train_loader, test_loader) + + quantizer = Dynamic_Quantizer(student, test_loader, {torch.nn.Linear}) + + pipe = Pipeline([distiller, pruner, quantizer], 1) + pipe.train()