Skip to content

Commit

Permalink
Fix: ModelBuilder deployment & optimization of JumpStart llama-3.1 mo…
Browse files Browse the repository at this point in the history
…dels (#4937)

* Emit warning when cpu cores are requested with sharded model deployment.

* Reformat sharded model validations.

* fix pop on none error in jumpstart draft model flow

* set lmi config on js model optimize

* re-format lmi config switch

* add e2e UT for lmi + .optimize()

* add e2e UT for lmi + .optimize() no override

* add deep UTs to catch regressions and test E2E fully and more practically

* work around flake8 bug

* flake8 workaround

* fix flake8 syntax error in py38

---------

Co-authored-by: Joseph Zhang <[email protected]>
Co-authored-by: Gary Wang 😤 <[email protected]>
  • Loading branch information
3 people authored Nov 22, 2024
1 parent 7c14046 commit 801db44
Show file tree
Hide file tree
Showing 8 changed files with 503 additions and 15 deletions.
14 changes: 14 additions & 0 deletions src/sagemaker/jumpstart/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -817,6 +817,20 @@ def deploy(
f"{EndpointType.INFERENCE_COMPONENT_BASED} is not supported for Proprietary models."
)

# No resources given to deploy() but present 'resources' key in deploy_kwargs means default
# JumpStart resource requirements are being used
if hasattr(self, "_is_sharded_model") and not resources and deploy_kwargs.resources:
if (
self._is_sharded_model
and deploy_kwargs.resources.num_cpus
and deploy_kwargs.resources.num_cpus > 0
):
JUMPSTART_LOGGER.warning(
"NumOfCpuCoresRequired should be 0 for the best experience with SageMaker Fast "
"Model Loading. Overriding the requested `num_cpus` to 0."
)
deploy_kwargs.resources.num_cpus = 0

self.additional_model_data_sources = _add_model_access_configs_to_model_data_sources(
self.additional_model_data_sources,
deploy_kwargs.model_access_configs,
Expand Down
7 changes: 4 additions & 3 deletions src/sagemaker/jumpstart/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1595,9 +1595,10 @@ def _add_model_access_configs_to_model_data_sources(
)
acked_model_data_sources.append(mutable_model_data_source)
else:
mutable_model_data_source.pop(
"HostingEulaKey"
) # pop when model access config is not applicable
if "HostingEulaKey" in mutable_model_data_source:
mutable_model_data_source.pop(
"HostingEulaKey"
) # pop when model access config is not applicable
acked_model_data_sources.append(mutable_model_data_source)
return acked_model_data_sources

Expand Down
31 changes: 19 additions & 12 deletions src/sagemaker/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1600,18 +1600,25 @@ def deploy(
if self._base_name is not None:
self._base_name = "-".join((self._base_name, compiled_model_suffix))

if self._is_sharded_model and endpoint_type != EndpointType.INFERENCE_COMPONENT_BASED:
logging.warning(
"Forcing INFERENCE_COMPONENT_BASED endpoint for sharded model. ADVISORY - "
"Use INFERENCE_COMPONENT_BASED endpoints over MODEL_BASED endpoints."
)
endpoint_type = EndpointType.INFERENCE_COMPONENT_BASED
if self._is_sharded_model:
if endpoint_type != EndpointType.INFERENCE_COMPONENT_BASED:
logging.warning(
"Forcing INFERENCE_COMPONENT_BASED endpoint for sharded model. ADVISORY - "
"Use INFERENCE_COMPONENT_BASED endpoints over MODEL_BASED endpoints."
)
endpoint_type = EndpointType.INFERENCE_COMPONENT_BASED

if self._is_sharded_model and self._enable_network_isolation:
raise ValueError(
"EnableNetworkIsolation cannot be set to True since SageMaker Fast Model "
"Loading of model requires network access."
)
if self._enable_network_isolation:
raise ValueError(
"EnableNetworkIsolation cannot be set to True since SageMaker Fast Model "
"Loading of model requires network access."
)

if resources and resources.num_cpus and resources.num_cpus > 0:
logger.warning(
"NumberOfCpuCoresRequired should be 0 for the best experience with SageMaker "
"Fast Model Loading. Configure by setting `num_cpus` to 0 in `resources`."
)

# Support multiple models on same endpoint
if endpoint_type == EndpointType.INFERENCE_COMPONENT_BASED:
Expand Down Expand Up @@ -1655,7 +1662,7 @@ def deploy(
vpc_config=self.vpc_config,
enable_network_isolation=self._enable_network_isolation,
role=self.role,
live_logging=endpoint_logging,
live_logging=False, # TODO: enable when IC supports this
wait=wait,
)

Expand Down
4 changes: 4 additions & 0 deletions src/sagemaker/serve/builder/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -1302,6 +1302,10 @@ def _model_builder_optimize_wrapper(
job_name = job_name or f"modelbuilderjob-{uuid.uuid4().hex}"
if self._is_jumpstart_model_id():
self.build(mode=self.mode, sagemaker_session=self.sagemaker_session)
if self.pysdk_model:
self.pysdk_model.set_deployment_config(
instance_type=instance_type, config_name="lmi"
)
input_args = self._optimize_for_jumpstart(
output_path=output_path,
instance_type=instance_type,
Expand Down
238 changes: 238 additions & 0 deletions tests/integ/sagemaker/serve/test_serve_js_deep_unit_tests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,238 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
from __future__ import absolute_import
from unittest.mock import MagicMock, patch, ANY

from sagemaker.session import Session
from sagemaker.serve.builder.model_builder import ModelBuilder
from sagemaker.serve.builder.schema_builder import SchemaBuilder
from sagemaker.resource_requirements import ResourceRequirements

ROLE_NAME = "SageMakerRole"


def test_js_model_with_optimize_speculative_decoding_config_gated_requests_are_expected(
sagemaker_session,
):
with patch.object(
Session, "create_model", return_value="mock_model"
) as mock_create_model, patch.object(
Session, "endpoint_from_production_variants"
) as mock_endpoint_from_production_variants:
iam_client = sagemaker_session.boto_session.client("iam")
role_arn = iam_client.get_role(RoleName=ROLE_NAME)["Role"]["Arn"]

schema_builder = SchemaBuilder("test", "test")
model_builder = ModelBuilder(
model="meta-textgeneration-llama-3-1-8b-instruct",
schema_builder=schema_builder,
sagemaker_session=sagemaker_session,
role_arn=role_arn,
)

optimized_model = model_builder.optimize(
instance_type="ml.g5.xlarge", # set to small instance in case a network call is made
speculative_decoding_config={
"ModelProvider": "JumpStart",
"ModelID": "meta-textgeneration-llama-3-2-1b",
"AcceptEula": True,
},
accept_eula=True,
)

optimized_model.deploy()

mock_create_model.assert_called_once_with(
name=ANY,
role=ANY,
container_defs={
"Image": ANY,
"Environment": {
"SAGEMAKER_PROGRAM": "inference.py",
"ENDPOINT_SERVER_TIMEOUT": "3600",
"MODEL_CACHE_ROOT": "/opt/ml/model",
"SAGEMAKER_ENV": "1",
"HF_MODEL_ID": "/opt/ml/model",
"SAGEMAKER_MODEL_SERVER_WORKERS": "1",
"OPTION_SPECULATIVE_DRAFT_MODEL": "/opt/ml/additional-model-data-sources/draft_model/",
},
"AdditionalModelDataSources": [
{
"ChannelName": "draft_model",
"S3DataSource": {
"S3Uri": ANY,
"S3DataType": "S3Prefix",
"CompressionType": "None",
"ModelAccessConfig": {"AcceptEula": True},
},
}
],
"ModelDataSource": {
"S3DataSource": {
"S3Uri": ANY,
"S3DataType": "S3Prefix",
"CompressionType": "None",
"ModelAccessConfig": {"AcceptEula": True},
}
},
},
vpc_config=None,
enable_network_isolation=True,
tags=ANY,
)
mock_endpoint_from_production_variants.assert_called_once()


def test_js_model_with_optimize_sharding_and_resource_requirements_requests_are_expected(
sagemaker_session,
):
with patch.object(
Session,
"wait_for_optimization_job",
return_value={"OptimizationJobName": "mock_optimization_job"},
), patch.object(
Session, "create_model", return_value="mock_model"
) as mock_create_model, patch.object(
Session, "endpoint_from_production_variants", return_value="mock_endpoint_name"
) as mock_endpoint_from_production_variants, patch.object(
Session, "create_inference_component"
) as mock_create_inference_component:
iam_client = sagemaker_session.boto_session.client("iam")
role_arn = iam_client.get_role(RoleName=ROLE_NAME)["Role"]["Arn"]

sagemaker_session.sagemaker_client.create_optimization_job = MagicMock()

schema_builder = SchemaBuilder("test", "test")
model_builder = ModelBuilder(
model="meta-textgeneration-llama-3-1-8b-instruct",
schema_builder=schema_builder,
sagemaker_session=sagemaker_session,
role_arn=role_arn,
)

optimized_model = model_builder.optimize(
instance_type="ml.g5.xlarge", # set to small instance in case a network call is made
sharding_config={"OverrideEnvironment": {"OPTION_TENSOR_PARALLEL_DEGREE": "8"}},
accept_eula=True,
)

optimized_model.deploy(
resources=ResourceRequirements(requests={"memory": 196608, "num_accelerators": 8})
)

mock_create_model.assert_called_once_with(
name=ANY,
role=ANY,
container_defs={
"Image": ANY,
"Environment": {
"SAGEMAKER_PROGRAM": "inference.py",
"ENDPOINT_SERVER_TIMEOUT": "3600",
"MODEL_CACHE_ROOT": "/opt/ml/model",
"SAGEMAKER_ENV": "1",
"HF_MODEL_ID": "/opt/ml/model",
"SAGEMAKER_MODEL_SERVER_WORKERS": "1",
"OPTION_TENSOR_PARALLEL_DEGREE": "8",
},
"ModelDataSource": {
"S3DataSource": {
"S3Uri": ANY,
"S3DataType": "S3Prefix",
"CompressionType": "None",
"ModelAccessConfig": {"AcceptEula": True},
}
},
},
vpc_config=None,
enable_network_isolation=False, # should be set to false
tags=ANY,
)
mock_endpoint_from_production_variants.assert_called_once_with(
name=ANY,
production_variants=ANY,
tags=ANY,
kms_key=ANY,
vpc_config=ANY,
enable_network_isolation=False,
role=ANY,
live_logging=False, # this should be set to false for IC
wait=True,
)
mock_create_inference_component.assert_called_once()


def test_js_model_with_optimize_quantization_on_pre_optimized_model_requests_are_expected(
sagemaker_session,
):
with patch.object(
Session,
"wait_for_optimization_job",
return_value={"OptimizationJobName": "mock_optimization_job"},
), patch.object(
Session, "create_model", return_value="mock_model"
) as mock_create_model, patch.object(
Session, "endpoint_from_production_variants", return_value="mock_endpoint_name"
) as mock_endpoint_from_production_variants:
iam_client = sagemaker_session.boto_session.client("iam")
role_arn = iam_client.get_role(RoleName=ROLE_NAME)["Role"]["Arn"]

sagemaker_session.sagemaker_client.create_optimization_job = MagicMock()

schema_builder = SchemaBuilder("test", "test")
model_builder = ModelBuilder(
model="meta-textgeneration-llama-3-1-8b-instruct",
schema_builder=schema_builder,
sagemaker_session=sagemaker_session,
role_arn=role_arn,
)

optimized_model = model_builder.optimize(
instance_type="ml.g5.xlarge", # set to small instance in case a network call is made
quantization_config={
"OverrideEnvironment": {
"OPTION_QUANTIZE": "fp8",
},
},
accept_eula=True,
)

optimized_model.deploy()

mock_create_model.assert_called_once_with(
name=ANY,
role=ANY,
container_defs={
"Image": ANY,
"Environment": {
"SAGEMAKER_PROGRAM": "inference.py",
"ENDPOINT_SERVER_TIMEOUT": "3600",
"MODEL_CACHE_ROOT": "/opt/ml/model",
"SAGEMAKER_ENV": "1",
"HF_MODEL_ID": "/opt/ml/model",
"SAGEMAKER_MODEL_SERVER_WORKERS": "1",
"OPTION_QUANTIZE": "fp8",
},
"ModelDataSource": {
"S3DataSource": {
"S3Uri": ANY,
"S3DataType": "S3Prefix",
"CompressionType": "None",
"ModelAccessConfig": {"AcceptEula": True},
}
},
},
vpc_config=None,
enable_network_isolation=True, # should be set to false
tags=ANY,
)
mock_endpoint_from_production_variants.assert_called_once()
22 changes: 22 additions & 0 deletions tests/unit/sagemaker/jumpstart/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2318,6 +2318,28 @@ def test_multiple_gated_additional_model_data_source_should_accept_both(self):
+ self.MOCK_GATED_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_POST_CALL
)

def test_gated_additional_model_data_source_already_accepted_with_no_hosting_eula_key_should_pass_through(
self,
):
mock_gated_deploy_config_additional_model_data_pre_accepted = [
{
"ChannelName": "draft_model",
"S3DataSource": {
"CompressionType": "None",
"S3DataType": "S3Prefix",
"S3Uri": "s3://jumpstart_bucket/path/to/gated/resources/",
"ModelAccessConfig": {"AcceptEula": True},
},
}
]

utils._add_model_access_configs_to_model_data_sources(
model_data_sources=mock_gated_deploy_config_additional_model_data_pre_accepted,
model_access_configs={self.MOCK_GATED_MODEL_ID: ModelAccessConfig(accept_eula=False)},
model_id=self.MOCK_GATED_MODEL_ID,
region=JUMPSTART_DEFAULT_REGION_NAME,
)

# Mixed Positive Cases

def test_multiple_mixed_additional_model_data_source_should_pass_through_one_accept_the_other(
Expand Down
Loading

0 comments on commit 801db44

Please sign in to comment.