diff --git a/tritonbench/utils/parser.py b/tritonbench/utils/parser.py index 22cfcb6c..29723e60 100644 --- a/tritonbench/utils/parser.py +++ b/tritonbench/utils/parser.py @@ -1,5 +1,5 @@ import argparse -from typing import List +from typing import List, Optional from tritonbench.utils.env_utils import AVAILABLE_PRECISIONS from tritonbench.utils.triton_op import DEFAULT_RUN_ITERS, DEFAULT_WARMUP, IS_FBCODE @@ -182,14 +182,17 @@ def _find_param_loc(params, key: str) -> int: def _remove_params(params, loc): if loc == -1: return params + if loc == len(params) - 1: + return params[:loc] if params[loc + 1].startswith("--"): return params[:loc] + params[loc + 1 :] return params[:loc] + params[loc + 2 :] -def add_cmd_parameter(args: List[str], name: str, value: str) -> List[str]: +def add_cmd_parameter(args: List[str], name: str, value: Optional[str]=None) -> List[str]: args.append(name) - args.append(value) + if value: + args.append(value) return args