From 06e6703aaafe2a7246930b3c1f627a251181bc8f Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Wed, 27 Nov 2024 02:21:30 -0800 Subject: [PATCH] Introducing TensorRT lazy export and caching option with trt_compile() (#11266) * Introducing trt_compile() Signed-off-by: Boris Fomitchev * Adding tests for dynamic axes Signed-off-by: Boris Fomitchev * Fixed folding constants and style check Signed-off-by: Boris Fomitchev --------- Signed-off-by: Boris Fomitchev --- nemo/export/__init__.py | 2 + nemo/export/tensorrt_lazy_compiler.py | 714 ++++++++++++++++++++++++++ tests/export/test_trt_compile.py | 139 +++++ 3 files changed, 855 insertions(+) create mode 100644 nemo/export/tensorrt_lazy_compiler.py create mode 100644 tests/export/test_trt_compile.py diff --git a/nemo/export/__init__.py b/nemo/export/__init__.py index d9155f923f18..6b1f8c90aa8f 100644 --- a/nemo/export/__init__.py +++ b/nemo/export/__init__.py @@ -11,3 +11,5 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from nemo.export.tensorrt_lazy_compiler import trt_compile diff --git a/nemo/export/tensorrt_lazy_compiler.py b/nemo/export/tensorrt_lazy_compiler.py new file mode 100644 index 000000000000..ab40278efa94 --- /dev/null +++ b/nemo/export/tensorrt_lazy_compiler.py @@ -0,0 +1,714 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import inspect +import os +import tempfile +import threading +from collections import OrderedDict +from logging import getLogger +from pathlib import Path +from types import MethodType +from typing import Any, Dict, List, Sequence, Tuple, Union + +import torch + +from nemo.utils.export_utils import add_casts_around_norms, replace_for_export +from nemo.utils.import_utils import safe_import + +polygraphy, polygraphy_imported = safe_import("polygraphy") +if polygraphy_imported: + from polygraphy.backend.common import bytes_from_path + from polygraphy.backend.trt import ( + CreateConfig, + Profile, + engine_bytes_from_network, + engine_from_bytes, + network_from_onnx_path, + ) + +trt, trt_imported = safe_import("tensorrt") +torch_tensorrt, _ = safe_import("torch_tensorrt") +cudart, _ = safe_import("cuda.cudart") + +lock_sm = threading.Lock() + + +def trt_to_torch_dtype_dict(): + """ + Map of TRT dtype -> Torch dtype + """ + return { + trt.int32: torch.int32, + trt.float32: torch.float32, + trt.float16: torch.float16, + trt.bfloat16: torch.float16, + trt.int64: torch.int64, + trt.int8: torch.int8, + trt.bool: torch.bool, + } + + +def get_profile_shapes(input_shape: Sequence[int], dynamic_batchsize: Sequence[int] | None): + """ + Given a sample input shape, calculate min/opt/max shapes according to dynamic_batchsize. + """ + + def scale_batch_size(input_shape: Sequence[int], scale_num: int): + scale_shape = [*input_shape] + scale_shape[0] = scale_num + return scale_shape + + # Use the dynamic batchsize range to generate the min, opt and max model input shape + if dynamic_batchsize: + min_input_shape = scale_batch_size(input_shape, dynamic_batchsize[0]) + opt_input_shape = scale_batch_size(input_shape, dynamic_batchsize[1]) + max_input_shape = scale_batch_size(input_shape, dynamic_batchsize[2]) + else: + min_input_shape = opt_input_shape = max_input_shape = input_shape + return min_input_shape, opt_input_shape, max_input_shape + + +def get_dynamic_axes(profiles): + """ + This method calculates dynamic_axes to use in onnx.export(). + Args: + profiles: [[min,opt,max],...] list of profile dimensions + """ + dynamic_axes: dict[str, list[int]] = {} + if not profiles: + return dynamic_axes + for profile in profiles: + for key in profile: + axes = [] + vals = profile[key] + for i in range(len(vals[0])): + if vals[0][i] != vals[2][i]: + axes.append(i) + if len(axes) > 0: + dynamic_axes[key] = axes + return dynamic_axes + + +def cuassert(cuda_ret): + """ + Error reporting method for CUDA calls. + Args: + cuda_ret: CUDA return code. + """ + err = cuda_ret[0] + if err != 0: + raise RuntimeError(f"CUDA ERROR: {err}") + if len(cuda_ret) > 1: + return cuda_ret[1] + return None + + +class ShapeError(Exception): + """ + Exception class to report errors from setting TRT plan input shapes + """ + + pass + + +class TRTEngine: + """ + An auxiliary class to implement running of TRT optimized engines + + """ + + def __init__(self, plan_path, logger=None): + """ + Loads serialized engine, creates execution context and activates it + Args: + plan_path: path to serialized TRT engine. + logger: optional logger object + """ + self.plan_path = plan_path + self.logger = logger or getLogger("trt_compile") + self.logger.info(f"Loading TensorRT engine: {self.plan_path}") + self.engine = engine_from_bytes(bytes_from_path(self.plan_path)) + self.tensors = OrderedDict() + self.cuda_graph_instance = None # cuda graph + self.context = self.engine.create_execution_context() + self.input_names = [] + self.output_names = [] + self.dtypes = [] + self.cur_profile = 0 + self.input_table = {} + dtype_dict = trt_to_torch_dtype_dict() + for idx in range(self.engine.num_io_tensors): + binding = self.engine[idx] + if self.engine.get_tensor_mode(binding) == trt.TensorIOMode.INPUT: + self.input_names.append(binding) + elif self.engine.get_tensor_mode(binding) == trt.TensorIOMode.OUTPUT: + self.output_names.append(binding) + dtype = dtype_dict[self.engine.get_tensor_dtype(binding)] + self.dtypes.append(dtype) + self.logger.info( + f"Loaded TensorRT engine: {self.plan_path}.\nInputs: {self.input_names}\nOutputs: {self.output_names}" + ) + + def allocate_buffers(self, device): + """ + Allocates outputs to run TRT engine + Args: + device: GPU device to allocate memory on + """ + ctx = self.context + + for i, binding in enumerate(self.output_names): + shape = list(ctx.get_tensor_shape(binding)) + if binding not in self.tensors or list(self.tensors[binding].shape) != shape: + t = torch.empty(shape, dtype=self.dtypes[i], device=device).contiguous() + self.tensors[binding] = t + ctx.set_tensor_address(binding, t.data_ptr()) + + def set_inputs(self, feed_dict, stream): + """ + Sets input bindings for TRT engine according to feed_dict + Args: + feed_dict: a dictionary [str->Tensor] + stream: CUDA stream to use + """ + e = self.engine + ctx = self.context + + last_profile = self.cur_profile + + def try_set_inputs(): + for binding in self.input_names: + t = feed_dict.get(self.input_table[binding], None) + if t is not None: + t = t.contiguous() + shape = t.shape + ctx.set_input_shape(binding, shape) + ctx.set_tensor_address(binding, t.data_ptr()) + + while True: + try: + try_set_inputs() + break + except ShapeError: + next_profile = (self.cur_profile + 1) % e.num_optimization_profiles + if next_profile == last_profile: + raise + self.cur_profile = next_profile + ctx.set_optimization_profile_async(self.cur_profile, stream) + except Exception: + raise + left = ctx.infer_shapes() + assert len(left) == 0 + + def infer(self, stream, use_cuda_graph=False): + """ + Runs TRT engine. + Args: + stream: CUDA stream to run on + use_cuda_graph: use CUDA graph. Note: requires all inputs to be the same GPU memory between calls. + """ + if use_cuda_graph: + if self.cuda_graph_instance is not None: + cuassert(cudart.cudaGraphLaunch(self.cuda_graph_instance, stream)) + cuassert(cudart.cudaStreamSynchronize(stream)) + else: + # do inference before CUDA graph capture + noerror = self.context.execute_async_v3(stream) + if not noerror: + raise ValueError("ERROR: inference failed.") + # capture cuda graph + cuassert( + cudart.cudaStreamBeginCapture( + stream, cudart.cudaStreamCaptureMode.cudaStreamCaptureModeThreadLocal + ) + ) + self.context.execute_async_v3(stream) + graph = cuassert(cudart.cudaStreamEndCapture(stream)) + self.cuda_graph_instance = cuassert(cudart.cudaGraphInstantiate(graph, 0)) + self.logger.info("CUDA Graph captured!") + else: + noerror = self.context.execute_async_v3(stream) + cuassert(cudart.cudaStreamSynchronize(stream)) + if not noerror: + raise ValueError("ERROR: inference failed.") + + return self.tensors + + +def make_tensor(d): + """ + Creates a new tensor from d, returns d if d is already a tensor + """ + return d if isinstance(d, torch.Tensor) else torch.tensor(d).cuda() + + +def unroll_input(input_names, input_example): + """ + Simulates list/tuple unrolling during ONNX export + """ + unrolled_input = {} + for name in input_names: + val = input_example[name] + if val is not None: + if isinstance(val, list) or isinstance(val, tuple): + for i in range(len(val)): + unrolled_input[f"{name}_{i}"] = make_tensor(val[i]) + else: + unrolled_input[name] = make_tensor(val) + return unrolled_input + + +def parse_groups( + ret: List[torch.Tensor], output_lists: List[List[int]] +) -> Tuple[Union[torch.Tensor, List[torch.Tensor]], ...]: + """ + Implements parsing of 'output_lists' arg of trt_compile(). + + Args: + ret: plain list of Tensors + + output_lists: list of output group sizes: to form some Lists/Tuples out of 'ret' List, this will be a list + of group dimensions, like [[], [5], [-1]] for returning Tensor, list of 5 items and dynamic list. + Format: [[group_n] | [], ...] + [] or group_n == 0 : next output from ret is a scalar + group_n > 0 : next output from ret is a list of group_n length + group_n == -1: next output is a dynamic list. This entry can be at any + position in output_lists, but can appear only once. + Returns: + Tuple of Union[torch.Tensor, List[torch.Tensor]], according to the grouping in output_lists + + """ + groups: Tuple[Union[torch.Tensor, List[torch.Tensor]], ...] = tuple() + cur = 0 + for l in range(len(output_lists)): + gl = output_lists[l] + assert len(gl) == 0 or len(gl) == 1 + if len(gl) == 0 or gl[0] == 0: + groups = (*groups, ret[cur]) + cur = cur + 1 + elif gl[0] > 0: + groups = (*groups, ret[cur : cur + gl[0]]) + cur = cur + gl[0] + elif gl[0] == -1: + rev_groups: Tuple[Union[torch.Tensor, List[torch.Tensor]], ...] = tuple() + rcur = len(ret) + for rl in range(len(output_lists) - 1, l, -1): + rgl = output_lists[rl] + assert len(rgl) == 0 or len(rgl) == 1 + if len(rgl) == 0 or rgl[0] == 0: + rcur = rcur - 1 + rev_groups = (*rev_groups, ret[rcur]) + elif rgl[0] > 0: + rcur = rcur - rgl[0] + rev_groups = (*rev_groups, ret[rcur : rcur + rgl[0]]) + else: + raise ValueError("Two -1 lists in output") + groups = (*groups, ret[cur:rcur], *rev_groups[::-1]) + break + return groups + + +class TrtCompiler: + """ + This class implements: + - TRT lazy persistent export + - Running TRT with optional fallback to Torch + (for TRT engines with limited profiles) + """ + + def __init__( + self, + model, + plan_path, + precision="fp16", + method="onnx", + input_names=None, + output_names=None, + output_lists=None, + export_args=None, + build_args=None, + input_profiles=None, + dynamic_batchsize=None, + use_cuda_graph=False, + timestamp=None, + fallback=False, + forward_override=None, + logger=None, + ): + """ + Initialization method: + Tries to load persistent serialized TRT engine + Saves its arguments for lazy TRT build on first forward() call + Args: + model: Model to "wrap". + plan_path : Path where to save persistent serialized TRT engine. + precision: TRT builder precision o engine model. Should be 'fp32'|'tf32'|'fp16'|'bf16'. + method: One of 'onnx'|'torch_trt'. + Default is 'onnx' (torch.onnx.export()->TRT). This is the most stable and efficient option. + 'torch_trt' may not work for some nets. Also AMP must be turned off for it to work. + input_names: Optional list of input names. If None, will be read from the function signature. + output_names: Optional list of output names. Note: If not None, patched forward() will return a dictionary. + output_lists: Optional list of output group sizes: when forward() returns Lists/Tuples, this will be a list + of their dimensions, like [[], [5], [-1]] for Tensor, list of 5 items and dynamic list. + export_args: Optional args to pass to export method. See onnx.export() and Torch-TensorRT docs for details. + build_args: Optional args to pass to TRT builder. See polygraphy.Config for details. + input_profiles: Optional list of profiles for TRT builder and ONNX export. + Each profile is a map of the form : {"input id" : [min_shape, opt_shape, max_shape], ...}. + dynamic_batchsize: A sequence with three elements to define the input batch size range for the model to be + converted. Should be a sequence like [MIN_BATCH, OPT_BATCH, MAX_BATCH]. + [note]: If neither input_profiles nor dynamic_batchsize specified, static shapes will be used. + use_cuda_graph: Use CUDA Graph for inference. Note: inputs have to be the same GPU memory between calls! + timestamp: Optional timestamp to rebuild TRT engine (e.g. if config file changes). + fallback: Allow to fall back to Pytorch when TRT inference fails (e.g, shapes exceed max profile). + """ + + method_vals = ["onnx", "torch_trt"] + if method not in method_vals: + raise ValueError(f"trt_compile(): 'method' should be one of {method_vals}, got: {method}.") + precision_vals = ["fp32", "tf32", "fp16", "bf16"] + if precision not in precision_vals: + raise ValueError(f"trt_compile(): 'precision' should be one of {precision_vals}, got: {precision}.") + + self.plan_path = plan_path + self.precision = precision + self.method = method + self.return_dict = output_names is not None + self.output_names = output_names or [] + self.output_lists = output_lists or [] + self.profiles = input_profiles or [] + self.dynamic_batchsize = dynamic_batchsize + self.export_args = export_args or {} + self.build_args = build_args or {} + self.engine: TRTEngine | None = None + self.use_cuda_graph = use_cuda_graph + self.fallback = fallback + self.disabled = False + + self.logger = logger or getLogger("trt_compile") + self.argspec = inspect.getfullargspec(model.forward) + # Normally we read input_names from forward() but can be overridden + if input_names is None: + input_names = self.argspec.args[1:] + self.defaults = {} + if self.argspec.defaults is not None: + for i in range(len(self.argspec.defaults)): + d = self.argspec.defaults[-i - 1] + if d is not None: + d = make_tensor(d) + self.defaults[self.argspec.args[-i - 1]] = d + + self.input_names = input_names + self.old_forward = model.forward + + # Force engine rebuild if older than the timestamp + if timestamp is not None and os.path.exists(self.plan_path) and os.path.getmtime(self.plan_path) < timestamp: + os.remove(self.plan_path) + + def _inputs_to_dict(self, input_example): + trt_inputs = {} + for i, inp in enumerate(input_example): + input_name = self.input_names[i] + trt_inputs[input_name] = inp + return trt_inputs + + def _load_engine(self): + """ + Loads TRT plan from disk and activates its execution context. + """ + try: + self.engine = TRTEngine(self.plan_path, self.logger) + # Make sure we have names correct + input_table = {} + for name in self.engine.input_names: + if name.startswith("__") and name not in self.input_names: + orig_name = name[2:] + else: + orig_name = name + input_table[name] = orig_name + self.engine.input_table = input_table + self.logger.info(f"Engine loaded, inputs:{self.engine.input_table}") + except Exception as e: + self.logger.info(f"Exception while loading the engine:\n{e}") + + def forward(self, model, argv, kwargs): + """ + Main forward method: + Builds TRT engine if not available yet. + Tries to run TRT engine + If exception thrown and self.callback==True: falls back to original Pytorch + + Args: Passing through whatever args wrapped module's forward() has + Returns: Passing through wrapped module's forward() return value(s) + + """ + args = self.defaults + args.update(kwargs) + if len(argv) > 0: + args.update(self._inputs_to_dict(argv)) + + if self.engine is None and not self.disabled: + # Restore original forward for export + new_forward = model.forward + model.forward = self.old_forward + try: + self._load_engine() + if self.engine is None: + build_args = args.copy() + with torch.no_grad(): + self._build_and_save(model, build_args) + # This will reassign input_names from the engine + self._load_engine() + assert self.engine is not None + except Exception as e: + if self.fallback: + self.logger.info(f"Failed to build engine: {e}") + self.disabled = True + else: + raise e + if not self.disabled and not self.fallback: + # Delete all parameters + for param in model.parameters(): + del param + # Call empty_cache to release GPU memory + torch.cuda.empty_cache() + # restore TRT hook + model.forward = new_forward + # Run the engine + try: + if self.engine is not None: + # forward_trt is not thread safe as we do not use per-thread execution contexts + with lock_sm: + device = torch.cuda.current_device() + stream = torch.cuda.Stream(device=device) + self.engine.set_inputs(unroll_input(self.input_names, args), stream.cuda_stream) + self.engine.allocate_buffers(device=device) + # Need this to synchronize with Torch stream + stream.wait_stream(torch.cuda.current_stream()) + ret = self.engine.infer(stream.cuda_stream, use_cuda_graph=self.use_cuda_graph) + # if output_names is not None, return dictionary + if not self.return_dict: + ret = list(ret.values()) + if self.output_lists: + ret = parse_groups(ret, self.output_lists) + elif len(ret) == 1: + ret = ret[0] + return ret + except Exception as e: + if self.fallback: + self.logger.info(f"Exception: {e}\nFalling back to Pytorch ...") + else: + raise e + return self.old_forward(*argv, **kwargs) + + def _onnx_to_trt(self, onnx_path): + """ + Builds TRT engine from ONNX file at onnx_path and saves to self.plan_path + """ + + profiles = [] + for profile in self.profiles: + p = Profile() + for id, val in profile.items(): + p.add(id, min=val[0], opt=val[1], max=val[2]) + profiles.append(p) + + build_args = self.build_args.copy() + build_args["tf32"] = self.precision != "fp32" + if self.precision == "fp16": + build_args["fp16"] = True + elif self.precision == "bf16": + build_args["bf16"] = True + + self.logger.info(f"Building TensorRT engine for {onnx_path}: {self.plan_path}") + network = network_from_onnx_path(onnx_path, flags=[trt.OnnxParserFlag.NATIVE_INSTANCENORM]) + return engine_bytes_from_network(network, config=CreateConfig(profiles=profiles, **build_args)) + + def _build_and_save(self, model, input_example): + """ + If TRT engine is not ready, exports model to ONNX, + builds TRT engine and saves serialized TRT engine to the disk. + Args: + input_example: passed to onnx.export() + """ + + if self.engine is not None: + return + + export_args = self.export_args + engine_bytes = None + + add_casts_around_norms(model) + replace_for_export(model) + + if self.method == "torch_trt": + enabled_precisions = [torch.float32] + if self.precision == "fp16": + enabled_precisions.append(torch.float16) + elif self.precision == "bf16": + enabled_precisions.append(torch.bfloat16) + inputs = list(input_example.values()) + + def get_torch_trt_input(input_shape, dynamic_batchsize): + min_input_shape, opt_input_shape, max_input_shape = get_profile_shapes(input_shape, dynamic_batchsize) + return torch_tensorrt.Input( + min_shape=min_input_shape, opt_shape=opt_input_shape, max_shape=max_input_shape + ) + + tt_inputs = [get_torch_trt_input(i.shape, self.dynamic_batchsize) for i in inputs] + engine_bytes = torch_tensorrt.convert_method_to_trt_engine( + model, + "forward", + arg_inputs=tt_inputs, + enabled_precisions=enabled_precisions, + **export_args, + ) + else: + dbs = self.dynamic_batchsize + if dbs: + if len(self.profiles) > 0: + raise ValueError("ERROR: Both dynamic_batchsize and input_profiles set for TrtCompiler!") + if len(dbs) != 3: + raise ValueError("dynamic_batchsize has to have len ==3 ") + profile = {} + for id, val in input_example.items(): + + def add_profile(id, val): + sh = val.shape + if len(sh) > 0: + sh = sh[1:] + profile[id] = [[dbs[0], *sh], [dbs[1], *sh], [dbs[2], *sh]] + + if isinstance(val, list) or isinstance(val, tuple): + for i in range(len(val)): + add_profile(f"{id}_{i}", val[i]) + elif isinstance(val, torch.Tensor): + add_profile(id, val) + self.profiles = [profile] + + self.dynamic_axes = get_dynamic_axes(self.profiles) + + if len(self.dynamic_axes) > 0: + export_args.update({"dynamic_axes": self.dynamic_axes}) + + # Use temporary directory for easy cleanup in case of external weights + with tempfile.TemporaryDirectory() as tmpdir: + if export_args.get("dynamo", False): + input_names = None + else: + input_names = list(unroll_input(self.input_names, input_example).keys()) + onnx_path = str(Path(tmpdir) / "model.onnx") + self.logger.info( + f"Exporting to {onnx_path}:\n" + + f"output_names={self.output_names}\ninput_names={self.input_names}\nexport args: {export_args}" + ) + torch.onnx.export( + model, + (input_example,), + onnx_path, + input_names=input_names, + output_names=self.output_names, + **export_args, + ) + if polygraphy_imported: + from polygraphy.backend.onnx.loader import fold_constants, onnx_from_path, save_onnx + + onnx_model = fold_constants(onnx_from_path(onnx_path), size_threshold=16 * 1000 * 1000) + save_onnx(onnx_model, onnx_path) + self.logger.info("Export to ONNX successful.") + engine_bytes = self._onnx_to_trt(onnx_path) + if engine_bytes: + open(self.plan_path, "wb").write(engine_bytes) + + +def trt_forward(self, *argv, **kwargs): + """ + Patch function to replace original model's forward() with. + Redirects to TrtCompiler.forward() + """ + return self._trt_compiler.forward(self, argv, kwargs) + + +def trt_compile( + model: torch.nn.Module, + base_path: str, + args: Dict[str, Any] | None = None, + submodule: Union[str, List[str]] | None = None, + logger: Any | None = None, +) -> torch.nn.Module: + """ + Instruments model or submodule(s) with TrtCompiler and replaces its forward() with TRT hook. + Note: TRT 10.3 is recommended for best performance. Some nets may even fail to work with TRT 8.x + Args: + model: module to patch with TrtCompiler object. + base_path: TRT plan(s) saved to f"{base_path}[.{submodule}].plan" path. + dirname(base_path) must exist, base_path does not have to. + If base_path does point to existing file (e.g. associated checkpoint), + that file becomes a dependency - its mtime is added to args["timestamp"]. + args: Optional dict : unpacked and passed to TrtCompiler() - see TrtCompiler above for details. + submodule: Optional hierarchical id(s) of submodule to patch, e.g. ['image_decoder.decoder'] + If None, TrtCompiler patch is applied to the whole model. + Otherwise, submodule (or list of) is being patched. + logger: Optional logger for diagnostics. + Returns: + Always returns same model passed in as argument. This is for ease of use in configs. + """ + + default_args: Dict[str, Any] = { + "method": "onnx", + "precision": "fp16", + "build_args": {"builder_optimization_level": 5, "precision_constraints": "obey"}, + } + + default_args.update(args or {}) + args = default_args + + if trt_imported and polygraphy_imported and torch.cuda.is_available(): + # if "path" filename point to existing file (e.g. checkpoint) + # it's also treated as dependency + if os.path.exists(base_path): + timestamp = int(os.path.getmtime(base_path)) + if "timestamp" in args: + timestamp = max(int(args["timestamp"]), timestamp) + args["timestamp"] = timestamp + + def wrap(model, path): + if not hasattr(model, "_trt_compiler"): + model.orig_forward = model.forward + wrapper = TrtCompiler(model, path + ".plan", logger=logger, **args) + model._trt_compiler = wrapper + model.forward = MethodType(trt_forward, model) + + def find_sub(parent, submodule): + idx = submodule.find(".") + # if there is "." in name, call recursively + if idx != -1: + parent_name = submodule[:idx] + parent = getattr(parent, parent_name) + submodule = submodule[idx + 1 :] + return find_sub(parent, submodule) + return parent, submodule + + if submodule is not None: + if isinstance(submodule, str): + submodule = [submodule] + for s in submodule: + parent, sub = find_sub(model, s) + wrap(getattr(parent, sub), base_path + "." + s) + else: + wrap(model, base_path) + else: + logger = logger or getLogger("trt_compile") + logger.warning("TensorRT and/or polygraphy packages are not available! trt_compile() has no effect.") + + return model diff --git a/tests/export/test_trt_compile.py b/tests/export/test_trt_compile.py new file mode 100644 index 000000000000..09a77004678f --- /dev/null +++ b/tests/export/test_trt_compile.py @@ -0,0 +1,139 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import tempfile +import unittest +from typing import List + +import torch + +TEST_CASE_1 = ["fp32"] +TEST_CASE_2 = ["fp16"] + + +class ListAdd(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x: List[torch.Tensor], y: torch.Tensor, z: torch.Tensor, bs: float = 0.1): + y1 = y.clone() + x1 = x.copy() + z1 = z + y + for xi in x: + y1 = y1 + xi + bs + return x1, [y1, z1], y1 + z1 + + +@unittest.skip +class TestTRTCompile(unittest.TestCase): + + def setUp(self): + self.gpu_device = torch.cuda.current_device() + + def tearDown(self): + current_device = torch.cuda.current_device() + if current_device != self.gpu_device: + torch.cuda.set_device(self.gpu_device) + + def test_torch_trt(self): + + model = torch.nn.Sequential(*[torch.nn.PReLU(), torch.nn.PReLU()]) + data1 = model.state_dict() + data1["0.weight"] = torch.tensor([0.1]) + data1["1.weight"] = torch.tensor([0.2]) + model.load_state_dict(data1) + model.cuda() + x = torch.randn(1, 16).to("cuda") + + with tempfile.TemporaryDirectory() as tempdir: + args = { + "method": "torch_trt", + "dynamic_batchsize": [1, 4, 8], + } + input_example = (x,) + output_example = model(*input_example) + trt_compile( + model, + f"{tempdir}/test_lists", + args=args, + ) + self.assertIsNone(model._trt_compiler.engine) + trt_output = model(*input_example) + # Check that lazy TRT build succeeded + self.assertIsNotNone(model._trt_compiler.engine) + torch.testing.assert_close(trt_output, output_example, rtol=0.01, atol=0.01) + + def test_profiles(self): + model = ListAdd().cuda() + + with torch.no_grad(), tempfile.TemporaryDirectory() as tmpdir: + args = { + "export_args": { + "dynamo": False, + }, + "input_profiles": [ + { + "x_0": [[1, 8], [2, 16], [2, 32]], + "x_1": [[1, 8], [2, 16], [2, 32]], + "x_2": [[1, 8], [2, 16], [2, 32]], + "y": [[1, 8], [2, 16], [2, 32]], + "z": [[1, 8], [1, 16], [1, 32]], + } + ], + "output_lists": [[-1], [2], []], + } + x = torch.randn(1, 16).to("cuda") + y = torch.randn(1, 16).to("cuda") + z = torch.randn(1, 16).to("cuda") + input_example = ([x, y, z], y.clone(), z.clone()) + output_example = model(*input_example) + trt_compile( + model, + f"{tmpdir}/test_dynamo_trt", + args=args, + ) + self.assertIsNone(model._trt_compiler.engine) + trt_output = model(*input_example) + # Check that lazy TRT build succeeded + self.assertIsNotNone(model._trt_compiler.engine) + torch.testing.assert_close(trt_output, output_example, rtol=0.01, atol=0.01) + + def test_lists(self): + model = ListAdd().cuda() + + with torch.no_grad(), tempfile.TemporaryDirectory() as tmpdir: + args = { + "export_args": { + "dynamo": True, + }, + "output_lists": [[-1], [2], []], + } + x = torch.randn(1, 16).to("cuda") + y = torch.randn(1, 16).to("cuda") + z = torch.randn(1, 16).to("cuda") + input_example = ([x, y, z], y.clone(), z.clone()) + output_example = model(*input_example) + trt_compile( + model, + f"{tmpdir}/test_lists", + args=args, + ) + self.assertIsNone(model._trt_compiler.engine) + trt_output = model(*input_example) + # Check that lazy TRT build succeeded + self.assertIsNotNone(model._trt_compiler.engine) + torch.testing.assert_close(trt_output, output_example, rtol=0.01, atol=0.01) + + +if __name__ == "__main__": + unittest.main()