Skip to content

Commit

Permalink
Remove double export and session init in perf test (#114907)
Browse files Browse the repository at this point in the history
Summary:
Previously both `optimize_ctx` call and `experiment` call will do export and session creation, ending up doubling the resource cost. This PR makes `experiment` call re-use the onnx model created by `optimize_ctx`.

X-link: pytorch/pytorch#114907
Approved by: https://github.com/thiagocrepaldi
ghstack dependencies: #110178

Reviewed By: atalman

Differential Revision: D51778269

fbshipit-source-id: ea29466a2df8f4801cb9f0a7870d3a8c6eda0a2b
  • Loading branch information
BowenBao authored and facebook-github-bot committed Dec 4, 2023
1 parent 8f7ee82 commit 7de2aed
Showing 1 changed file with 28 additions and 27 deletions.
55 changes: 28 additions & 27 deletions userbenchmark/dynamo/dynamobench/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -707,6 +707,7 @@ def speedup_experiment_onnx(
onnx_model_cls: Type[OnnxModelFromTorchScript],
args,
model_iter_fn,
onnx_model: OnnxModel,
model,
example_inputs,
**kwargs,
Expand All @@ -715,9 +716,8 @@ def speedup_experiment_onnx(
Measure speedups over eager.
This function is responsible for the following:
1. Creation of OnnxModel, which handles export, ort initialization.
2. Creating iobinding with OnnxModel if device is CUDA, which is essential for perf measurement.
3. Running ORT with OnnxModel.
1. Creating iobinding with OnnxModel if device is CUDA, which is essential for perf measurement.
2. Running ORT with OnnxModel.
Writes to ./{output_filename}, which should be
`pathlib.Path(self.output_dir) / f"{self.compiler}_{suite}_{self.dtype}_{self.mode}_{self.device}_{self.testing}.csv".
Expand All @@ -729,16 +729,7 @@ def speedup_experiment_onnx(
should_randomize_input = args.randomize_input
times = args.iterations_per_run

onnx_model = onnx_model_cls(
args.output_directory or ".",
model,
copy.deepcopy(example_inputs),
dynamic_shapes=args.dynamic_shapes,
)

def create_onnx_input_binded_fn(
onnx_model: OnnxModelFromTorchScript, pt_inputs, example_outputs
):
def create_onnx_input_binded_fn(onnx_model: OnnxModel, pt_inputs, example_outputs):
# Goal is to move the iobinding creation outside of the timer function.
iobinding, outputs = onnx_model.create_iobinding(pt_inputs, example_outputs)

Expand All @@ -749,7 +740,7 @@ def onnxrt_model_iter_fn(model, inputs, collect_outputs=True):

return onnxrt_model_iter_fn

def create_onnx_fn(onnx_model: OnnxModelFromTorchScript, pt_inputs):
def create_onnx_fn(onnx_model: OnnxModel, pt_inputs):
# NOTE: Making perf comparison fair by moving out the i/o adapting part.
# 1. Pre-adapt `pt_inputs` to `onnx_inputs` here.
# 2. Drop `onnx_outputs` to `pt_outputs` adapting. Output comparison is not part of perf measurement.
Expand All @@ -760,7 +751,7 @@ def onnxrt_model_iter_fn(model, inputs, collect_outputs=True):

return onnxrt_model_iter_fn

def timed_onnx(model, onnx_model: OnnxModelFromTorchScript, inputs):
def timed_onnx(model, onnx_model: OnnxModel, inputs):
if current_device == "cpu" or onnx_model.is_cpu():
onnxrt_model_iter_fn = create_onnx_fn(onnx_model, inputs)
else:
Expand Down Expand Up @@ -1526,9 +1517,14 @@ def parse_exception(self, exception: Exception) -> OnnxExportErrorRow:
)


@dataclasses.dataclass
class OnnxContext:
onnx_model: Optional[OnnxModel] = None


def optimize_onnx_ctx(
output_directory: str,
onnx_model_cls: Type[OnnxModelFromTorchScript],
onnx_model_cls: Type[OnnxModel],
run_n_iterations: Callable,
dynamic_shapes: bool = False,
) -> Callable:
Expand All @@ -1537,7 +1533,8 @@ def optimize_onnx_ctx(
# 1. Export and cache model.
# 2. Create iobinding for ORT.
# 3. Run ORT for n iterations.
onnx_model: Optional[OnnxModelFromTorchScript] = None
# The cached model is stored in 'context' under the returned callable.
context = OnnxContext()
test_data_dumped = False

def run_n_iterations_onnx(model, inputs, n=2):
Expand All @@ -1553,14 +1550,15 @@ def run_n_iterations_onnx(model, inputs, n=2):
output_error_filename = output_filename[:-4] + "_export_error.csv"
parser = OnnxExportErrorParser(current_device, current_name, current_batch_size)
try:
nonlocal onnx_model
if onnx_model is None:
onnx_model = onnx_model_cls(
nonlocal context
if context.onnx_model is None:
context.onnx_model = onnx_model_cls(
output_directory,
model,
copy.deepcopy(inputs),
dynamic_shapes=dynamic_shapes,
)
onnx_model = context.onnx_model

for _ in range(n):
nonlocal test_data_dumped
Expand Down Expand Up @@ -1598,6 +1596,8 @@ def run_n_iterations_onnx(model, inputs, n=2):
output_csv(output_error_filename, parsed_error.headers, parsed_error.row)
raise

run_n_iterations_onnx.context = context

return run_n_iterations_onnx


Expand Down Expand Up @@ -2473,6 +2473,11 @@ def warmup(fn, model, example_inputs, mode, niters=5):
f"{ok:3}/{total:3} +{frames_third_pass} frames {compilation_time:3.0f}s"
)

if experiment.func is speedup_experiment_onnx:
experiment = functools.partial(
experiment, optimized_model_iter_fn.context.onnx_model
)

if not hasattr(model, name):
model.name = name
results.append(experiment(model, example_inputs, **experiment_kwargs))
Expand Down Expand Up @@ -3356,9 +3361,7 @@ def run(runner, args, original_dir=None):
optimize_ctx = functools.partial(
optimize_onnx_ctx, args.output_directory or ".", OnnxModelFromTorchScript
)
experiment = functools.partial(
speedup_experiment_onnx, OnnxModelFromTorchScript
)
experiment = speedup_experiment_onnx
output_filename = "torchscript_onnx.csv"
current_onnx_compiler = "torchscript"
elif args.dynamo_onnx:
Expand All @@ -3368,7 +3371,7 @@ def run(runner, args, original_dir=None):
OnnxModelFromDynamo,
dynamic_shapes=args.dynamic_shapes,
)
experiment = functools.partial(speedup_experiment_onnx, OnnxModelFromDynamo)
experiment = speedup_experiment_onnx
output_filename = "dynamo_onnx.csv"
current_onnx_compiler = "dynamo"
elif args.dynamo_onnx_aot_inline:
Expand All @@ -3378,9 +3381,7 @@ def run(runner, args, original_dir=None):
OnnxModelFromDynamoAotInline,
dynamic_shapes=args.dynamic_shapes,
)
experiment = functools.partial(
speedup_experiment_onnx, OnnxModelFromDynamoAotInline
)
experiment = speedup_experiment_onnx
output_filename = "dynamo_onnx_aot_inline.csv"
current_onnx_compiler = "dynamo"
elif args.speedup_dynamo_ts:
Expand Down

0 comments on commit 7de2aed

Please sign in to comment.