Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/llvm-target' into whitneywhtsang…
Browse files Browse the repository at this point in the history
…/splitsimdblock
  • Loading branch information
whitneywhtsang committed Sep 12, 2024
2 parents 0cf05a9 + d7fd027 commit 71afe7a
Show file tree
Hide file tree
Showing 63 changed files with 1,822 additions and 160 deletions.
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
name: Build and test A770
name: Build and test GPU
run-name: ${{ inputs.run_name }}

on:
workflow_dispatch:
inputs:
runner_label:
description: Runner label, keep empty for default
description: Runner label or GPU
type: string
default: ""
required: true
pytorch_ref:
description: PyTorch ref, keep empty for default
type: string
Expand All @@ -23,11 +23,11 @@ on:
skip_list:
description: Skip list
type: string
default: ""
required: true
run_name:
description: Custom run name
type: string
default: "Build and test A770"
default: ""
enable_unskip:
description: Ignore pytest.skip
type: boolean
Expand All @@ -43,12 +43,12 @@ jobs:
python: ["3.9"]
uses: ./.github/workflows/build-test-reusable.yml
with:
device: a770
device: ${{ inputs.runner_label }}
runner_label: ${{ inputs.runner_label }}
pytorch_ref: ${{ inputs.pytorch_ref }}
python_version: ${{ matrix.python }}
upload_test_reports: ${{ inputs.upload_test_reports }}
ignore_errors: ${{ inputs.ignore_errors }}
skip_list: ${{ inputs.skip_list || 'a770' }}
run_name: ${{ inputs.run_name }}
skip_list: ${{ inputs.skip_list }}
run_name: ${{ inputs.run_name || format('Build and test {0}', inputs.runner_label) }}
enable_unskip: ${{ inputs.enable_unskip }}
2 changes: 2 additions & 0 deletions .github/workflows/build-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,11 @@ on:
pull_request:
branches:
- llvm-target
- main
push:
branches:
- llvm-target
- main

permissions: read-all

Expand Down
3 changes: 2 additions & 1 deletion .github/workflows/triton-benchmarks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ permissions: read-all

env:
PYTHON_VERSION: "3.10"
USE_IPEX: ${{ github.event_name == 'schedule' && '1' || inputs.install_ipex && '1' || '0' }}

jobs:
build:
Expand Down Expand Up @@ -51,7 +52,7 @@ jobs:

- name: Install Python build dependencies
run: |
pip install wheel
pip install wheel cmake
- name: Setup PyTorch with IPEX
if: ${{ github.event_name == 'schedule' || inputs.install_ipex }}
Expand Down
8 changes: 7 additions & 1 deletion benchmarks/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,20 @@ set(CMAKE_CXX_STANDARD 20)

project(TritonBenchmark)

option(USE_IPEX "Use IPEX" ON)

if(NOT WIN32)
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake")
endif()

find_package(Python3 COMPONENTS Interpreter)
find_package(Torch REQUIRED)
find_library(TORCH_PYTHON_LIBRARY torch_python PATH "${TORCH_INSTALL_PREFIX}/lib")
find_package(IPEX REQUIRED)

if(USE_IPEX)
string(APPEND CMAKE_CXX_FLAGS " -DUSE_IPEX")
find_package(IPEX REQUIRED)
endif()


# add the XeTLA kernel.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import torch
import intel_extension_for_pytorch # type: ignore # noqa: F401
import os

import torch
import triton
import triton.language as tl

if os.getenv('USE_IPEX', '1') == '1':
import intel_extension_for_pytorch # type: ignore # noqa: F401


