diff --git a/userbenchmark/dynamo/dynamobench/common.py b/userbenchmark/dynamo/dynamobench/common.py index c19578a3a7..96d356b8b3 100644 --- a/userbenchmark/dynamo/dynamobench/common.py +++ b/userbenchmark/dynamo/dynamobench/common.py @@ -707,6 +707,7 @@ def speedup_experiment_onnx( onnx_model_cls: Type[OnnxModelFromTorchScript], args, model_iter_fn, + onnx_model: OnnxModel, model, example_inputs, **kwargs, @@ -715,9 +716,8 @@ def speedup_experiment_onnx( Measure speedups over eager. This function is responsible for the following: - 1. Creation of OnnxModel, which handles export, ort initialization. - 2. Creating iobinding with OnnxModel if device is CUDA, which is essential for perf measurement. - 3. Running ORT with OnnxModel. + 1. Creating iobinding with OnnxModel if device is CUDA, which is essential for perf measurement. + 2. Running ORT with OnnxModel. Writes to ./{output_filename}, which should be `pathlib.Path(self.output_dir) / f"{self.compiler}_{suite}_{self.dtype}_{self.mode}_{self.device}_{self.testing}.csv". @@ -729,16 +729,7 @@ def speedup_experiment_onnx( should_randomize_input = args.randomize_input times = args.iterations_per_run - onnx_model = onnx_model_cls( - args.output_directory or ".", - model, - copy.deepcopy(example_inputs), - dynamic_shapes=args.dynamic_shapes, - ) - - def create_onnx_input_binded_fn( - onnx_model: OnnxModelFromTorchScript, pt_inputs, example_outputs - ): + def create_onnx_input_binded_fn(onnx_model: OnnxModel, pt_inputs, example_outputs): # Goal is to move the iobinding creation outside of the timer function. iobinding, outputs = onnx_model.create_iobinding(pt_inputs, example_outputs) @@ -749,7 +740,7 @@ def onnxrt_model_iter_fn(model, inputs, collect_outputs=True): return onnxrt_model_iter_fn - def create_onnx_fn(onnx_model: OnnxModelFromTorchScript, pt_inputs): + def create_onnx_fn(onnx_model: OnnxModel, pt_inputs): # NOTE: Making perf comparison fair by moving out the i/o adapting part. # 1. Pre-adapt `pt_inputs` to `onnx_inputs` here. # 2. Drop `onnx_outputs` to `pt_outputs` adapting. Output comparison is not part of perf measurement. @@ -760,7 +751,7 @@ def onnxrt_model_iter_fn(model, inputs, collect_outputs=True): return onnxrt_model_iter_fn - def timed_onnx(model, onnx_model: OnnxModelFromTorchScript, inputs): + def timed_onnx(model, onnx_model: OnnxModel, inputs): if current_device == "cpu" or onnx_model.is_cpu(): onnxrt_model_iter_fn = create_onnx_fn(onnx_model, inputs) else: @@ -1526,9 +1517,14 @@ def parse_exception(self, exception: Exception) -> OnnxExportErrorRow: ) +@dataclasses.dataclass +class OnnxContext: + onnx_model: Optional[OnnxModel] = None + + def optimize_onnx_ctx( output_directory: str, - onnx_model_cls: Type[OnnxModelFromTorchScript], + onnx_model_cls: Type[OnnxModel], run_n_iterations: Callable, dynamic_shapes: bool = False, ) -> Callable: @@ -1537,7 +1533,8 @@ def optimize_onnx_ctx( # 1. Export and cache model. # 2. Create iobinding for ORT. # 3. Run ORT for n iterations. - onnx_model: Optional[OnnxModelFromTorchScript] = None + # The cached model is stored in 'context' under the returned callable. + context = OnnxContext() test_data_dumped = False def run_n_iterations_onnx(model, inputs, n=2): @@ -1553,14 +1550,15 @@ def run_n_iterations_onnx(model, inputs, n=2): output_error_filename = output_filename[:-4] + "_export_error.csv" parser = OnnxExportErrorParser(current_device, current_name, current_batch_size) try: - nonlocal onnx_model - if onnx_model is None: - onnx_model = onnx_model_cls( + nonlocal context + if context.onnx_model is None: + context.onnx_model = onnx_model_cls( output_directory, model, copy.deepcopy(inputs), dynamic_shapes=dynamic_shapes, ) + onnx_model = context.onnx_model for _ in range(n): nonlocal test_data_dumped @@ -1598,6 +1596,8 @@ def run_n_iterations_onnx(model, inputs, n=2): output_csv(output_error_filename, parsed_error.headers, parsed_error.row) raise + run_n_iterations_onnx.context = context + return run_n_iterations_onnx @@ -2473,6 +2473,11 @@ def warmup(fn, model, example_inputs, mode, niters=5): f"{ok:3}/{total:3} +{frames_third_pass} frames {compilation_time:3.0f}s" ) + if experiment.func is speedup_experiment_onnx: + experiment = functools.partial( + experiment, optimized_model_iter_fn.context.onnx_model + ) + if not hasattr(model, name): model.name = name results.append(experiment(model, example_inputs, **experiment_kwargs)) @@ -3356,9 +3361,7 @@ def run(runner, args, original_dir=None): optimize_ctx = functools.partial( optimize_onnx_ctx, args.output_directory or ".", OnnxModelFromTorchScript ) - experiment = functools.partial( - speedup_experiment_onnx, OnnxModelFromTorchScript - ) + experiment = speedup_experiment_onnx output_filename = "torchscript_onnx.csv" current_onnx_compiler = "torchscript" elif args.dynamo_onnx: @@ -3368,7 +3371,7 @@ def run(runner, args, original_dir=None): OnnxModelFromDynamo, dynamic_shapes=args.dynamic_shapes, ) - experiment = functools.partial(speedup_experiment_onnx, OnnxModelFromDynamo) + experiment = speedup_experiment_onnx output_filename = "dynamo_onnx.csv" current_onnx_compiler = "dynamo" elif args.dynamo_onnx_aot_inline: @@ -3378,9 +3381,7 @@ def run(runner, args, original_dir=None): OnnxModelFromDynamoAotInline, dynamic_shapes=args.dynamic_shapes, ) - experiment = functools.partial( - speedup_experiment_onnx, OnnxModelFromDynamoAotInline - ) + experiment = speedup_experiment_onnx output_filename = "dynamo_onnx_aot_inline.csv" current_onnx_compiler = "dynamo" elif args.speedup_dynamo_ts: