Skip to content

Commit

Permalink
Merge branch 'main' into amyachev/issue2979
Browse files Browse the repository at this point in the history
  • Loading branch information
anmyachev authored Dec 19, 2024
2 parents eda8285 + 13725c1 commit 5b7374b
Show file tree
Hide file tree
Showing 53 changed files with 1,035 additions and 815 deletions.
3 changes: 2 additions & 1 deletion .github/actions/setup-pytorch/action.yml
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,9 @@ runs:
- name: Generate PyTorch cache key
shell: bash
run: |
ONEAPI_LINK=$(readlink /opt/intel/oneapi || true)
ONEAPI_KEY=$(sha256sum /opt/intel/installed.txt 2> /dev/null | cut -d\ -f1 || true)
PYTORCH_CACHE_KEY=$(echo $PYTHON_VERSION $PYTORCH_COMMIT_ID ${{ hashFiles('scripts/patch-pytorch.sh') }} ${{ inputs.mode }}$ONEAPI_KEY | sha256sum - | cut -d\ -f1)
PYTORCH_CACHE_KEY=$(echo $PYTHON_VERSION $PYTORCH_COMMIT_ID ${{ hashFiles('scripts/patch-pytorch.sh') }} ${{ inputs.mode }}${ONEAPI_KEY}${ONEAPI_LINK} | sha256sum - | cut -d\ -f1)
echo "PYTORCH_CACHE_KEY=$PYTORCH_CACHE_KEY" | tee -a "$GITHUB_ENV"
- name: Load PyTorch from a cache
Expand Down
21 changes: 18 additions & 3 deletions .github/workflows/triton-benchmarks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ on:
- PYTORCH_LEGACY_PROFILER_USING_IPEX
- ELAPSED_TIME
- UPSTREAM_PYTORCH_PROFILER
default: PYTORCH_LEGACY_PROFILER_USING_IPEX
default: UPSTREAM_PYTORCH_PROFILER
run_name:
description: Run name
type: string
Expand All @@ -32,6 +32,13 @@ on:
description: Use Python built with pyenv
type: boolean
default: false
oneapi_bundle:
description: oneAPI bundle
type: choice
options:
- PTDB
- DLE
default: DLE

schedule:
- cron: "5 23 * * *"
Expand All @@ -46,8 +53,8 @@ permissions: read-all

env:
PYTHON_VERSION: "3.10"
BENCHMARKING_METHOD: ${{ inputs.benchmarking_method || 'PYTORCH_LEGACY_PROFILER_USING_IPEX' }}
USE_IPEX: ${{ github.event_name != 'workflow_dispatch' && '1' || inputs.benchmarking_method == 'PYTORCH_LEGACY_PROFILER_USING_IPEX' && '1' || '0' }}
BENCHMARKING_METHOD: ${{ inputs.benchmarking_method || 'UPSTREAM_PYTORCH_PROFILER' }}
USE_IPEX: ${{ (inputs.benchmarking_method || 'UPSTREAM_PYTORCH_PROFILER') == 'PYTORCH_LEGACY_PROFILER_USING_IPEX' && '1' || '0' }}
TAG: ${{ inputs.tag || (github.event_name == 'pull_request' && format('pr-{0}', github.event.number)) || (github.event_name == 'schedule' && 'ci') || 'test' }}

jobs:
Expand All @@ -66,6 +73,14 @@ jobs:
${{ toJSON(inputs) }}
EOF
- name: Use DLE
if: ${{ (github.oneapi_bundle || 'DLE') == 'DLE' }}
shell: bash
run: |
if [[ -e /opt/intel/dle ]]; then
sudo ln -sfT /opt/intel/dle /opt/intel/oneapi
fi
- name: Checkout repository
uses: actions/checkout@v4

Expand Down
32 changes: 16 additions & 16 deletions benchmarks/triton_kernels_benchmark/benchmark_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def _summarize_statistics(times, quantiles, return_mode):