@triton.jit
def float_trunc_kernel(
Expand Down
10 changes: 8 additions & 2 deletions benchmarks/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,12 @@
from setuptools import setup

import torch
import intel_extension_for_pytorch

ipex_cmake_prefix_path = ""
USE_IPEX_OPTION = os.getenv("USE_IPEX", "1")
if USE_IPEX_OPTION == "1":
import intel_extension_for_pytorch
ipex_cmake_prefix_path = f";{intel_extension_for_pytorch.cmake_prefix_path}"


class CMakeBuild():
Expand Down Expand Up @@ -43,7 +48,8 @@ def build_extension(self):
"Ninja", # Ninja is much faster than make
"-DCMAKE_MAKE_PROGRAM=" +
ninja_dir, # Pass explicit path to ninja otherwise cmake may cache a temporary path
f"-DCMAKE_PREFIX_PATH={torch.utils.cmake_prefix_path};{intel_extension_for_pytorch.cmake_prefix_path}",
f"-DCMAKE_PREFIX_PATH={torch.utils.cmake_prefix_path}{ipex_cmake_prefix_path}",
f"-DUSE_IPEX={USE_IPEX_OPTION}",
"-DCMAKE_EXPORT_COMPILE_COMMANDS=ON",
"-DCMAKE_ARCHIVE_OUTPUT_DIRECTORY=" + self.extdir,
"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=" + self.extdir,
Expand Down
11 changes: 6 additions & 5 deletions benchmarks/triton_kernels_benchmark/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from triton.runtime import driver
from . import benchmark_driver
from .benchmark_testing import do_bench, assert_close, perf_report, Benchmark # type: ignore # noqa: F401
from .benchmark_testing import do_bench, assert_close, perf_report, Benchmark, USE_IPEX_OPTION # type: ignore # noqa: F401

# replace the launcher with the profilier hook.
driver.active.launcher_cls = benchmark_driver.XPULauncher
if USE_IPEX_OPTION:
from triton.runtime import driver
from . import benchmark_driver
# replace the launcher with the profilier hook.
driver.active.launcher_cls = benchmark_driver.XPULauncher
76 changes: 60 additions & 16 deletions benchmarks/triton_kernels_benchmark/benchmark_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import os
from typing import Any, Dict, List

USE_IPEX_OPTION = os.getenv("USE_IPEX", "1") == "1"


def synchronize():
import torch
Expand All @@ -12,8 +14,26 @@ def synchronize():
torch.xpu.synchronize()


def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, fast_flush=True, return_mode="mean",
device="xpu", sync_submitting=True):
def _summarize_statistics(times, quantiles, return_mode):
import torch
if quantiles is not None:
ret = torch.quantile(times, torch.tensor(quantiles, dtype=torch.float)).tolist()
if times.numel() > 2:
# exclude max and min times
times = torch.sort(times).values[1:-1]
# add coefficient of the variance.
std = torch.std(times)
mean = torch.mean(times)
cv = std / mean
ret.extend([mean.tolist(), cv.tolist()])
if len(ret) == 1:
ret = ret[0]
return ret
return getattr(torch, return_mode)(times).item()


def do_bench_ipex(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, fast_flush=True, return_mode="mean",
device="xpu", sync_submitting=True):
"""
Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with
the 20-th and 80-th performance percentile.
Expand All @@ -31,6 +51,9 @@ def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, fast_flu
:param fast_flush: Use faster kernel to flush L2 between measurements
:type fast_flush: bool
"""
# TODO: remove this function and switch to `do_bench_no_ipex` after
# `XPUEvent.elapsed_time` stops introducing regressions into the results.

assert return_mode in ["min", "max", "mean", "median"]
import torch
from torch.autograd.profiler import record_function
Expand Down Expand Up @@ -96,20 +119,41 @@ def extract_kernels(funcs):
assert len(kernels) == n_repeat, "the profiling number not match"
# Make the time to the milliseconds.
times = torch.tensor([sum([k.duration for k in ks]) * 1e-3 for ks in kernels], dtype=torch.float)
if quantiles is not None:
ret = torch.quantile(times, torch.tensor(quantiles, dtype=torch.float)).tolist()
if times.numel() > 2:
# exclude max and min times
times = torch.sort(times).values[1:-1]
# add coefficient of the variance.
std = torch.std(times)
mean = torch.mean(times)
cv = std / mean
ret.extend([mean.tolist(), cv.tolist()])
if len(ret) == 1:
ret = ret[0]
return ret
return getattr(torch, return_mode)(times).item()
return _summarize_statistics(times, quantiles, return_mode)


def do_bench_no_ipex(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, fast_flush=True, return_mode="mean",
device="xpu"):
"""
Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with
the 20-th and 80-th performance percentile.
:param fn: Function to benchmark
:type fn: Callable
:param warmup: Warmup time (in ms)
:type warmup: int
:param rep: Repetition time (in ms)
:type rep: int
:param grad_to_none: Reset the gradient of the provided tensor to None
:type grad_to_none: torch.tensor, optional
:param quantiles: Performance percentile to return in addition to the median.
:type quantiles: list[float]
:param fast_flush: Use faster kernel to flush L2 between measurements
:type fast_flush: bool
"""
assert return_mode in ["min", "max", "mean", "median"]
import torch
from triton.testing import do_bench as triton_do_bench

times = triton_do_bench(fn, warmup=warmup, rep=rep, grad_to_none=grad_to_none, fast_flush=fast_flush,
return_mode="all", device_type=device)
times = torch.tensor(times, dtype=torch.float)
return _summarize_statistics(times, quantiles, return_mode)


