diff --git a/src/server/package/src/model_explorer/config.py b/src/server/package/src/model_explorer/config.py index f995c5c9..529f3964 100644 --- a/src/server/package/src/model_explorer/config.py +++ b/src/server/package/src/model_explorer/config.py @@ -95,6 +95,13 @@ def add_model_from_pytorch( exported_program: the ExportedProgram from torch.export.export. settings: The settings that config the visualization. """ + + if torch is None: + raise ImportError( + '`torch` not found. Please install it via `pip install torch`, ' + 'and restart the Model Explorer server.' + ) + # Convert the given model to model explorer graphs. print('Converting pytorch model to model explorer graphs...') adapter = PytorchExportedProgramAdapterImpl(exported_program, settings) diff --git a/src/server/package/src/model_explorer/extension_manager.py b/src/server/package/src/model_explorer/extension_manager.py index cc023b3c..4cb223a4 100644 --- a/src/server/package/src/model_explorer/extension_manager.py +++ b/src/server/package/src/model_explorer/extension_manager.py @@ -19,6 +19,11 @@ from importlib import import_module from typing import Any, Dict, Union +try: + import torch +except ImportError: + torch = None + from .adapter_runner import AdapterRunner from .consts import MODULE_NAME from .extension_class_processor import ExtensionClassProcessor @@ -29,15 +34,19 @@ class ExtensionManager(object, metaclass=Singleton): - BUILTIN_ADAPTER_MODULES: list[str] = [ - '.builtin_tflite_flatbuffer_adapter', - '.builtin_tflite_mlir_adapter', - '.builtin_tf_mlir_adapter', - '.builtin_tf_direct_adapter', - '.builtin_graphdef_adapter', - '.builtin_pytorch_exportedprogram_adapter', - '.builtin_mlir_adapter', - ] + BUILTIN_ADAPTER_MODULES: list[str] = ( + [ + '.builtin_tflite_flatbuffer_adapter', + '.builtin_tflite_mlir_adapter', + '.builtin_tf_mlir_adapter', + '.builtin_tf_direct_adapter', + '.builtin_graphdef_adapter', + ] + + (['.builtin_pytorch_exportedprogram_adapter'] if torch else []) + + [ + '.builtin_mlir_adapter', + ] + ) CACHED_REGISTERED_EXTENSIONS: Dict[str, RegisteredExtension] = {}