Skip to content

Commit

Permalink
Merge commit '0ecf368a01fb0f0a412802072415fb2f36b26561'
Browse files Browse the repository at this point in the history
  • Loading branch information
whitneywhtsang committed Jul 24, 2024
2 parents 45cf3db + 0ecf368 commit 4ced588
Show file tree
Hide file tree
Showing 15 changed files with 116 additions and 72 deletions.
9 changes: 2 additions & 7 deletions .github/workflows/integration-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -239,10 +239,8 @@ jobs:
echo "Coult not find '${SHARED_LIB_DIR}'" ; exit -1
fi
cd python/test/unit
python3 -m pytest -s -n 8 --ignore=hopper/test_flashattention.py --ignore=runtime --ignore=language/test_line_info.py --ignore=language/test_subprocess.py
python3 -m pytest -s -n 8 --ignore=hopper/test_flashattention.py --ignore=language/test_line_info.py --ignore=language/test_subprocess.py
python3 -m pytest -s -n 8 language/test_subprocess.py
# Run runtime tests serially to avoid race condition with cache handling
python3 -m pytest -s runtime/
# Run test_line_info.py separately with TRITON_DISABLE_LINE_INFO=0
TRITON_DISABLE_LINE_INFO=0 python3 -m pytest -s language/test_line_info.py
# Run hopper/test_flashattention.py separately to avoid out of gpu memory
Expand Down Expand Up @@ -385,17 +383,14 @@ jobs:
cd python/test/unit
## test_subprocess.py is flaky on the AMD CI.
## TODO (lixun) find a solution and re-enable it.
pytest --capture=tee-sys -rfs -n 16 language \
pytest --capture=tee-sys -rfs -n 16 language runtime \
--ignore=language/test_line_info.py \
--ignore=language/test_subprocess.py
TRITON_ALWAYS_COMPILE=1 TRITON_DISABLE_LINE_INFO=0 LLVM_PASS_PLUGIN_PATH=${SHARED_LIB_DIR}/libGPUHello.so \
pytest --capture=tee-sys -rfs -vvv instrumentation/test_gpuhello.py
# Run test_line_info.py separately with TRITON_DISABLE_LINE_INFO=0
TRITON_DISABLE_LINE_INFO=0 python3 -m pytest -s -n 8 language/test_line_info.py
# Run runtime tests serially to avoid race condition with cache handling
python3 -m pytest -s runtime
- name: Run Proton tests
run: |
cd third_party/proton
Expand Down
9 changes: 2 additions & 7 deletions .github/workflows/integration-tests.yml.in
Original file line number Diff line number Diff line change
Expand Up @@ -274,10 +274,8 @@ jobs:
echo "Coult not find '${SHARED_LIB_DIR}'" ; exit -1
fi
cd python/test/unit
python3 -m pytest -s -n 8 --ignore=hopper/test_flashattention.py --ignore=runtime --ignore=language/test_line_info.py --ignore=language/test_subprocess.py
python3 -m pytest -s -n 8 --ignore=hopper/test_flashattention.py --ignore=language/test_line_info.py --ignore=language/test_subprocess.py
python3 -m pytest -s -n 8 language/test_subprocess.py
# Run runtime tests serially to avoid race condition with cache handling
python3 -m pytest -s runtime/
# Run test_line_info.py separately with TRITON_DISABLE_LINE_INFO=0
TRITON_DISABLE_LINE_INFO=0 python3 -m pytest -s language/test_line_info.py
# Run hopper/test_flashattention.py separately to avoid out of gpu memory
Expand Down Expand Up @@ -390,7 +388,7 @@ jobs:
cd python/test/unit
## test_subprocess.py is flaky on the AMD CI.
## TODO (lixun) find a solution and re-enable it.
pytest --capture=tee-sys -rfs -n 16 language \
pytest --capture=tee-sys -rfs -n 16 language runtime \
--ignore=language/test_line_info.py \
--ignore=language/test_subprocess.py
TRITON_ALWAYS_COMPILE=1 TRITON_DISABLE_LINE_INFO=0 LLVM_PASS_PLUGIN_PATH=${SHARED_LIB_DIR}/libGPUHello.so \
Expand All @@ -399,9 +397,6 @@ jobs:
# Run test_line_info.py separately with TRITON_DISABLE_LINE_INFO=0
TRITON_DISABLE_LINE_INFO=0 python3 -m pytest -s -n 8 language/test_line_info.py

# Run runtime tests serially to avoid race condition with cache handling
python3 -m pytest -s runtime

