Skip to content

Commit

Permalink
feat: Pulling in dependencies (in_process mode) using conda environme…
Browse files Browse the repository at this point in the history
…nt (#4807)

* InferenceSpec support for HF

* feat: InferenceSpec support for MMS and testing

* Introduce changes for InProcess Mode

* mb_inprocess updates

* In_Process mode for TGI transformers, edits

* Remove InfSpec from branch

* changes to support in_process

* changes to get pre-checks passing

* pylint fix

* unit test, test mb

* period missing, added

* suggestions and test added

* pre-push fix

* missing an @

* fixes to test, added stubbing

* removing for fixes

* variable fixes

* init fix

* tests for in process mode

* prepush fix

* deps and mb

* changes

* fixing pkl

* testing

* save pkl debug

* changes

* conda create

* Conda fixes

* random dep

* subproces

* requirementsmanager.py script

* requires manag

* changing command

* changing command

* print

* shell=true

* minor fix

* changes

* check=true

* unit test

* testing

* unit test for requirementsmanager

* removing in_process and minor edits

* format

* .txt file

* renaming functions

* fix path

* making .txt evaluate to true

---------

Co-authored-by: Bryannah Hernandez <[email protected]>
Co-authored-by: sage-maker <[email protected]>
  • Loading branch information
3 people authored Aug 7, 2024
1 parent 97a6be3 commit bceefd1
Show file tree
Hide file tree
Showing 8 changed files with 398 additions and 5 deletions.
1 change: 0 additions & 1 deletion src/sagemaker/serve/builder/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -603,7 +603,6 @@ def _overwrite_mode_in_deploy(self, overwrite_mode: str):
s3_upload_path, env_vars_sagemaker = self._prepare_for_mode()
self.pysdk_model.model_data = s3_upload_path
self.pysdk_model.env.update(env_vars_sagemaker)

elif overwrite_mode == Mode.LOCAL_CONTAINER:
self.mode = self.pysdk_model.mode = Mode.LOCAL_CONTAINER
self._prepare_for_mode()
Expand Down
100 changes: 100 additions & 0 deletions src/sagemaker/serve/builder/requirements_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""Requirements Manager class to pull in client dependencies from a .txt or .yml file"""
from __future__ import absolute_import
import logging
import os
import subprocess

from typing import Optional

logger = logging.getLogger(__name__)


class RequirementsManager:
"""Manages dependency installation by detecting file types"""

def capture_and_install_dependencies(self, dependencies: Optional[str] = None) -> str:
"""Detects the type of file dependencies will be installed from
If a req.txt or conda.yml file is provided, it verifies their existence and
returns the local file path
Args:
dependencies (str): Local path where dependencies file exists.
Returns:
file path of the existing or generated dependencies file
"""
_dependencies = dependencies or self._detect_conda_env_and_local_dependencies()

# Dependencies specified as either req.txt or conda_env.yml
if _dependencies.endswith(".txt"):
self._install_requirements_txt()
elif _dependencies.endswith(".yml"):
self._update_conda_env_in_path()
else:
raise ValueError(f'Invalid dependencies provided: "{_dependencies}"')

def _install_requirements_txt(self):
"""Install requirements.txt file using pip"""
logger.info("Running command to pip install")
subprocess.run("pip install -r in_process_requirements.txt", shell=True, check=True)
logger.info("Command ran successfully")

def _update_conda_env_in_path(self):
"""Update conda env using conda yml file"""
logger.info("Updating conda env")
subprocess.run("conda env update -f conda_in_process.yml", shell=True, check=True)
logger.info("Conda env updated successfully")

def _get_active_conda_env_name(self) -> str:
"""Returns the conda environment name from the set environment variable. None otherwise."""
return os.getenv("CONDA_DEFAULT_ENV")

def _get_active_conda_env_prefix(self) -> str:
"""Returns the conda prefix from the set environment variable. None otherwise."""
return os.getenv("CONDA_PREFIX")

def _detect_conda_env_and_local_dependencies(self) -> str:
"""Generates dependencies list from the user's local runtime.
Raises RuntimeEnvironmentError if not able to.
Currently supports: conda environments
"""

# Try to capture dependencies from the conda environment, if any.
conda_env_name = self._get_active_conda_env_name()
logger.info("Found conda_env_name: '%s'", conda_env_name)
conda_env_prefix = None

if conda_env_name is None:
conda_env_prefix = self._get_active_conda_env_prefix()

if conda_env_name is None and conda_env_prefix is None:
local_dependencies_path = os.path.join(os.getcwd(), "in_process_requirements.txt")
logger.info(local_dependencies_path)

return local_dependencies_path

if conda_env_name == "base":
logger.warning(
"We recommend using an environment other than base to "
"isolate your project dependencies from conda dependencies"
)

local_dependencies_path = os.path.join(os.getcwd(), "conda_in_process.yml")
logger.info(local_dependencies_path)

return local_dependencies_path
14 changes: 14 additions & 0 deletions src/sagemaker/serve/builder/transformers_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from abc import ABC, abstractmethod
from typing import Type
from pathlib import Path
import subprocess
from packaging.version import Version

from sagemaker.model import Model
Expand All @@ -41,6 +42,8 @@
from sagemaker.serve.utils.telemetry_logger import _capture_telemetry
from sagemaker.base_predictor import PredictorBase
from sagemaker.huggingface.llm_utils import get_huggingface_model_metadata
from sagemaker.serve.builder.requirements_manager import RequirementsManager


logger = logging.getLogger(__name__)
DEFAULT_TIMEOUT = 1800
Expand Down Expand Up @@ -376,6 +379,9 @@ def _build_for_transformers(self):
save_pkl(code_path, (self.inference_spec, self.schema_builder))
logger.info("PKL file saved to file: %s", code_path)

if self.mode == Mode.IN_PROCESS:
self._create_conda_env()

self._auto_detect_container()

self.secret_key = prepare_for_mms(
Expand All @@ -394,3 +400,11 @@ def _build_for_transformers(self):
if self.sagemaker_session:
self.pysdk_model.sagemaker_session = self.sagemaker_session
return self.pysdk_model

def _create_conda_env(self):
"""Creating conda environment by running commands"""

try:
RequirementsManager().capture_and_install_dependencies(self)
except subprocess.CalledProcessError:
print("Failed to create and activate conda environment.")
6 changes: 3 additions & 3 deletions src/sagemaker/serve/model_server/multi_model_server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def _start_serving(
secret_key: str,
env_vars: dict,
):
"""Placeholder docstring"""
"""Initializes the start of the server"""
env = {
"SAGEMAKER_SUBMIT_DIRECTORY": "/opt/ml/model/code",
"SAGEMAKER_PROGRAM": "inference.py",
Expand Down Expand Up @@ -59,7 +59,7 @@ def _start_serving(
)

def _invoke_multi_model_server_serving(self, request: object, content_type: str, accept: str):
"""Placeholder docstring"""
"""Invokes MMS server by hitting the docker host"""
try:
response = requests.post(
f"http://{get_docker_host()}:8080/invocations",
Expand All @@ -73,7 +73,7 @@ def _invoke_multi_model_server_serving(self, request: object, content_type: str,
raise Exception("Unable to send request to the local container server") from e

def _multi_model_server_deep_ping(self, predictor: PredictorBase):
"""Placeholder docstring"""
"""Deep ping in order to ensure prediction"""
response = None
try:
response = predictor.predict(self.schema_builder.sample_input)
Expand Down
113 changes: 113 additions & 0 deletions src/sagemaker/serve/utils/conda_in_process.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
name: conda_env
channels:
- defaults
dependencies:
- accelerate>=0.24.1,<=0.27.0
- sagemaker_schema_inference_artifacts>=0.0.5
- uvicorn>=0.30.1
- fastapi>=0.111.0
- nest-asyncio
- pip>=23.0.1
- attrs>=23.1.0,<24
- boto3>=1.34.142,<2.0
- cloudpickle==2.2.1
- google-pasta
- numpy>=1.9.0,<2.0
- protobuf>=3.12,<5.0
- smdebug_rulesconfig==1.0.1
- importlib-metadata>=1.4.0,<7.0
- packaging>=20.0
- pandas
- pathos
- schema
- PyYAML~=6.0
- jsonschema
- platformdirs
- tblib>=1.7.0,<4
- urllib3>=1.26.8,<3.0.0
- requests
- docker
- tqdm
- psutil
- pip:
- altair>=4.2.2
- anyio>=3.6.2
- awscli>=1.27.114
- blinker>=1.6.2
- botocore>=1.29.114
- cachetools>=5.3.0
- certifi==2022.12.7
- harset-normalizer>=3.1.0
- click>=8.1.3
- cloudpickle>=2.2.1
- colorama>=0.4.4
- contextlib2>=21.6.0
- decorator>=5.1.1
- dill>=0.3.6
- docutils>=0.16
- entrypoints>=0.4
- filelock>=3.11.0
- gitdb>=4.0.10
- gitpython>=3.1.31
- gunicorn>=20.1.0
- h11>=0.14.0
- huggingface-hub>=0.13.4
- idna>=3.4
- importlib-metadata>=4.13.0
- jinja2>=3.1.2
- jmespath>=1.0.1
- jsonschema>=4.17.3
- markdown-it-py>=2.2.0
- markupsafe>=2.1.2
- mdurl>=0.1.2
- mpmath>=1.3.0
- multiprocess>=0.70.14
- networkx>=3.1
- packaging>=23.1
- pandas>=1.5.3
- pathos>=0.3.0
- pillow>=9.5.0
- platformdirs>=3.2.0
- pox>=0.3.2
- ppft>=1.7.6.6
- protobuf>=3.20.3
- protobuf3-to-dict>=0.1.5
- pyarrow>=11.0.0
- pyasn1>=0.4.8
- pydantic>=1.10.7
- pydeck>=0.8.1b0
- pygments>=2.15.1
- pympler>=1.0.1
- pyrsistent>=0.19.3
- python-dateutil>=2.8.2
- pytz>=2023.3
- pytz-deprecation-shim>=0.1.0.post0
- pyyaml>=5.4.1
- regex>=2023.3.23
- requests>=2.28.2
- rich>=13.3.4
- rsa>=4.7.2
- s3transfer>=0.6.0
- sagemaker>=2.148.0
- schema>=0.7.5
- six>=1.16.0
- smdebug-rulesconfig>=1.0.1
- smmap==5.0.0
- sniffio>=1.3.0
- starlette>=0.26.1
- streamlit>=1.21.0
- sympy>=1.11.1
- tblib>=1.7.0
- tokenizers>=0.13.3
- toml>=0.10.2
- toolz>=0.12.0
- torch>=2.0.0
- tornado>=6.3
- tqdm>=4.65.0
- transformers>=4.28.1
- typing-extensions>=4.5.0
- tzdata>=2023.3
- tzlocal>=4.3
- urllib3>=1.26.15
- validators>=0.20.0
- zipp>=3.15.0
2 changes: 1 addition & 1 deletion src/sagemaker/serve/utils/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Placeholder Docstring"""
"""Exceptions used across different model builder invocations"""

from __future__ import absolute_import

Expand Down
85 changes: 85 additions & 0 deletions src/sagemaker/serve/utils/in_process_requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
altair>=4.2.2
anyio>=3.6.2
awscli>=1.27.114
blinker>=1.6.2
botocore>=1.29.114
cachetools>=5.3.0
certifi==2022.12.7
harset-normalizer>=3.1.0
click>=8.1.3
cloudpickle>=2.2.1
colorama>=0.4.4
contextlib2>=21.6.0
decorator>=5.1.1
dill>=0.3.6
docutils>=0.16
entrypoints>=0.4
filelock>=3.11.0
gitdb>=4.0.10
gitpython>=3.1.31
gunicorn>=20.1.0
h11>=0.14.0
huggingface-hub>=0.13.4
idna>=3.4
importlib-metadata>=4.13.0
jinja2>=3.1.2
jmespath>=1.0.1
jsonschema>=4.17.3
markdown-it-py>=2.2.0
markupsafe>=2.1.2
mdurl>=0.1.2
mpmath>=1.3.0
multiprocess>=0.70.14
networkx>=3.1
packaging>=23.1
pandas>=1.5.3
pathos>=0.3.0
pillow>=9.5.0
platformdirs>=3.2.0
pox>=0.3.2
ppft>=1.7.6.6
protobuf>=3.20.3
protobuf3-to-dict>=0.1.5
pyarrow>=11.0.0
pyasn1>=0.4.8
pydantic>=1.10.7
pydeck>=0.8.1b0
pygments>=2.15.1
pympler>=1.0.1
pyrsistent>=0.19.3
python-dateutil>=2.8.2
pytz>=2023.3
pytz-deprecation-shim>=0.1.0.post0
pyyaml>=5.4.1
regex>=2023.3.23
requests>=2.28.2
rich>=13.3.4
rsa>=4.7.2
s3transfer>=0.6.0
sagemaker>=2.148.0
schema>=0.7.5
six>=1.16.0
smdebug-rulesconfig>=1.0.1
smmap==5.0.0
sniffio>=1.3.0
starlette>=0.26.1
streamlit>=1.21.0
sympy>=1.11.1
tblib>=1.7.0
tokenizers>=0.13.3
toml>=0.10.2
toolz>=0.12.0
torch>=2.0.0
tornado>=6.3
tqdm>=4.65.0
transformers>=4.28.1
typing-extensions>=4.5.0
tzdata>=2023.3
tzlocal>=4.3
urllib3>=1.26.15
validators>=0.20.0
zipp>=3.15.0
uvicorn>=0.30.1
fastapi>=0.111.0
nest-asyncio
transformers
Loading

0 comments on commit bceefd1

Please sign in to comment.