forked from pytorch/ao
-
Notifications
You must be signed in to change notification settings - Fork 0
/
setup.py
141 lines (113 loc) · 4.23 KB
/
setup.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import os
import glob
from datetime import datetime
import subprocess
from setuptools import find_packages, setup
current_date = datetime.now().strftime("%Y.%m.%d")
def get_git_commit_id():
try:
return subprocess.check_output(['git', 'rev-parse', '--short', 'HEAD']).decode('ascii').strip()
except Exception:
return ""
def read_requirements(file_path):
with open(file_path, "r") as file:
return file.read().splitlines()
def read_version(file_path="version.txt"):
with open(file_path, "r") as file:
return file.readline().strip()
# Determine the package name based on the presence of an environment variable
package_name = "torchao-nightly" if os.environ.get("TORCHAO_NIGHTLY") else "torchao"
# Use Git commit ID if VERSION_SUFFIX is not set
version_suffix = os.getenv("VERSION_SUFFIX")
if version_suffix is None:
version_suffix = f"+git{get_git_commit_id()}"
use_cpp = os.getenv('USE_CPP')
# Version is year.month.date if using nightlies
version = current_date if package_name == "torchao-nightly" else read_version()
import torch
from torch.utils.cpp_extension import (
CppExtension,
CUDAExtension,
BuildExtension,
CUDA_HOME,
IS_WINDOWS
)
def get_extensions():
debug_mode = os.getenv('DEBUG', '0') == '1'
if debug_mode:
print("Compiling in debug mode")
if not torch.cuda.is_available():
print("PyTorch GPU support is not available. Skipping compilation of CUDA extensions")
if CUDA_HOME is None and torch.cuda.is_available():
print("CUDA toolkit is not available. Skipping compilation of CUDA extensions")
print("If you'd like to compile CUDA extensions locally please install the cudatoolkit from https://anaconda.org/nvidia/cuda-toolkit")
use_cuda = torch.cuda.is_available() and CUDA_HOME is not None
extension = CUDAExtension if use_cuda else CppExtension
if not IS_WINDOWS:
extra_link_args = []
extra_compile_args = {
"cxx": [
"-O3" if not debug_mode else "-O0",
"-fdiagnostics-color=always",
],
"nvcc": [
"-O3" if not debug_mode else "-O0",
"-t=0",
]
}
if debug_mode:
extra_compile_args["cxx"].append("-g")
extra_compile_args["nvcc"].append("-g")
extra_link_args.extend(["-O0", "-g"])
else:
extra_link_args = []
extra_compile_args = {
"cxx": [
"/O2" if not debug_mode else "/Od",
"/permissive-"
],
"nvcc": [
"-O3" if not debug_mode else "-O0",
"-t=0",
]
}
if debug_mode:
extra_compile_args["cxx"].append("/ZI")
extra_compile_args["nvcc"].append("-g")
extra_link_args.append("/DEBUG")
this_dir = os.path.dirname(os.path.curdir)
extensions_dir = os.path.join(this_dir, "torchao", "csrc")
sources = list(glob.glob(os.path.join(extensions_dir, "**/*.cpp"), recursive=True))
extensions_cuda_dir = os.path.join(extensions_dir, "cuda")
cuda_sources = list(glob.glob(os.path.join(extensions_cuda_dir, "**/*.cu"), recursive=True))
if use_cuda:
sources += cuda_sources
ext_modules = [
extension(
"torchao._C",
sources,
extra_compile_args=extra_compile_args,
extra_link_args=extra_link_args,
)
]
return ext_modules
setup(
name=package_name,
version=version+version_suffix,
packages=find_packages(),
include_package_data=True,
package_data={
"torchao.kernel.configs": ["*.pkl"],
},
ext_modules=get_extensions() if use_cpp != "0" else None,
extras_require={"dev": read_requirements("dev-requirements.txt")},
description="Package for applying ao techniques to GPU models",
long_description=open("README.md").read(),
long_description_content_type="text/markdown",
url="https://github.com/pytorch-labs/ao",
cmdclass={"build_ext": BuildExtension},
)