- name: Run Proton tests
run: |
cd third_party/proton
Expand Down
4 changes: 4 additions & 0 deletions python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,10 @@ def build_extension(self, ext):
else:
cmake_args += ["-DTRITON_BUILD_PROTON=OFF"]

cmake_args_append = os.getenv("TRITON_APPEND_CMAKE_ARGS")
if cmake_args_append is not None:
cmake_args += cmake_args_append.split(" ")

env = os.environ.copy()
cmake_dir = get_cmake_dir()
subprocess.check_call(["cmake", self.base_dir] + cmake_args, cwd=cmake_dir, env=env)
Expand Down
17 changes: 17 additions & 0 deletions python/src/llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,23 @@ void init_triton_llvm(py::module &&m) {
},
py::keep_alive<0, 2>());

m.def("attach_datalayout", [](llvm::Module *mod, const std::string triple,
const std::string proc,
const std::string features) {
std::string error;
auto target = llvm::TargetRegistry::lookupTarget(triple, error);
if (!target) {
throw std::runtime_error("target lookup error: " + error);
}
llvm::TargetOptions opt;
// Target machine is only used to create the data layout.
std::unique_ptr<llvm::TargetMachine> machine{target->createTargetMachine(
triple, proc, features, opt, llvm::Reloc::PIC_, std::nullopt,
llvm::CodeGenOptLevel::None)};
// set data layout
mod->setDataLayout(machine->createDataLayout());
});