do_bench = do_bench_no_ipex
if USE_IPEX_OPTION:
do_bench = do_bench_ipex


def assert_close(x, y, atol=None, rtol=None, err_msg=""):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
import torch
import intel_extension_for_pytorch # type: ignore # noqa: F401

import triton
import triton.language as tl

import triton_kernels_benchmark
import triton_kernels_benchmark as benchmark_suit
import xetla_kernel

benchmark_suit = triton_kernels_benchmark # triton.testing
if benchmark_suit.USE_IPEX_OPTION:
import intel_extension_for_pytorch # type: ignore # noqa: F401


# pylint: disable=unused-argument
Expand Down Expand Up @@ -226,10 +225,12 @@ def benchmark(Z, H, N_CTX, D_HEAD, provider):

elif provider == 'triton':
triton_fn = lambda: forward(q, k, v, causal, sm_scale)
torch_fn = lambda: torch.nn.functional.scaled_dot_product_attention(
q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False, scale=sm_scale).to(torch.float32)
atol = 1e-1 if N_CTX == 16384 else 1e-2
benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=atol, rtol=1e-3, err_msg='triton to torch')
if benchmark_suit.USE_IPEX_OPTION:
# FIXME: use torch sdpa for result check after https://github.com/intel/intel-xpu-backend-for-triton/issues/2042 fixed
torch_fn = lambda: torch.nn.functional.scaled_dot_product_attention(
q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False, scale=sm_scale).to(torch.float32)
atol = 1e-1 if N_CTX == 16384 else 1e-2
benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=atol, rtol=1e-3, err_msg='triton to torch')
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles,
fast_flush=False)

Expand Down
7 changes: 3 additions & 4 deletions benchmarks/triton_kernels_benchmark/fused_softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,15 @@
"""

import torch
import intel_extension_for_pytorch # type: ignore # noqa: F401

import triton
import triton.language as tl
from triton.runtime import driver

import triton_kernels_benchmark
import triton_kernels_benchmark as benchmark_suit
import xetla_kernel

benchmark_suit = triton_kernels_benchmark # triton.testing
if benchmark_suit.USE_IPEX_OPTION:
import intel_extension_for_pytorch # type: ignore # noqa: F401


@torch.jit.script
Expand Down
5 changes: 3 additions & 2 deletions benchmarks/triton_kernels_benchmark/gemm_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,15 @@
"""

import torch
import intel_extension_for_pytorch # type: ignore # noqa: F401

import triton
import triton.language as tl

import triton_kernels_benchmark as benchmark_suit
import xetla_kernel

if benchmark_suit.USE_IPEX_OPTION:
import intel_extension_for_pytorch # type: ignore # noqa: F401


@triton.autotune(
configs=[
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,14 @@
"""

import torch
import intel_extension_for_pytorch # type: ignore # noqa: F401

import triton
import triton.language as tl

import triton_kernels_benchmark as benchmark_suit

if benchmark_suit.USE_IPEX_OPTION:
import intel_extension_for_pytorch # type: ignore # noqa: F401


@triton.autotune(
configs=[
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,16 @@
"""

import math
import torch
import intel_extension_for_pytorch # type: ignore # noqa: F401

import torch
import triton
import triton.language as tl

import triton_kernels_benchmark as benchmark_suit

if benchmark_suit.USE_IPEX_OPTION:
import intel_extension_for_pytorch # type: ignore # noqa: F401

kAlpha = tl.constexpr(math.sqrt(2.0 / math.pi))


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,14 @@
"""

import torch
import intel_extension_for_pytorch # type: ignore # noqa: F401

import triton
import triton.language as tl

import triton_kernels_benchmark as benchmark_suit

if benchmark_suit.USE_IPEX_OPTION:
import intel_extension_for_pytorch # type: ignore # noqa: F401


@triton.autotune(
configs=[
Expand Down
7 changes: 3 additions & 4 deletions benchmarks/triton_kernels_benchmark/gemm_splitk_benchmark.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import torch
import intel_extension_for_pytorch # type: ignore # noqa: F401

import triton
import triton.language as tl

import triton_kernels_benchmark
import triton_kernels_benchmark as benchmark_suit

benchmark_suit = triton_kernels_benchmark # triton.testing
if benchmark_suit.USE_IPEX_OPTION:
import intel_extension_for_pytorch # type: ignore # noqa: F401


@triton.autotune(
Expand Down
Loading

0 comments on commit 71afe7a

Please sign in to comment.