diff --git a/tritonbench/utils/triton_op.py b/tritonbench/utils/triton_op.py index 6c7d40c..92b9925 100644 --- a/tritonbench/utils/triton_op.py +++ b/tritonbench/utils/triton_op.py @@ -719,13 +719,13 @@ def run( for _dryrun_input_id in range(self._input_id): self.example_inputs = self.get_example_inputs() for input_id in input_id_range: + self.example_inputs = self.get_example_inputs() x_val = self.get_x_val(self.example_inputs) if "proton" in self.required_metrics: proton.activate(self._proton_session_id) proton.enter_scope(f"x_val_{x_val}") proton.deactivate(self._proton_session_id) self._cur_input_id = input_id - self.example_inputs = self.get_example_inputs() if self.example_inputs is None: logger.warn( f"The input generator get_input_iter() has depleted at id {input_id}. Available number of "