Skip to content

Commit

Permalink
fix sagemaker role name
Browse files Browse the repository at this point in the history
  • Loading branch information
gilad-shaham committed Mar 8, 2024
1 parent 80c5e1d commit 82937c3
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
6 changes: 3 additions & 3 deletions src/functions/data-preparation.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def data_prepare(context):
sm_client = boto3.client("sagemaker")
boto_session = boto3.Session(region_name=region)
sagemaker_session = sagemaker.session.Session(boto_session=boto_session, sagemaker_client=sm_client)
role = os.environ["SAGEMAKER-ROLE"]
role = os.environ["SAGEMAKER_ROLE"]
bucket_prefix = "payment-classification"
s3_bucket = sagemaker_session.default_bucket()

Expand Down Expand Up @@ -191,7 +191,7 @@ def get_feature_store_values():
sagemaker_session = sagemaker.session.Session(
boto_session=boto_session, sagemaker_client=sm_client
)
role = context.get_secret("SAGEMAKER-ROLE")
role = context.get_secret("SAGEMAKER_ROLE")
bucket_prefix = "payment-classification"
s3_bucket = sagemaker_session.default_bucket()

Expand All @@ -216,4 +216,4 @@ def _set_envars(context):
os.environ["AWS_ACCESS_KEY_ID"] = context.get_secret("AWS_ACCESS_KEY_ID")
os.environ["AWS_SECRET_ACCESS_KEY"] = context.get_secret("AWS_SECRET_ACCESS_KEY")
os.environ["AWS_DEFAULT_REGION"] = context.get_secret("AWS_DEFAULT_REGION")
os.environ["SAGEMAKER-ROLE"] = context.get_secret("SAGEMAKER-ROLE")
os.environ["SAGEMAKER_ROLE"] = context.get_secret("SAGEMAKER_ROLE")
4 changes: 2 additions & 2 deletions src/functions/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def train(context):
sagemaker_session = sagemaker.session.Session(
boto_session=boto_session, sagemaker_client=sm_client
)
role = context.get_secret("SAGEMAKER-ROLE")
role = context.get_secret("SAGEMAKER_ROLE")
bucket_prefix = "payment-classification"
s3_bucket = sagemaker_session.default_bucket()

Expand Down Expand Up @@ -74,4 +74,4 @@ def _set_envars(context):
os.environ["AWS_ACCESS_KEY_ID"] = context.get_secret("AWS_ACCESS_KEY_ID")
os.environ["AWS_SECRET_ACCESS_KEY"] = context.get_secret("AWS_SECRET_ACCESS_KEY")
os.environ["AWS_DEFAULT_REGION"] = context.get_secret("AWS_DEFAULT_REGION")
os.environ["SAGEMAKER-ROLE"] = context.get_secret("SAGEMAKER-ROLE")
os.environ["SAGEMAKER_ROLE"] = context.get_secret("SAGEMAKER_ROLE")

0 comments on commit 82937c3

Please sign in to comment.