def do_bench_ipex(fn, n_warmup=25, n_repeat=100, grad_to_none=None, quantiles=None, return_mode="mean", device="xpu",
sync_submitting=True, kernel_name=None): # pylint: disable=unused-argument
sync_submitting=True):
"""
Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with
the 20-th and 80-th performance percentile.
Expand Down Expand Up @@ -108,7 +108,7 @@ def extract_kernels(funcs):


def do_bench_elapsed_time(fn, n_warmup=25, n_repeat=100, grad_to_none=None, quantiles=None, return_mode="mean",
device="xpu", kernel_name=None): # pylint: disable=unused-argument
device="xpu"):
"""
Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with
the 20-th and 80-th performance percentile.
Expand Down Expand Up @@ -159,7 +159,7 @@ def do_bench_elapsed_time(fn, n_warmup=25, n_repeat=100, grad_to_none=None, quan


def do_bench_upstream_pytorch_profiler(fn, n_warmup=25, n_repeat=100, grad_to_none=None, quantiles=None,
return_mode="mean", device="xpu", sync_submitting=True, kernel_name=None):
return_mode="mean", device="xpu", sync_submitting=True):
"""
Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with
the 20-th and 80-th performance percentile.
Expand All @@ -178,7 +178,7 @@ def do_bench_upstream_pytorch_profiler(fn, n_warmup=25, n_repeat=100, grad_to_no

assert return_mode in ["min", "max", "mean", "median"]
import torch
from torch.profiler import profile, ProfilerActivity
from torch.profiler import profile, ProfilerActivity, record_function

fn()
synchronize()
Expand Down Expand Up @@ -206,24 +206,24 @@ def do_bench_upstream_pytorch_profiler(fn, n_warmup=25, n_repeat=100, grad_to_no
if sync_submitting:
synchronize()
# record time of `fn`
fn()
with record_function("__profile_kernel_of_func"):
fn()
# Record clocks
synchronize()

function_events = prof.events()
profiling_func_filter = filter(lambda x: x.name.startswith("__profile_kernel_of_func"), prof.events())
functions = list(profiling_func_filter)

all_functions = []
if isinstance(kernel_name, str):
kernel_name = [kernel_name]
for ker_name in kernel_name:
functions = list(filter(lambda x: x.name.startswith(ker_name), function_events)) # pylint: disable=cell-var-from-loop
assert len(functions) == n_repeat, f"the profiling number for kernel: '{ker_name}' not match, {len(functions)}"
all_functions.append(functions)
# profiling_func_filter = filter(lambda x: x.name.startswith("__profile_kernel_of_func"), function_events)
def extract_kernels(funcs):
kernels = []
kernels += list(itertools.chain.from_iterable(map(lambda func: extract_kernels(func.cpu_children), funcs)))
kernels += list(itertools.chain.from_iterable([func.kernels for func in funcs]))
return kernels

kernels = [extract_kernels(func.cpu_children) for func in functions]
assert len(kernels) == n_repeat, "the profiling number not match"
# Make the time to the milliseconds.
times = torch.tensor([sum(map(lambda elem: elem.self_device_time_total, f)) * 1e-3 for f in zip(*all_functions)],
dtype=torch.float)
times = torch.tensor([sum([k.duration for k in ks]) * 1e-3 for ks in kernels], dtype=torch.float)
return _summarize_statistics(times, quantiles, return_mode)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -265,8 +265,7 @@ def benchmark(Z, H, N_CTX, D_HEAD, CAUSAL, provider):
), attn_mask=None, dropout_p=0.0, is_causal=CAUSAL, scale=sm_scale).to(torch.float32)
atol = 1e-1 if N_CTX == 16384 else 1e-2
benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=atol, rtol=1e-3, err_msg='triton to torch')
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10, quantiles=quantiles,
kernel_name='_attn_fwd')
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10, quantiles=quantiles)

elif provider == 'xetla':
module_name = f'flash_attn_causal_{CAUSAL}'.lower()
Expand All @@ -281,8 +280,7 @@ def benchmark(Z, H, N_CTX, D_HEAD, CAUSAL, provider):
l = torch.empty((size_ml, ), device='xpu', dtype=torch.float)

xetla_fn = lambda: func(q, k, v, out, dropout_mask, bias, m, l, Z, H, D_HEAD, N_CTX, N_CTX, sm_scale)
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(xetla_fn, n_warmup=10, n_repeat=10, quantiles=quantiles,
kernel_name='gpu::xetla::fmha::FmhaForwardKernel<')
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(xetla_fn, n_warmup=10, n_repeat=10, quantiles=quantiles)

else:
raise NotImplementedError(f'Unsupported provider {provider}')
Expand Down
15 changes: 2 additions & 13 deletions benchmarks/triton_kernels_benchmark/fused_softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,7 @@ def benchmark(M, N, provider):
triton_fn = lambda: softmax(x, out)
torch_fn = lambda: torch.softmax(x, axis=-1)
benchmark_suit.assert_close(triton_fn(), torch_fn(), err_msg="triton to torch")
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, quantiles=quantiles, n_warmup=10, n_repeat=10,
kernel_name="softmax_kernel")
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, quantiles=quantiles, n_warmup=10, n_repeat=10)

elif provider == "torch-jit":
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(lambda: naive_softmax(x), quantiles=quantiles,
Expand All @@ -145,17 +144,7 @@ def benchmark(M, N, provider):
xetla_fn = lambda: func(x, out, 0)
torch_fn = lambda: torch.softmax(x, axis=-1)
# benchmark_suit.assert_close(xetla_fn(), torch_fn(), err_msg="xetla to torch")
kernels_name = {
"softmax_shape_4096_256": "mat1_4096x256_bf16_cfg0",
"softmax_shape_4096_1024": "mat1_4096x1024_bf16_cfg0",
"softmax_shape_4096_2048": "mat1_4096x2048_bf16_cfg0",
"softmax_shape_4096_4096": "mat1_4096x4096_bf16_cfg0",
"softmax_shape_4096_8192": "mat1_4096x8k_bf16_cfg0",
"softmax_shape_4096_16384": "mat1_4096x16k_bf16_cfg0",
"softmax_shape_4096_32768": "mat1_4096x32k_bf16_cfg0",
}
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(xetla_fn, quantiles=quantiles, n_warmup=10, n_repeat=10,
kernel_name=kernels_name[name])
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(xetla_fn, quantiles=quantiles, n_warmup=10, n_repeat=10)

else:
raise NotImplementedError(f"Unsupported provider {provider}")
Expand Down
35 changes: 3 additions & 32 deletions benchmarks/triton_kernels_benchmark/gemm_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ def benchmark(B, M, N, K, provider):
# Legacy profiler shows ~6000TFLOPS GeoMean for onednn measurements, so use more reliable method
do_bench = do_bench_elapsed_time
_, min_ms, max_ms, mean_ms, cv = do_bench(lambda: torch.matmul(torch_a, torch_b), n_warmup=10, n_repeat=10,
quantiles=quantiles, kernel_name='gemm_kernel')
quantiles=quantiles)
elif provider == 'triton':
assert len(a.shape) == len(b.shape), 'Incompatible sizes'
if len(a.shape) == 3:
Expand All @@ -301,8 +301,7 @@ def benchmark(B, M, N, K, provider):
rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3
benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch')
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10,
quantiles=quantiles,
kernel_name='matmul_kernel_with_block_pointers')
quantiles=quantiles)
elif provider == 'xetla':
if B == 1:
c = torch.zeros((M, N), device='xpu', dtype=torch.float32)
Expand All @@ -329,37 +328,9 @@ def xetla_func_with_acc_allocation():
xetla_fn = xetla_func_with_acc_allocation
torch_fn = lambda: torch.matmul(a, b).to(torch.float32)

kernels_name = {
'gemm_shape_1_1024_1024_1024': 'Test_1x1024x1024x1024_row_row',
'gemm_shape_1_2048_2048_2048': 'Test_1x2048x2048x2048_row_row',
'gemm_shape_1_4096_4096_4096': 'Test_1x4096x4096x4096_row_row',
'gemm_shape_1_8192_8192_8192': 'Test_1x8192x8192x8192_row_row',
'gemm_shape_1_1_5120_13824': 'Test_1x1x5120x13824_row_row',
'gemm_shape_1_4_4096_12288': 'Test_1x4x4096x12288_row_row',
'gemm_shape_1_512_8192_8192': 'Test_1x512x8192x8192_row_row',
'gemm_shape_1_512_8192_32768': 'Test_1x512x8192x32768_row_row',
'gemm_shape_1_512_32768_8192': 'Test_1x512x32768x8192_row_row',
'gemm_shape_1_1024_16384_8192': 'Test_1x1024x16384x8192_row_row',
'gemm_shape_1_1024_28672_8192': 'Test_1x1024x28672x8192_row_row',
'gemm_shape_1_3072_4096_3072': 'Test_1x3072x4096x3072_row_row',
'gemm_shape_1_4096_16384_8192': 'Test_1x4096x16384x8192_row_row',
'gemm_shape_1_8192_16384_1024': 'Test_1x8192x16384x1024_row_row',
'gemm_shape_1_8192_16384_4096': 'Test_1x8192x16384x4096_row_row',
'gemm_shape_1_16384_1024_8192': 'Test_1x16384x1024x8192_row_row',
'gemm_shape_1_16384_4096_8192': 'Test_1x16384x4096x8192_row_row',
'gemm_shape_1_16384_8192_1024': 'Test_1x16384x8192x1024_row_row',
'gemm_shape_1_16384_8192_4096': 'Test_1x16384x8192x4096_row_row',
'gemm_shape_4_32768_128_4096': 'Test_4x32768x128x4096_row_row',
'gemm_shape_4_32768_4096_128': 'Test_4x32768x4096x128_row_row',
'gemm_shape_32_4096_4096_128': 'Test_32x4096x4096x128_row_row',
'gemm_shape_4096_8_128_16384': 'Test_4096x8x128x16384_row_row',
'gemm_shape_4096_8_16384_128': 'Test_4096x8x16384x128_row_row',
'gemm_streamk_shape_3072_4096_3072': 'stream_k_gemm_run',
}

# benchmark_suit.assert_close(xetla_fn(), torch_fn(), atol=1e-4, rtol=1.0, err_msg='xetla to torch')
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(xetla_fn, n_warmup=10, n_repeat=10,
quantiles=quantiles, kernel_name=kernels_name[name])
quantiles=quantiles)
else:
raise NotImplementedError(f'Unsupported provider {provider}')

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -266,17 +266,15 @@ def benchmark(B, M, N, K, provider):
assert len(a.shape) == len(b.shape), 'Incompatible sizes'
if len(a.shape) == 3:
c = torch.empty((B, M, N), device='xpu', dtype=torch.float32)
kernel_name = 'matmul_kernel_with_block_pointers_batched'
else:
assert len(a.shape) == 2, 'Expecting shape of length 2'
c = torch.empty((M, N), device='xpu', dtype=torch.float32)
kernel_name = 'matmul_kernel_with_block_pointers'
triton_fn = lambda: matmul(a, b, d, c)
torch_fn = lambda: torch.matmul(a, b).to(torch.float32) + d
rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3
benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch')
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10,
quantiles=quantiles, kernel_name=kernel_name)
quantiles=quantiles)
else:
raise NotImplementedError(f'Unsupported provider {provider}')

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -268,17 +268,15 @@ def benchmark(B, M, N, K, provider):
assert len(a.shape) == len(b.shape), 'Incompatible sizes'
if len(a.shape) == 3:
c = torch.empty((B, M, N), device='xpu', dtype=torch.float32)
kernel_name = 'matmul_kernel_with_block_pointers_batched'
else:
assert len(a.shape) == 2, 'Expecting shape of length 2'
c = torch.empty((M, N), device='xpu', dtype=torch.float32)
kernel_name = 'matmul_kernel_with_block_pointers'
triton_fn = lambda: matmul(a, b, c)
torch_fn = lambda: torch.nn.functional.gelu(torch.matmul(a, b).to(torch.float32))
rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3
benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch')
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10,
quantiles=quantiles, kernel_name=kernel_name)
quantiles=quantiles)
else:
raise NotImplementedError(f'Unsupported provider {provider}')

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -256,17 +256,15 @@ def benchmark(B, M, N, K, provider):
assert len(a.shape) == len(b.shape), 'Incompatible sizes'
if len(a.shape) == 3:
c = torch.empty((B, M, N), device='xpu', dtype=torch.float32)
kernel_name = 'matmul_kernel_with_block_pointers_batched'
else:
assert len(a.shape) == 2, 'Expecting shape of length 2'
c = torch.empty((M, N), device='xpu', dtype=torch.float32)
kernel_name = 'matmul_kernel_with_block_pointers'
triton_fn = lambda: matmul(a, b, c)
torch_fn = lambda: torch.matmul(torch.exp(a), b).to(torch.float32)
rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3
benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch')
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10,
quantiles=quantiles, kernel_name=kernel_name)
quantiles=quantiles)
else:
raise NotImplementedError(f'Unsupported provider {provider}')

Expand Down
4 changes: 2 additions & 2 deletions benchmarks/triton_kernels_benchmark/gemm_splitk_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def benchmark(M, N, K, provider):
rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3
benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch')
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10,
quantiles=quantiles, kernel_name='_kernel')
quantiles=quantiles)
elif provider == 'xetla':
c = torch.zeros((M, N), device='xpu', dtype=torch.float32)
acc = torch.zeros((M, N), device='xpu', dtype=torch.float32)
Expand All @@ -172,7 +172,7 @@ def benchmark(M, N, K, provider):

# benchmark_suit.assert_close(xetla_fn(), torch_fn(), atol=1e-4, rtol=1.0, err_msg='xetla to torch')
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(xetla_fn, n_warmup=10, n_repeat=10,
quantiles=quantiles, kernel_name='split_k_gemm_run')
quantiles=quantiles)
else:
raise NotImplementedError(f'Unsupported provider {provider}')

Expand Down
5 changes: 2 additions & 3 deletions benchmarks/triton_kernels_benchmark/gemm_streamk_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,8 +280,7 @@ def benchmark(M, N, K, provider):
torch_fn = lambda: torch.matmul(a, b).to(torch.float32)
benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=1e-2, err_msg='triton to torch')
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10,
quantiles=quantiles,
kernel_name=['first_wave', 'full_tiles'])
quantiles=quantiles)
elif provider == 'xetla':
c = torch.zeros((M, N), device='xpu', dtype=torch.float32)
acc = torch.zeros((M, N), device='xpu', dtype=torch.float32)
Expand All @@ -294,7 +293,7 @@ def benchmark(M, N, K, provider):

# benchmark_suit.assert_close(xetla_fn(), torch_fn(), atol=1e-4, rtol=1.0, err_msg='xetla to torch')
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(xetla_fn, n_warmup=10, n_repeat=10,
quantiles=quantiles, kernel_name='stream_k_gemm_run')
quantiles=quantiles)
else:
raise NotImplementedError(f'Unsupported provider {provider}')

Expand Down
3 changes: 1 addition & 2 deletions benchmarks/triton_kernels_benchmark/prefix_sums.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,7 @@ def benchmark(M, N, AXIS, provider):

if provider == 'triton':
triton_fn = lambda: scan_kernel[(1, )](x, BLOCK_SIZE_M=M, BLOCK_SIZE_N=N, AXIS=AXIS)
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, quantiles=quantiles,
kernel_name='scan_kernel')
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, quantiles=quantiles)
else:
raise NotImplementedError(f'Unsupported provider {provider}')

Expand Down
Loading

0 comments on commit 5b7374b

Please sign in to comment.