-
Notifications
You must be signed in to change notification settings - Fork 100
/
setup.py
88 lines (64 loc) · 2.38 KB
/
setup.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
"""
torchgpipe
==========
A GPipe_ implementation in PyTorch_.
.. _GPipe: https://arxiv.org/abs/1811.06965
.. _PyTorch: https://pytorch.org/
.. sourcecode:: python
from torchgpipe import GPipe
model = nn.Sequential(a, b, c, d)
model = GPipe(model, balance=[1, 1, 1, 1], chunks=8)
for input in data_loader:
output = model(input)
What is GPipe?
~~~~~~~~~~~~~~
GPipe is a scalable pipeline parallelism library published by Google Brain,
which allows efficient training of large, memory-consuming models. According to
the paper, GPipe can train a 25x larger model by using 8x devices (TPU), and
train a model 3.5x faster by using 4x devices.
`GPipe: Efficient Training of Giant Neural Networks using Pipeline Parallelism
<https://arxiv.org/abs/1811.06965>`_
Google trained AmoebaNet-B with 557M parameters over GPipe. This model has
achieved 84.3% top-1 and 97.0% top-5 accuracy on ImageNet classification
benchmark (the state-of-the-art performance as of May 2019).
Links
~~~~~
- Source Code: https://github.com/kakaobrain/torchgpipe
- Documentation: https://torchgpipe.readthedocs.io/
- Original Paper: https://arxiv.org/abs/1811.06965
"""
from setuptools import setup
about = {} # type: ignore
with open('torchgpipe/__version__.py') as f:
exec(f.read(), about) # pylint: disable=W0122
version = about['__version__']
del about
setup(
name='torchgpipe',
version=version,
license='BSD-3-Clause',
url='https://github.com/kakaobrain/torchgpipe',
author='Kakao Brain',
maintainer='Heungsub Lee, Myungryong Jeong, Chiheon Kim',
description='GPipe for PyTorch',
long_description=__doc__,
keywords='pytorch gpipe',
zip_safe=False,
packages=['torchgpipe', 'torchgpipe.balance', 'torchgpipe.skip'],
package_data={'torchgpipe': ['py.typed']},
py_modules=['torchgpipe_balancing'],
install_requires=['torch>=1.1'],
setup_requires=['pytest-runner'],
tests_require=['pytest>=4'],
classifiers=[
'Development Status :: 3 - Alpha',
'Intended Audience :: Science/Research',
'License :: OSI Approved :: BSD License',
'Programming Language :: Python :: 3',
'Programming Language :: Python :: 3.6',
'Programming Language :: Python :: 3.7',
'Programming Language :: Python :: 3 :: Only',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'Typing :: Typed',
],
)