You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
The AWS SageMaker tutorial using merlin-models works as expected for both the training and inference steps (after following the PR NVIDIA-Merlin/Merlin#1040 fixes). However, when I'm trying to do the same with the transformers4rec getting-started tutorial, I'm getting the following error trying to perform the inference on a SageMaker Endpoint:
| 1733851742784 | I1210 17:29:02.641670 103 python_be.cc:2177] TRITONBACKEND_ModelInstanceExecute: model instance name 0_transformworkflowtriton_0 released 1 requests | AllTraffic/i-0a730c865fae02cab |
| 1733851742784 | Failed to transform operator <merlin.systems.dag.runtimes.triton.ops.workflow.TransformWorkflowTriton object at 0x7f70e322ce50> | AllTraffic/i-0a730c865fae02cab |
| 1733851742784 | Traceback (most recent call last): File "/usr/local/lib/python3.10/dist-packages/merlin/dag/executors.py", line 237, in _run_node_transform transformed_data = node.op.transform(selection, input_data) File "/usr/local/lib/python3.10/dist-packages/merlin/systems/dag/runtimes/triton/ops/workflow.py", line 92, in transform raise RuntimeError(inference_response.error().message()) | AllTraffic/i-0a730c865fae02cab |
| 1733851742784 | RuntimeError: Error: <class 'KeyError'> - "['weekday_sin-list', 'category-list', 'item_id-count', 'age_days-list', 'item_id-list', 'day-first'] not in index", Traceback: [' File "/opt/ml/model/0_transformworkflowtriton/1/model.py", line 117, in execute\n transformed = self.runner.run_workflow(input_tensors)\n', ' File "/usr/local/lib/python3.10/dist-packages/merlin/systems/workflow/base.py", line 103, in run_workflow\n transformed = LocalExecutor().transform(transformable, self.workflow.graph)\n', ' File "/usr/local/lib/python3.10/dist-packages/merlin/dag/executors.py", line 102, in transform\n transformed_data = self._execute_node(node, transformable, capture_dtypes, strict)\n', ' File "/usr/local/lib/python3.10/dist-packages/merlin/dag/executors.py", line 116, in _execute_node\n upstream_outputs = self._run_upstream_transforms(\n', ' File "/usr/local/lib/python3.10/dist-packages/merlin/dag/executors.py", line 130, in _run_upstream_transforms\n node_output = self._execute_node(\n', ' File "/usr/local/lib/python3.10/dist-packages/merlin/dag/executors.py", line 119, in _execute_node\n upstream_columns = self._append_addl_root_columns(node, transformable, upstream_outputs)\n', ' File "/usr/local/lib/python3.10/dist-packages/merlin/dag/executors.py", line 154, in _append_addl_root_columns\n upstream_outputs.append(transformable[list(root_columns)])\n', ' File "/usr/local/lib/python3.10/dist-packages/pandas/core/frame.py", line 3811, in __getitem__\n indexer = self.columns._get_indexer_strict(key, "columns")[1]\n', ' File "/usr/local/lib/python3.10/dist-packages/pandas/core/indexes/base.py", line 6113, in _get_indexer_strict\n self._raise_if_missing(keyarr, indexer, axis_name)\n', ' File "/usr/local/lib/python3.10/dist-packages/pandas/core/indexes/base.py", line 6176, in _raise_if_missing\n raise KeyError(f"{not_found} not in index")\n']
As you can see, the error seems to be related to the grouped variables in the 0_transformworkflowtriton model of the Triton ensemble. However, the model training and the ensemble initialization on the Triton server seems to be ok SM_endpoint_logs_full.txt:
I think that the cause of this error could be in the Triton server initialization command (tritonserver --allow-sagemaker=true --allow-http=false $SAGEMAKER_ARGS) or in the SageMaker Endpoint invocation ( runtime_sm_client.invoke_endpoint(EndpointName=endpoint_name, ContentType=f"application/vnd.sagemaker-triton.binary+json;json-header-size={header_length}", Body=request_body)) (details and code attached below), since when I perform the Triton inference using the AWS SageMaker Training job (the same instance used for training) it works as expected. Any help with this issue will be highly appreciated.
Details
Following the Merlin SageMaker tutorial, these are my files:
Dockerfile
FROM nvcr.io/nvidia/merlin/merlin-pytorch:23.12
RUN pip3 install sagemaker-training
COPY --chown=1000:1000 --chmod=764 serve /usr/bin/serve
serve (Initializes the Triton server. Copied from the PR fix):
#!/bin/bash# Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.## Redistribution and use in source and binary forms, with or without# modification, are permitted provided that the following conditions# are met:# * Redistributions of source code must retain the above copyright# notice, this list of conditions and the following disclaimer.# * Redistributions in binary form must reproduce the above copyright# notice, this list of conditions and the following disclaimer in the# documentation and/or other materials provided with the distribution.# * Neither the name of NVIDIA CORPORATION nor the names of its# contributors may be used to endorse or promote products derived# from this software without specific prior written permission.## THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
SAGEMAKER_SINGLE_MODEL_REPO=/opt/ml/model/
# Use 'ready' for ping check in single-model endpoint mode, and use 'live' for ping check in multi-model endpoint model# https://github.com/kserve/kserve/blob/master/docs/predict-api/v2/rest_predict_v2.yaml#L10-L26if [ -n"$SAGEMAKER_TRITON_OVERRIDE_PING_MODE" ];then
SAGEMAKER_TRITON_PING_MODE=${SAGEMAKER_TRITON_OVERRIDE_PING_MODE}else
SAGEMAKER_TRITON_PING_MODE="ready"fi# Note: in Triton on SageMaker, each model url is registered as a separate repository# e.g., /opt/ml/models/<hash>/model. Specifying MME model repo path as /opt/ml/models causes Triton# to treat it as an additional empty repository and changes# the state of all models to be UNAVAILABLE in the model repository# https://github.com/triton-inference-server/core/blob/main/src/model_repository_manager.cc#L914,L922# On Triton, this path will be a dummy path as it's mandatory to specify a model repo when starting triton
SAGEMAKER_MULTI_MODEL_REPO=/tmp/sagemaker
SAGEMAKER_MODEL_REPO=${SAGEMAKER_SINGLE_MODEL_REPO}
is_mme_mode=false
if [ -n"$SAGEMAKER_MULTI_MODEL" ];thenif [ "$SAGEMAKER_MULTI_MODEL"=="true" ];then
mkdir -p ${SAGEMAKER_MULTI_MODEL_REPO}
SAGEMAKER_MODEL_REPO=${SAGEMAKER_MULTI_MODEL_REPO}if [ -n"$SAGEMAKER_TRITON_OVERRIDE_PING_MODE" ];then
SAGEMAKER_TRITON_PING_MODE=${SAGEMAKER_TRITON_OVERRIDE_PING_MODE}else
SAGEMAKER_TRITON_PING_MODE="live"fi
is_mme_mode=true
echo -e "Triton is running in SageMaker MME mode. Using Triton ping mode: \"${SAGEMAKER_TRITON_PING_MODE}\""fifi
SAGEMAKER_ARGS="--model-repository=${SAGEMAKER_MODEL_REPO}"#Set model namespacing to true, but allow disabling if requiredif [ -n"$SAGEMAKER_TRITON_DISABLE_MODEL_NAMESPACING" ];then
SAGEMAKER_ARGS="${SAGEMAKER_ARGS} --model-namespacing=${SAGEMAKER_TRITON_DISABLE_MODEL_NAMESPACING}"else
SAGEMAKER_ARGS="${SAGEMAKER_ARGS} --model-namespacing=true"fiif [ -n"$SAGEMAKER_BIND_TO_PORT" ];then
SAGEMAKER_ARGS="${SAGEMAKER_ARGS} --sagemaker-port=${SAGEMAKER_BIND_TO_PORT}"fiif [ -n"$SAGEMAKER_SAFE_PORT_RANGE" ];then
SAGEMAKER_ARGS="${SAGEMAKER_ARGS} --sagemaker-safe-port-range=${SAGEMAKER_SAFE_PORT_RANGE}"fiif [ -n"$SAGEMAKER_TRITON_ALLOW_GRPC" ];then
SAGEMAKER_ARGS="${SAGEMAKER_ARGS} --allow-grpc=${SAGEMAKER_TRITON_ALLOW_GRPC}"else
SAGEMAKER_ARGS="${SAGEMAKER_ARGS} --allow-grpc=false"fiif [ -n"$SAGEMAKER_TRITON_ALLOW_METRICS" ];then
SAGEMAKER_ARGS="${SAGEMAKER_ARGS} --allow-metrics=${SAGEMAKER_TRITON_ALLOW_METRICS}"else
SAGEMAKER_ARGS="${SAGEMAKER_ARGS} --allow-metrics=false"fiif [ -n"$SAGEMAKER_TRITON_METRICS_PORT" ];then
SAGEMAKER_ARGS="${SAGEMAKER_ARGS} --metrics-port=${SAGEMAKER_TRITON_METRICS_PORT}"fiif [ -n"$SAGEMAKER_TRITON_GRPC_PORT" ];then
SAGEMAKER_ARGS="${SAGEMAKER_ARGS} --grpc-port=${SAGEMAKER_TRITON_GRPC_PORT}"fiif [ -n"$SAGEMAKER_TRITON_BUFFER_MANAGER_THREAD_COUNT" ];then
SAGEMAKER_ARGS="${SAGEMAKER_ARGS} --buffer-manager-thread-count=${SAGEMAKER_TRITON_BUFFER_MANAGER_THREAD_COUNT}"fiif [ -n"$SAGEMAKER_TRITON_THREAD_COUNT" ];then
SAGEMAKER_ARGS="${SAGEMAKER_ARGS} --sagemaker-thread-count=${SAGEMAKER_TRITON_THREAD_COUNT}"fi# Enable verbose logging by default. If env variable is specified, use value from env variableif [ -n"$SAGEMAKER_TRITON_LOG_VERBOSE" ];then
SAGEMAKER_ARGS="${SAGEMAKER_ARGS} --log-verbose=${SAGEMAKER_TRITON_LOG_VERBOSE}"else
SAGEMAKER_ARGS="${SAGEMAKER_ARGS} --log-verbose=true"fiif [ -n"$SAGEMAKER_TRITON_LOG_INFO" ];then
SAGEMAKER_ARGS="${SAGEMAKER_ARGS} --log-info=${SAGEMAKER_TRITON_LOG_INFO}"fiif [ -n"$SAGEMAKER_TRITON_LOG_WARNING" ];then
SAGEMAKER_ARGS="${SAGEMAKER_ARGS} --log-warning=${SAGEMAKER_TRITON_LOG_WARNING}"fiif [ -n"$SAGEMAKER_TRITON_LOG_ERROR" ];then
SAGEMAKER_ARGS="${SAGEMAKER_ARGS} --log-error=${SAGEMAKER_TRITON_LOG_ERROR}"fiif [ -n"$SAGEMAKER_TRITON_SHM_DEFAULT_BYTE_SIZE" ];then
SAGEMAKER_ARGS="${SAGEMAKER_ARGS} --backend-config=python,shm-default-byte-size=${SAGEMAKER_TRITON_SHM_DEFAULT_BYTE_SIZE}"else
SAGEMAKER_ARGS="${SAGEMAKER_ARGS} --backend-config=python,shm-default-byte-size=16777216"#16MBfiif [ -n"$SAGEMAKER_TRITON_SHM_GROWTH_BYTE_SIZE" ];then
SAGEMAKER_ARGS="${SAGEMAKER_ARGS} --backend-config=python,shm-growth-byte-size=${SAGEMAKER_TRITON_SHM_GROWTH_BYTE_SIZE}"else
SAGEMAKER_ARGS="${SAGEMAKER_ARGS} --backend-config=python,shm-growth-byte-size=1048576"#1MBfiif [ -n"$SAGEMAKER_TRITON_TENSORFLOW_VERSION" ];then
SAGEMAKER_ARGS="${SAGEMAKER_ARGS} --backend-config=tensorflow,version=${SAGEMAKER_TRITON_TENSORFLOW_VERSION}"fiif [ -n"$SAGEMAKER_TRITON_MODEL_LOAD_GPU_LIMIT" ];then
num_gpus=$(nvidia-smi -L | wc -l)for((i=0; i<${num_gpus}; i++));do
SAGEMAKER_ARGS="${SAGEMAKER_ARGS} --model-load-gpu-limit ${i}:${SAGEMAKER_TRITON_MODEL_LOAD_GPU_LIMIT}"donefiif [ -n"$SAGEMAKER_TRITON_ADDITIONAL_ARGS" ];then
SAGEMAKER_ARGS="${SAGEMAKER_ARGS}${SAGEMAKER_TRITON_ADDITIONAL_ARGS}"fi
tritonserver --allow-sagemaker=true --allow-http=false $SAGEMAKER_ARGS
train.py (here I just copied the transformers4rec tutorial):
importargparseimportjsonimportloggingimportosimportsysimporttempfileimportosos.environ["CUDA_VISIBLE_DEVICES"] ="0"importglobimportcudfimportnumpyasnpimportpandasaspdimportnvtabularasnvtfromnvtabular.opsimport*frommerlin.schema.tagsimportTagsfromtransformers4rec.utils.data_utilsimportsave_time_based_splitsimporttorchfromtransformers4recimporttorchastrfromtransformers4rec.torch.ranking_metricimportNDCGAt, AvgPrecisionAt, RecallAtfromtransformers4rec.torch.utils.examples_utilsimportwipe_memoryfrommerlin.schemaimportSchemafrommerlin.ioimportDatasetfromtransformers4rec.config.trainerimportT4RecTrainingArgumentsfromtransformers4rec.torchimportTrainerfrommerlin.core.dispatchimportmake_dffrommerlin.systems.dagimportEnsemblefrommerlin.systems.dag.ops.pytorchimportPredictPyTorchfrommerlin.systems.dag.ops.workflowimportTransformWorkflowimportcloudpicklefrommerlin.tableimportTensorTable, TorchColumnfrommerlin.table.conversionsimportconvert_colimportshutilfromnvtabular.workflowimportWorkflowlogger=logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
logger.addHandler(logging.StreamHandler(sys.stdout))
defparse_args():
""" Parse arguments passed from the SageMaker API to the container. """parser=argparse.ArgumentParser()
# Model directory: we will use the default set by SageMaker, /opt/ml/modelparser.add_argument("--model_dir", type=str, default=os.environ.get("SM_MODEL_DIR"))
returnparser.parse_known_args()
defdata_preprocessing():
INPUT_DATA_DIR=os.environ.get("INPUT_DATA_DIR", "./data/")
NUM_ROWS=os.environ.get("NUM_ROWS", 100000)
long_tailed_item_distribution=np.clip(np.random.lognormal(3., 1., int(NUM_ROWS)).astype(np.int32), 1, 50000)
# generate random item interaction features df=pd.DataFrame(np.random.randint(70000, 90000, int(NUM_ROWS)), columns=['session_id'])
df['item_id'] =long_tailed_item_distribution# generate category mapping for each item-iddf['category'] =pd.cut(df['item_id'], bins=334, labels=np.arange(1, 335)).astype(np.int32)
df['age_days'] =np.random.uniform(0, 1, int(NUM_ROWS)).astype(np.float32)
df['weekday_sin']=np.random.uniform(0, 1, int(NUM_ROWS)).astype(np.float32)
# generate day mapping for each session map_day=dict(zip(df.session_id.unique(), np.random.randint(1, 10, size=(df.session_id.nunique()))))
df['day'] =df.session_id.map(map_day)
SESSIONS_MAX_LENGTH=20# Categorify categorical featurescateg_feats= ['item_id', 'category'] >>nvt.ops.Categorify()
# Define Groupby Workflowgroupby_feats=categ_feats+ ['session_id', 'day', 'age_days', 'weekday_sin']
# Group interaction features by sessiongroupby_features=groupby_feats>>nvt.ops.Groupby(
groupby_cols=["session_id"],
aggs={
"item_id": ["list", "count"],
"category": ["list"],
"day": ["first"],
"age_days": ["list"],
'weekday_sin': ["list"],
},
name_sep="-")
# Select and truncate the sequential featuressequence_features_truncated= (
groupby_features['category-list']
>>nvt.ops.ListSlice(-SESSIONS_MAX_LENGTH)
)
sequence_features_truncated_item= (
groupby_features['item_id-list']
>>nvt.ops.ListSlice(-SESSIONS_MAX_LENGTH)
>>TagAsItemID()
)
sequence_features_truncated_cont= (
groupby_features['age_days-list', 'weekday_sin-list']
>>nvt.ops.ListSlice(-SESSIONS_MAX_LENGTH)
>>nvt.ops.AddMetadata(tags=[Tags.CONTINUOUS])
)
# Filter out sessions with length 1 (not valid for next-item prediction training and evaluation)MINIMUM_SESSION_LENGTH=2selected_features= (
groupby_features['item_id-count', 'day-first', 'session_id'] +sequence_features_truncated_item+sequence_features_truncated+sequence_features_truncated_cont
)
filtered_sessions=selected_features>>nvt.ops.Filter(f=lambdadf: df["item_id-count"] >=MINIMUM_SESSION_LENGTH)
seq_feats_list=filtered_sessions['item_id-list', 'category-list', 'age_days-list', 'weekday_sin-list'] >>nvt.ops.ValueCount()
workflow=nvt.Workflow(filtered_sessions['session_id', 'day-first'] +seq_feats_list)
dataset=nvt.Dataset(df)
# Generate statistics for the features and export parquet files# this step will generate the schema fileworkflow.fit_transform(dataset).to_parquet(os.path.join(INPUT_DATA_DIR, "processed_nvt"))
workflow.save(os.path.join(INPUT_DATA_DIR, "workflow_etl"))
OUTPUT_DIR=os.environ.get("OUTPUT_DIR", os.path.join(INPUT_DATA_DIR, "sessions_by_day"))
# Read in the processed parquet filesessions_gdf=cudf.read_parquet(os.path.join(INPUT_DATA_DIR, "processed_nvt/part_0.parquet"))
save_time_based_splits(data=nvt.Dataset(sessions_gdf),
output_dir=OUTPUT_DIR,
partition_col='day-first',
timestamp_col='session_id',
)
returndefmodel_training():
INPUT_DATA_DIR=os.environ.get("INPUT_DATA_DIR", "./data")
OUTPUT_DIR=os.environ.get("OUTPUT_DIR", f"{INPUT_DATA_DIR}/sessions_by_day")
train=Dataset(os.path.join(INPUT_DATA_DIR, "processed_nvt/part_0.parquet"))
schema=train.schema# You can select a subset of features for trainingschema=schema.select_by_name(['item_id-list',
'category-list',
'weekday_sin-list',
'age_days-list'])
inputs=tr.TabularSequenceFeatures.from_schema(
schema,
max_sequence_length=20,
continuous_projection=64,
masking="mlm",
d_output=100,
)
# Define XLNetConfig class and set default parameters for HF XLNet config transformer_config=tr.XLNetConfig.build(
d_model=64, n_head=4, n_layer=2, total_seq_length=20
)
# Define the model block including: inputs, masking, projection and transformer block.body=tr.SequentialBlock(
inputs, tr.MLPBlock([64]), tr.TransformerBlock(transformer_config, masking=inputs.masking)
)
# Define the evaluation top-N metrics and the cut-offsmetrics= [NDCGAt(top_ks=[20, 40], labels_onehot=True),
RecallAt(top_ks=[20, 40], labels_onehot=True)]
# Define a head related to next item prediction task head=tr.Head(
body,
tr.NextItemPredictionTask(weight_tying=True,
metrics=metrics),
inputs=inputs,
)
# Get the end-to-end Model class model=tr.Model(head)
per_device_train_batch_size=int(os.environ.get(
"per_device_train_batch_size",
'128'
))
per_device_eval_batch_size=int(os.environ.get(
"per_device_eval_batch_size",
'32'
))
# Set hyperparameters for training train_args=T4RecTrainingArguments(
data_loader_engine='merlin',
dataloader_drop_last=True,
gradient_accumulation_steps=1,
per_device_train_batch_size=per_device_train_batch_size,
per_device_eval_batch_size=per_device_eval_batch_size,
output_dir="./tmp",
learning_rate=0.0005,
lr_scheduler_type='cosine',
learning_rate_num_cosine_cycles_by_epoch=1.5,
num_train_epochs=5,
max_sequence_length=20,
report_to= [],
logging_steps=50,
no_cuda=False,
)
trainer=Trainer(
model=model,
args=train_args,
schema=schema,
compute_metrics=True,
)
start_window_index=int(os.environ.get(
"start_window_index",
'1'
))
final_window_index=int(os.environ.get(
"final_window_index",
'8'
))
start_time_window_index=start_window_indexfinal_time_window_index=final_window_index#Iterating over days of one weekfortime_indexinrange(start_time_window_index, final_time_window_index):
# Set data time_index_train=time_indextime_index_eval=time_index+1train_paths=glob.glob(os.path.join(OUTPUT_DIR, f"{time_index_train}/train.parquet"))
eval_paths=glob.glob(os.path.join(OUTPUT_DIR, f"{time_index_eval}/valid.parquet"))
print(train_paths)
# Train on day related to time_index print('*'*20)
print("Launch training for day %s are:"%time_index)
print('*'*20+'\n')
trainer.train_dataset_or_path=train_pathstrainer.reset_lr_scheduler()
trainer.train()
trainer.state.global_step+=1print('finished')
# Evaluate on the following daytrainer.eval_dataset_or_path=eval_pathstrain_metrics=trainer.evaluate(metric_key_prefix='eval')
print('*'*20)
print("Eval results for day %s are:\t"%time_index_eval)
print('\n'+'*'*20+'\n')
forkeyinsorted(train_metrics.keys()):
print(" %s = %s"% (key, str(train_metrics[key])))
wipe_memory()
eval_data_paths=glob.glob(os.path.join(OUTPUT_DIR, f"{time_index_eval}/valid.parquet"))
# set new data from day 7eval_metrics=trainer.evaluate(eval_dataset=eval_data_paths, metric_key_prefix='eval')
forkeyinsorted(eval_metrics.keys()):
print(" %s = %s"% (key, str(eval_metrics[key])))
model_path=os.environ.get("OUTPUT_DIR", f"{INPUT_DATA_DIR}/saved_model")
model.save(model_path)
defmodel_ensemble(output_path):
INPUT_DATA_DIR=os.environ.get("INPUT_DATA_DIR", "./data/")
OUTPUT_DIR=os.environ.get("OUTPUT_DIR", f"{INPUT_DATA_DIR}/sessions_by_day")
model_path=os.environ.get("model_path", f"{INPUT_DATA_DIR}/saved_model")
loaded_model=cloudpickle.load(
open(os.path.join(model_path, "t4rec_model_class.pkl"), "rb")
)
model=loaded_model.cuda()
model.eval()
train_paths=os.path.join(OUTPUT_DIR, f"{1}/train.parquet")
dataset=Dataset(train_paths)
df=cudf.read_parquet(train_paths, columns=model.input_schema.column_names)
table=TensorTable.from_df(df.loc[:100])
forcolumnintable.columns:
table[column] =convert_col(table[column], TorchColumn)
model_input_dict=table.to_dict()
traced_model=torch.jit.trace(model, model_input_dict, strict=True)
input_schema=model.input_schemaoutput_schema=model.output_schemaworkflow=Workflow.load(os.path.join(INPUT_DATA_DIR, "workflow_etl"))
torch_op=workflow.input_schema.column_names>>TransformWorkflow(workflow) >>PredictPyTorch(
traced_model, input_schema, output_schema
)
ensemble=Ensemble(torch_op, workflow.input_schema)
ens_config, node_configs=ensemble.export(output_path)
returndeftrain(output_path):
data_preprocessing()
model_training()
model_ensemble(output_path)
returnif__name__=="__main__":
args, _=parse_args()
train(args.model_dir)
mvidela31
changed the title
[QST] Unable to replicate the getting-started tutorial on AWS SageMaker
[QST] Unable to reproduce the getting-started tutorial on AWS SageMaker
Dec 10, 2024
❓ Questions & Help
Hi everyone, I was trying to reproduce the Getting Started: Session-based Recommendation with Synthetic Data example on AWS SageMaker following the Training and Serving Merlin on AWS SageMaker official tutorial (that uses a
merlin-models
model) but using atransformers4rec
model instead.The AWS SageMaker tutorial using
merlin-models
works as expected for both the training and inference steps (after following the PR NVIDIA-Merlin/Merlin#1040 fixes). However, when I'm trying to do the same with thetransformers4rec
getting-started tutorial, I'm getting the following error trying to perform the inference on a SageMaker Endpoint:As you can see, the error seems to be related to the grouped variables in the
0_transformworkflowtriton
model of the Triton ensemble. However, the model training and the ensemble initialization on the Triton server seems to be ok SM_endpoint_logs_full.txt:I think that the cause of this error could be in the Triton server initialization command (
tritonserver --allow-sagemaker=true --allow-http=false $SAGEMAKER_ARGS
) or in the SageMaker Endpoint invocation (runtime_sm_client.invoke_endpoint(EndpointName=endpoint_name, ContentType=f"application/vnd.sagemaker-triton.binary+json;json-header-size={header_length}", Body=request_body)
) (details and code attached below), since when I perform the Triton inference using the AWS SageMaker Training job (the same instance used for training) it works as expected. Any help with this issue will be highly appreciated.Details
Following the Merlin SageMaker tutorial, these are my files:
Dockerfile
serve
(Initializes the Triton server. Copied from the PR fix):train.py
(here I just copied thetransformers4rec
tutorial):The text was updated successfully, but these errors were encountered: