From 2133c0a0214339f1e5f26d1135c964348f4df070 Mon Sep 17 00:00:00 2001 From: Xu Zhao Date: Mon, 18 Nov 2024 18:29:36 -0500 Subject: [PATCH] Fix test modes --- tritonbench/utils/triton_op.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tritonbench/utils/triton_op.py b/tritonbench/utils/triton_op.py index a1b9b5d5..b70a03dc 100644 --- a/tritonbench/utils/triton_op.py +++ b/tritonbench/utils/triton_op.py @@ -568,7 +568,6 @@ def __init__( self.use_cuda_graphs = ( self.tb_args.cudagraph if self.tb_args.cudagraph else self.use_cuda_graphs ) - # we accept both "fwd" and "eval" _translate_mode(self.tb_args) if self.tb_args.mode == "fwd": self.mode = Mode.FWD @@ -578,8 +577,8 @@ def __init__( self.mode = Mode.FWD_NO_GRAD else: assert ( - self.tb_args.mode == "bwd" or self.tb_args.bwd - ), f"We only accept 3 test modes: fwd(eval), fwd_bwd(train), or bwd." + self.tb_args.mode == "bwd" + ), "We only accept test modes: fwd, bwd, fwd_bwd, or fwd_no_grad." self.mode = Mode.BWD self.device = tb_args.device self.required_metrics = (