diff --git a/python/sdk/merlin/requirements.py b/python/sdk/merlin/requirements.py index fac6db295..d1f58186f 100644 --- a/python/sdk/merlin/requirements.py +++ b/python/sdk/merlin/requirements.py @@ -2,7 +2,7 @@ from collections import namedtuple from itertools import filterfalse from pathlib import Path -from typing import Dict, List, Optional +from typing import Any, Dict, List, Optional, Union import yaml from packaging.requirements import Requirement @@ -18,8 +18,8 @@ def get_default_merlin_requirements(): def process_conda_env( - conda_env: Dict = None, - python_version: str = "3.10.*", + conda_env: Any, + python_version: Optional[str] = "3.10.*", additional_merlin_reqs: Optional[List[str]] = None, ): """ @@ -54,17 +54,18 @@ def process_conda_env( pip_reqs = _get_pip_deps(conda_env) pip_reqs, constraints = _parse_pip_requirements(pip_reqs) - for additional_merlin_req in additional_merlin_reqs: - exist = False - for pip_req in pip_reqs: - pip_req_obj = Requirement(pip_req) - additional_merlin_req_obj = Requirement(additional_merlin_req) - if pip_req_obj.name.lower() == additional_merlin_req_obj.name.lower(): - exist = True - break - - if not exist: - pip_reqs.insert(0, additional_merlin_req) + if additional_merlin_reqs: + for additional_merlin_req in additional_merlin_reqs: + exist = False + for pip_req in pip_reqs: + pip_req_obj = Requirement(pip_req) + additional_merlin_req_obj = Requirement(additional_merlin_req) + if pip_req_obj.name.lower() == additional_merlin_req_obj.name.lower(): + exist = True + break + + if not exist: + pip_reqs.insert(0, additional_merlin_req) if constraints: pip_reqs.append(f"-c {_CONSTRAINTS_FILE_NAME}")