From dc2bba2c3a1d7b34c0b67269209553159fe34f58 Mon Sep 17 00:00:00 2001 From: Sankalp Sanand Date: Wed, 6 Dec 2023 14:11:41 -0500 Subject: [PATCH] added support for directly adding modules --- covalent/__init__.py | 1 + covalent/_workflow/__init__.py | 1 + covalent/_workflow/depsmodule.py | 20 ++++++++++++++------ covalent/_workflow/electron.py | 22 +++++++++++++++++----- 4 files changed, 33 insertions(+), 11 deletions(-) diff --git a/covalent/__init__.py b/covalent/__init__.py index 524952776..6d88af00d 100644 --- a/covalent/__init__.py +++ b/covalent/__init__.py @@ -40,6 +40,7 @@ from ._workflow import ( # nopycln: import DepsBash, DepsCall, + DepsModule, DepsPip, Lepton, TransportableObject, diff --git a/covalent/_workflow/__init__.py b/covalent/_workflow/__init__.py index 98da8e676..64ce547c8 100644 --- a/covalent/_workflow/__init__.py +++ b/covalent/_workflow/__init__.py @@ -18,6 +18,7 @@ from .depsbash import DepsBash from .depscall import DepsCall +from .depsmodule import DepsModule from .depspip import DepsPip from .electron import electron from .lattice import lattice diff --git a/covalent/_workflow/depsmodule.py b/covalent/_workflow/depsmodule.py index 1c4fc9bfb..a25f0786d 100644 --- a/covalent/_workflow/depsmodule.py +++ b/covalent/_workflow/depsmodule.py @@ -15,27 +15,30 @@ # limitations under the License. import importlib +from types import ModuleType +from typing import Union import cloudpickle as pickle from .depscall import DepsCall -def _client_side_pickle_module(module_name: str): +def _client_side_pickle_module(module: Union[str, ModuleType]): """ Pickle a module by value on the client side and return the pickled bytes. Args: - module_name: The name of the module to pickle. + module: The name of the module to pickle, can also be a module. This module must be importable on the client side. Returns: The pickled bytes of the module. """ - # Import the module on the client side - module = importlib.import_module(module_name) + if isinstance(module, str): + # Import the module on the client side + module = importlib.import_module(module) # Register the module with cloudpickle by value pickle.register_pickle_by_value(module) @@ -43,6 +46,9 @@ def _client_side_pickle_module(module_name: str): # Pickle the module pickled_module = pickle.dumps(module) + # Unregister the module with cloudpickle + # pickle.unregister_pickle_by_value(module) + return pickled_module @@ -79,9 +85,11 @@ class DepsModule(DepsCall): module_name: A string containing the name of the module to be imported. """ - def __init__(self, module_name: str): + def __init__(self, module: Union[str, ModuleType]): + module_name = module if isinstance(module, str) else module.__name__ + # Pickle the module by value on the client side - module_pickle = _client_side_pickle_module(module_name) + module_pickle = _client_side_pickle_module(module) # Pass the pickled module to the server side func = _server_side_import_module diff --git a/covalent/_workflow/electron.py b/covalent/_workflow/electron.py index 97f32a787..909c5eabf 100644 --- a/covalent/_workflow/electron.py +++ b/covalent/_workflow/electron.py @@ -24,6 +24,7 @@ from builtins import list from dataclasses import asdict from functools import wraps +from types import ModuleType from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional, Union from covalent._dispatcher_plugins.local import LocalDispatcher @@ -753,12 +754,23 @@ def electron( if deps_module: if isinstance(deps_module, list): - deps_module = [ - DepsModule(module_name=module) if isinstance(module, str) else module - for module in deps_module - ] + # Convert to DepsModule objects + converted_deps = [] + for dep in deps_module: + if isinstance(dep, str): + converted_deps.append(DepsModule(dep)) + elif isinstance(dep, ModuleType): + converted_deps.append(DepsModule(dep)) + else: + converted_deps.append(dep) + deps_module = converted_deps + elif isinstance(deps_module, str): - deps_module = [DepsModule(module_name=deps_module)] + deps_module = [DepsModule(deps_module)] + + elif isinstance(deps_module, ModuleType): + deps_module = [DepsModule(deps_module)] + elif isinstance(deps_module, DepsModule): deps_module = [deps_module]