m.def("optimize_module", [](llvm::Module *mod,
const llvm::OptimizationLevel &opt) {
if (mlir::triton::tools::getBoolEnv("DISABLE_LLVM_OPT"))
Expand Down
14 changes: 12 additions & 2 deletions python/test/unit/conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# content of conftest.py

import pytest
import os
import tempfile


def pytest_addoption(parser):
Expand All @@ -10,3 +10,13 @@ def pytest_addoption(parser):
@pytest.fixture
def device(request):
return request.config.getoption("--device")


@pytest.fixture
def fresh_triton_cache():
with tempfile.TemporaryDirectory() as tmpdir:
try:
os.environ["TRITON_CACHE_DIR"] = tmpdir
yield tmpdir
finally:
os.environ.pop("TRITON_CACHE_DIR", None)
12 changes: 11 additions & 1 deletion python/test/unit/language/print_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,13 @@ def kernel_print(X, Y, BLOCK: tl.constexpr):
tl.store(Y + tl.arange(0, BLOCK), x)


@triton.jit
def kernel_device_print_scalar(SCALAR):
x = tl.load(SCALAR)
# Triton should add a space after this prefix.
print("x:", x)


@triton.jit
def kernel_device_print_large(
BLOCK_M: tl.constexpr,
Expand Down Expand Up @@ -95,6 +102,9 @@ def test_print(func: str, data_type: str, device: str):
y = torch.zeros((N, ), dtype=x.dtype, device=device)
if func == "device_print":
kernel_device_print[(1, )](x, y, num_warps=num_warps, BLOCK=N)
elif func == "device_print_scalar":
scalar = torch.tensor(42, dtype=x.dtype, device=device)
kernel_device_print_scalar[(1, )](scalar, num_warps=num_warps)
elif func == "print":
kernel_print[(1, )](x, y, num_warps=num_warps, BLOCK=N)
elif func == "device_print_large":
Expand Down Expand Up @@ -123,7 +133,7 @@ def test_print(func: str, data_type: str, device: str):

if func != "print_no_arg" and func != "no_arg_print" and func != "device_print_large" and \
func != "print_multiple_args" and func != "device_print_multiple_args" and \
func != "device_print_pointer":
func != "device_print_pointer" and func != "device_print_scalar":
assert_close(y, x)


Expand Down
36 changes: 23 additions & 13 deletions python/test/unit/language/test_subprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,19 +27,21 @@ def is_interpreter():


@pytest.mark.interpreter
@pytest.mark.parametrize("func_type, data_type", [("device_print", data_type) for data_type in torch_types] + [
("print", "int32"),
("static_print", "int32"),
("no_arg_print", "int32"),
("print_no_arg", "int32"),
("device_print_large", "int32"),
("print_multiple_args", "int32"),
("device_print_multiple_args", "int32"),
("device_print_hex", "int16"),
("device_print_hex", "int32"),
("device_print_hex", "int64"),
("device_print_pointer", "int32"),
])
@pytest.mark.parametrize("func_type, data_type", [(fn, data_type)
for fn in ["device_print", "device_print_scalar"]
for data_type in torch_types] + [
("print", "int32"),
("static_print", "int32"),
("no_arg_print", "int32"),
("print_no_arg", "int32"),
("device_print_large", "int32"),
("print_multiple_args", "int32"),
("device_print_multiple_args", "int32"),
("device_print_hex", "int16"),
("device_print_hex", "int32"),
("device_print_hex", "int64"),
("device_print_pointer", "int32"),
])
def test_print(func_type: str, data_type: str, device: str):
if device == "xpu" and data_type == "float64" and not tr.driver.active.get_current_target().arch['has_fp64']:
pytest.xfail("float64 not supported on current xpu hardware")
Expand All @@ -59,6 +61,9 @@ def test_print(func_type: str, data_type: str, device: str):
# The total number of elements in the 1-D tensor to print.
N = 128

# Constant for testing the printing of scalar values
SCALAR_VAL = 42

# Format is
# pid (<x>, <y>, <z>) idx (<i1>, <i2>, ...) <prefix> (operand <n>) <elem>
expected_lines = Counter()
Expand All @@ -68,6 +73,11 @@ def test_print(func_type: str, data_type: str, device: str):
if data_type.startswith("float"):
line += ".000000"
expected_lines[line] = 1
elif func_type == "device_print_scalar":
line = f"pid (0, 0, 0) idx () x: {SCALAR_VAL}"
if data_type.startswith("float"):
line += ".000000"
expected_lines[line] = N
elif func_type == "device_print_hex":
for i in range(N):
line = f"pid (0, 0, 0) idx ({i:3}) x: 0x"
Expand Down
20 changes: 4 additions & 16 deletions python/test/unit/runtime/test_cache.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import importlib.util
import itertools
import os
import shutil
import tempfile

Expand All @@ -12,8 +11,6 @@
import triton.language as tl
from triton.runtime.jit import JITFunction

tmpdir = ".tmp"


@triton.jit
def function_0(i):
Expand Down Expand Up @@ -151,38 +148,29 @@ def test_kernel(i):
assert orig_cache_key != updated_cache_key


def reset_tmp_dir():
os.environ["TRITON_CACHE_DIR"] = tmpdir
if os.path.exists(tmpdir):
# https://stackoverflow.com/questions/303200/how-do-i-remove-delete-a-folder-that-is-not-empty
shutil.rmtree(tmpdir, ignore_errors=True)


def test_reuse(device):
def test_reuse(device, fresh_triton_cache):
counter = 0

def inc_counter(*args, **kwargs):
nonlocal counter
counter += 1

JITFunction.cache_hook = inc_counter
reset_tmp_dir()
x = torch.empty(1, dtype=torch.int32, device=device)
for i in range(10):
kernel[(1, )](x, 1, BLOCK=1024)
assert counter == 1


@pytest.mark.parametrize('mode', ['enable', 'disable'])
def test_specialize(mode, device):
def test_specialize(mode, device, fresh_triton_cache):
counter = 0

def inc_counter(*args, **kwargs):
nonlocal counter
counter += 1

JITFunction.cache_hook = inc_counter
reset_tmp_dir()
x = torch.empty(1, dtype=torch.int32, device=device)
function = {'enable': kernel, 'disable': kernel_nospec}[mode]
target = {'enable': 3, 'disable': 1}[mode]
Expand Down Expand Up @@ -494,7 +482,7 @@ def kernel(in_ptr0, out_ptr0, xnumel, XBLOCK: tl.constexpr):
tl.store(out_ptr0 + (x0 + tl.zeros([XBLOCK], tl.int32)), tmp0, xmask)


def test_preload() -> None:
def test_preload(fresh_triton_cache) -> None:

@triton.jit
def kernel_add(a, b, o, N: tl.constexpr, type: tl.constexpr):
Expand Down Expand Up @@ -523,7 +511,7 @@ def cache_hook(*args, **kwargs):
assert specialization_data is not None

# clear the cache
reset_tmp_dir()
shutil.rmtree(fresh_triton_cache)
kernel_add.cache[device].clear()

# preload the kernel
Expand Down
21 changes: 3 additions & 18 deletions python/test/unit/runtime/test_subproc.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import multiprocessing
import os
import shutil

import torch
Expand All @@ -9,17 +8,9 @@
import triton.language as tl
from triton.compiler import ASTSource

tmpdir = ".tmp"

target = triton.runtime.driver.active.get_current_target()


def reset_tmp_dir():
os.environ["TRITON_CACHE_DIR"] = tmpdir
if os.path.exists(tmpdir):
shutil.rmtree(tmpdir, ignore_errors=True)


def compile_fn(attrs, capability):

@triton.jit
Expand Down Expand Up @@ -66,15 +57,12 @@ def kernel_dot(Z):
triton.compile(src=src, target=target)


def test_compile_in_forked_subproc() -> None:
reset_tmp_dir()
capability = 0
def test_compile_in_forked_subproc(fresh_triton_cache) -> None:
if torch.cuda.is_available():
major, minor = torch.cuda.get_device_capability(0)
capability = major * 10 + minor
elif torch.xpu.is_available():
capability = torch.xpu.get_device_capability(0)

config = triton.compiler.AttrsDescriptor(tuple(range(1)), ())

assert multiprocessing.get_start_method() == 'fork'
Expand All @@ -96,7 +84,7 @@ def empty_kernel():
triton.compile(src=src, target=target)


def test_compile_in_forked_subproc_with_forced_gc() -> None:
def test_compile_in_forked_subproc_with_forced_gc(fresh_triton_cache) -> None:
'''
Tests that compilation artifacts can safely live in forked process.
Expand All @@ -109,7 +97,6 @@ def test_compile_in_forked_subproc_with_forced_gc() -> None:
This is a regression test that ensures thread pool in MLIRContext is released
safely after compilation.
'''
reset_tmp_dir()
import gc
old_gc_state = gc.isenabled()
# disable GC to manage resources manually in the manner described in comment above
Expand All @@ -120,7 +107,7 @@ def test_compile_in_forked_subproc_with_forced_gc() -> None:
compile_empty_kernel_with_gc(config)

# stage 2.p
reset_tmp_dir()
shutil.rmtree(fresh_triton_cache)
assert multiprocessing.get_start_method() == 'fork'
proc = multiprocessing.Process(target=compile_empty_kernel_with_gc, args=(config, ))

Expand All @@ -133,5 +120,3 @@ def test_compile_in_forked_subproc_with_forced_gc() -> None:
if old_gc_state:
gc.enable()
assert proc.exitcode == 0

reset_tmp_dir()
3 changes: 3 additions & 0 deletions scripts/skiplist/a770/subprocess.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,6 @@
test/unit/language/test_subprocess.py::test_print[device_print-float16]
test/unit/language/test_subprocess.py::test_print[device_print-float32]
test/unit/language/test_subprocess.py::test_print[device_print-float64]
test/unit/language/test_subprocess.py::test_print[device_print_scalar-float16]
test/unit/language/test_subprocess.py::test_print[device_print_scalar-float64]
test/unit/language/test_subprocess.py::test_print[device_print_scalar-float32]
3 changes: 3 additions & 0 deletions scripts/skiplist/conda-basekit/subprocess.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
test/unit/language/test_subprocess.py::test_print[device_print-float16]
test/unit/language/test_subprocess.py::test_print[device_print-float32]
test/unit/language/test_subprocess.py::test_print[device_print-float64]
test/unit/language/test_subprocess.py::test_print[device_print_scalar-float16]
test/unit/language/test_subprocess.py::test_print[device_print_scalar-float64]
test/unit/language/test_subprocess.py::test_print[device_print_scalar-float32]
3 changes: 3 additions & 0 deletions scripts/skiplist/conda/subprocess.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,6 @@
test/unit/language/test_subprocess.py::test_print[device_print-float16]
test/unit/language/test_subprocess.py::test_print[device_print-float32]
test/unit/language/test_subprocess.py::test_print[device_print-float64]
test/unit/language/test_subprocess.py::test_print[device_print_scalar-float16]
test/unit/language/test_subprocess.py::test_print[device_print_scalar-float64]
test/unit/language/test_subprocess.py::test_print[device_print_scalar-float32]
3 changes: 3 additions & 0 deletions scripts/skiplist/default/subprocess.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,6 @@
test/unit/language/test_subprocess.py::test_print[device_print-float16]
test/unit/language/test_subprocess.py::test_print[device_print-float32]
test/unit/language/test_subprocess.py::test_print[device_print-float64]
test/unit/language/test_subprocess.py::test_print[device_print_scalar-float16]
test/unit/language/test_subprocess.py::test_print[device_print_scalar-float64]
test/unit/language/test_subprocess.py::test_print[device_print_scalar-float32]
3 changes: 3 additions & 0 deletions scripts/skiplist/lts/subprocess.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,6 @@
test/unit/language/test_subprocess.py::test_print[device_print-float16]
test/unit/language/test_subprocess.py::test_print[device_print-float32]
test/unit/language/test_subprocess.py::test_print[device_print-float64]
test/unit/language/test_subprocess.py::test_print[device_print_scalar-float16]
test/unit/language/test_subprocess.py::test_print[device_print_scalar-float64]
test/unit/language/test_subprocess.py::test_print[device_print_scalar-float32]
Loading

0 comments on commit 4ced588

Please sign in to comment.