diff --git a/tritonbench/utils/triton_op.py b/tritonbench/utils/triton_op.py index 9ffa185..c7855cf 100644 --- a/tritonbench/utils/triton_op.py +++ b/tritonbench/utils/triton_op.py @@ -904,7 +904,7 @@ def run_and_capture_jit(self, *args, **kwargs): return None def kernel_hash(self, fn): - AST = triton.compiler.ASTSource(fn=fn, signature={}, constants={}) + AST = triton.compiler.ASTSource(fn=fn, signature={}) sorted_sig = [v for k, v in sorted(AST.signature.items())] key = f"{AST.attrs.hash()}-{sorted_sig}" hashed = hashlib.sha256(key.encode("utf-8")).hexdigest() @@ -1200,7 +1200,7 @@ def _init_extra_metrics() -> Dict[str, Any]: metrics.best_config = self.best_config(fn) if "all_configs" in self.required_metrics: metrics.all_configs = self.all_configs(fn) - if "kernel_source_hash" in self.required_metrics or "gemm" in self.name: + if "kernel_source_hash" in self.required_metrics: metrics.kernel_source_hash = self.kernel_hash(fn) # run the hidden metric "_compile_time_in_task" # to get the compile time in parent process