forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_cpp_extensions_mtia_backend.py
145 lines (125 loc) · 5.6 KB
/
test_cpp_extensions_mtia_backend.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
# Owner(s): ["module: mtia"]
import os
import tempfile
import unittest
import torch
import torch.testing._internal.common_utils as common
import torch.utils.cpp_extension
from torch.testing._internal.common_utils import (
IS_ARM64,
IS_LINUX,
skipIfTorchDynamo,
TEST_CUDA,
TEST_PRIVATEUSE1,
TEST_XPU,
)
from torch.utils.cpp_extension import CUDA_HOME, ROCM_HOME
# define TEST_ROCM before changing TEST_CUDA
TEST_ROCM = TEST_CUDA and torch.version.hip is not None and ROCM_HOME is not None
TEST_CUDA = TEST_CUDA and CUDA_HOME is not None
@unittest.skipIf(
IS_ARM64 or not IS_LINUX or TEST_CUDA or TEST_PRIVATEUSE1 or TEST_ROCM or TEST_XPU,
"Only on linux platform and mutual exclusive to other backends",
)
@torch.testing._internal.common_utils.markDynamoStrictTest
class TestCppExtensionMTIABackend(common.TestCase):
"""Tests MTIA backend with C++ extensions."""
module = None
def setUp(self):
super().setUp()
# cpp extensions use relative paths. Those paths are relative to
# this file, so we'll change the working directory temporarily
self.old_working_dir = os.getcwd()
os.chdir(os.path.dirname(os.path.abspath(__file__)))
def tearDown(self):
super().tearDown()
# return the working directory (see setUp)
os.chdir(self.old_working_dir)
@classmethod
def tearDownClass(cls):
torch.testing._internal.common_utils.remove_cpp_extensions_build_root()
@classmethod
def setUpClass(cls):
torch.testing._internal.common_utils.remove_cpp_extensions_build_root()
build_dir = tempfile.mkdtemp()
# Load the fake device guard impl.
cls.module = torch.utils.cpp_extension.load(
name="mtia_extension",
sources=["cpp_extensions/mtia_extension.cpp"],
build_directory=build_dir,
extra_include_paths=[
"cpp_extensions",
"path / with spaces in it",
"path with quote'",
],
is_python_module=False,
verbose=True,
)
@skipIfTorchDynamo("Not a TorchDynamo suitable test")
def test_get_device_module(self):
device = torch.device("mtia:0")
default_stream = torch.get_device_module(device).current_stream()
self.assertEqual(
default_stream.device_type, int(torch._C._autograd.DeviceType.MTIA)
)
print(torch._C.Stream.__mro__)
print(torch.cuda.Stream.__mro__)
@skipIfTorchDynamo("Not a TorchDynamo suitable test")
def test_stream_basic(self):
default_stream = torch.mtia.current_stream()
user_stream = torch.mtia.Stream()
self.assertEqual(torch.mtia.current_stream(), default_stream)
self.assertNotEqual(default_stream, user_stream)
# Check mtia_extension.cpp, default stream id starts from 0.
self.assertEqual(default_stream.stream_id, 0)
self.assertNotEqual(user_stream.stream_id, 0)
with torch.mtia.stream(user_stream):
self.assertEqual(torch.mtia.current_stream(), user_stream)
self.assertTrue(user_stream.query())
default_stream.synchronize()
self.assertTrue(default_stream.query())
@skipIfTorchDynamo("Not a TorchDynamo suitable test")
def test_stream_context(self):
mtia_stream_0 = torch.mtia.Stream(device="mtia:0")
mtia_stream_1 = torch.mtia.Stream(device="mtia:0")
print(mtia_stream_0)
print(mtia_stream_1)
with torch.mtia.stream(mtia_stream_0):
current_stream = torch.mtia.current_stream()
msg = f"current_stream {current_stream} should be {mtia_stream_0}"
self.assertTrue(current_stream == mtia_stream_0, msg=msg)
with torch.mtia.stream(mtia_stream_1):
current_stream = torch.mtia.current_stream()
msg = f"current_stream {current_stream} should be {mtia_stream_1}"
self.assertTrue(current_stream == mtia_stream_1, msg=msg)
@skipIfTorchDynamo("Not a TorchDynamo suitable test")
def test_stream_context_different_device(self):
device_0 = torch.device("mtia:0")
device_1 = torch.device("mtia:1")
mtia_stream_0 = torch.mtia.Stream(device=device_0)
mtia_stream_1 = torch.mtia.Stream(device=device_1)
print(mtia_stream_0)
print(mtia_stream_1)
orig_current_device = torch.mtia.current_device()
with torch.mtia.stream(mtia_stream_0):
current_stream = torch.mtia.current_stream()
self.assertTrue(torch.mtia.current_device() == device_0.index)
msg = f"current_stream {current_stream} should be {mtia_stream_0}"
self.assertTrue(current_stream == mtia_stream_0, msg=msg)
self.assertTrue(torch.mtia.current_device() == orig_current_device)
with torch.mtia.stream(mtia_stream_1):
current_stream = torch.mtia.current_stream()
self.assertTrue(torch.mtia.current_device() == device_1.index)
msg = f"current_stream {current_stream} should be {mtia_stream_1}"
self.assertTrue(current_stream == mtia_stream_1, msg=msg)
self.assertTrue(torch.mtia.current_device() == orig_current_device)
@skipIfTorchDynamo("Not a TorchDynamo suitable test")
def test_device_context(self):
device_0 = torch.device("mtia:0")
device_1 = torch.device("mtia:1")
with torch.mtia.device(device_0):
self.assertTrue(torch.mtia.current_device() == device_0.index)
with torch.mtia.device(device_1):
self.assertTrue(torch.mtia.current_device() == device_1.index)
if __name__ == "__main__":
common.run_tests()