diff --git a/test/test_gpu/skip_tests_h100_pytorch.yaml b/test/test_gpu/skip_tests_h100_pytorch.yaml index dcfa347d..17250050 100644 --- a/test/test_gpu/skip_tests_h100_pytorch.yaml +++ b/test/test_gpu/skip_tests_h100_pytorch.yaml @@ -30,7 +30,7 @@ gemm: - triton_tma_persistent_matmul - triton_tma_persistent_cached_matmul - hstu_triton_matmul - - colfax_gemm + - colfax_cutlass_matmul jagged_layer_norm: jagged_mean: jagged_softmax: diff --git a/tritonbench/operators/gemm/operator.py b/tritonbench/operators/gemm/operator.py index 835b3a44..3293f31a 100644 --- a/tritonbench/operators/gemm/operator.py +++ b/tritonbench/operators/gemm/operator.py @@ -2,6 +2,7 @@ import csv import os from typing import Any, Callable, Generator, List, Optional, Tuple + import torch import torch._inductor.config as inductor_config import triton @@ -39,7 +40,10 @@ ) if IS_FBCODE: - from hammer.ops.triton.triton_matmul import triton_matmul as hstu_triton_matmul_kernel + from hammer.ops.triton.triton_matmul import ( + triton_matmul as hstu_triton_matmul_kernel, + ) + HAS_HAMMER = True else: HAS_HAMMER = False