Skip to content

Commit

Permalink
updating train to include sagemaker feature store code
Browse files Browse the repository at this point in the history
  • Loading branch information
aviaIguazio committed Feb 13, 2024
1 parent b247cdf commit 308b816
Show file tree
Hide file tree
Showing 2 changed files with 196 additions and 34 deletions.
64 changes: 32 additions & 32 deletions financial-payment-pipeline.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 1,
"id": "c447c260-b243-4f62-8a48-9dd07091282d",
"metadata": {
"editable": true,
Expand All @@ -26,7 +26,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 2,
"id": "e34f8c80-6584-4e80-981c-0f17e1584ebf",
"metadata": {
"tags": []
Expand All @@ -36,7 +36,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"> 2024-02-12 13:44:28,481 [info] Project loaded successfully: {'project_name': 'sagemaker'}\n"
"> 2024-02-13 10:14:19,611 [info] Project loaded successfully: {'project_name': 'sagemaker'}\n"
]
}
],
Expand All @@ -61,7 +61,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 3,
"id": "9fb9bf8f-dfc5-4b9f-8d63-d423ae326630",
"metadata": {
"tags": []
Expand All @@ -70,7 +70,7 @@
{
"data": {
"text/html": [
"<div>Pipeline running (id=eb48cc6e-d6ae-4a2d-947f-600a6e4cd469), <a href=\"https://dashboard.default-tenant.app.cust-cs-il-353.iguazio-cd2.com/mlprojects/sagemaker-admin/jobs/monitor-workflows/workflow/eb48cc6e-d6ae-4a2d-947f-600a6e4cd469\" target=\"_blank\"><b>click here</b></a> to view the details in MLRun UI</div>"
"<div>Pipeline running (id=24e028c0-cc0d-45c3-bdc7-ae29c02d1ee1), <a href=\"https://dashboard.default-tenant.app.cust-cs-il-353.iguazio-cd2.com/mlprojects/sagemaker-admin/jobs/monitor-workflows/workflow/24e028c0-cc0d-45c3-bdc7-ae29c02d1ee1\" target=\"_blank\"><b>click here</b></a> to view the details in MLRun UI</div>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
Expand All @@ -93,44 +93,44 @@
"<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 112)\">\n",
"<title>kfp</title>\n",
"<polygon fill=\"white\" stroke=\"none\" points=\"-4,4 -4,-112 225.82,-112 225.82,4 -4,4\"/>\n",
"<!-- fraud&#45;detection&#45;pipeline&#45;cpzwr&#45;2562321235 -->\n",
"<!-- fraud&#45;detection&#45;pipeline&#45;47fhj&#45;3902555814 -->\n",
"<g id=\"node1\" class=\"node\">\n",
"<title>fraud&#45;detection&#45;pipeline&#45;cpzwr&#45;2562321235</title>\n",
"<title>fraud&#45;detection&#45;pipeline&#45;47fhj&#45;3902555814</title>\n",
"<ellipse fill=\"green\" stroke=\"black\" cx=\"115.25\" cy=\"-90\" rx=\"33.1\" ry=\"18\"/>\n",
"<text text-anchor=\"middle\" x=\"115.25\" y=\"-84.58\" font-family=\"Times,serif\" font-size=\"14.00\">train</text>\n",
"</g>\n",
"<!-- fraud&#45;detection&#45;pipeline&#45;cpzwr&#45;650189257 -->\n",
"<!-- fraud&#45;detection&#45;pipeline&#45;47fhj&#45;973894202 -->\n",
"<g id=\"node2\" class=\"node\">\n",
"<title>fraud&#45;detection&#45;pipeline&#45;cpzwr&#45;650189257</title>\n",
"<title>fraud&#45;detection&#45;pipeline&#45;47fhj&#45;973894202</title>\n",
"<polygon fill=\"green\" stroke=\"black\" points=\"110.5,-36 4,-36 0,-32 0,0 106.5,0 110.5,-4 110.5,-36\"/>\n",
"<polyline fill=\"none\" stroke=\"black\" points=\"106.5,-32 0,-32\"/>\n",
"<polyline fill=\"none\" stroke=\"black\" points=\"106.5,-32 106.5,0\"/>\n",
"<polyline fill=\"none\" stroke=\"black\" points=\"106.5,-32 110.5,-36\"/>\n",
"<text text-anchor=\"middle\" x=\"55.25\" y=\"-12.57\" font-family=\"Times,serif\" font-size=\"14.00\">deploy&#45;serving</text>\n",
"</g>\n",
"<!-- fraud&#45;detection&#45;pipeline&#45;cpzwr&#45;2562321235&#45;&gt;fraud&#45;detection&#45;pipeline&#45;cpzwr&#45;650189257 -->\n",
"<!-- fraud&#45;detection&#45;pipeline&#45;47fhj&#45;3902555814&#45;&gt;fraud&#45;detection&#45;pipeline&#45;47fhj&#45;973894202 -->\n",
"<g id=\"edge1\" class=\"edge\">\n",
"<title>fraud&#45;detection&#45;pipeline&#45;cpzwr&#45;2562321235&#45;&gt;fraud&#45;detection&#45;pipeline&#45;cpzwr&#45;650189257</title>\n",
"<title>fraud&#45;detection&#45;pipeline&#45;47fhj&#45;3902555814&#45;&gt;fraud&#45;detection&#45;pipeline&#45;47fhj&#45;973894202</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M101.64,-73.12C94.36,-64.63 85.23,-53.98 77,-44.38\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"79.24,-42.61 70.07,-37.29 73.92,-47.16 79.24,-42.61\"/>\n",
"</g>\n",
"<!-- fraud&#45;detection&#45;pipeline&#45;cpzwr&#45;2902630996 -->\n",
"<!-- fraud&#45;detection&#45;pipeline&#45;47fhj&#45;4183868495 -->\n",
"<g id=\"node3\" class=\"node\">\n",
"<title>fraud&#45;detection&#45;pipeline&#45;cpzwr&#45;2902630996</title>\n",
"<title>fraud&#45;detection&#45;pipeline&#45;47fhj&#45;4183868495</title>\n",
"<ellipse fill=\"green\" stroke=\"black\" cx=\"175.25\" cy=\"-18\" rx=\"46.57\" ry=\"18\"/>\n",
"<text text-anchor=\"middle\" x=\"175.25\" y=\"-12.57\" font-family=\"Times,serif\" font-size=\"14.00\">evaluate</text>\n",
"</g>\n",
"<!-- fraud&#45;detection&#45;pipeline&#45;cpzwr&#45;2562321235&#45;&gt;fraud&#45;detection&#45;pipeline&#45;cpzwr&#45;2902630996 -->\n",
"<!-- fraud&#45;detection&#45;pipeline&#45;47fhj&#45;3902555814&#45;&gt;fraud&#45;detection&#45;pipeline&#45;47fhj&#45;4183868495 -->\n",
"<g id=\"edge2\" class=\"edge\">\n",
"<title>fraud&#45;detection&#45;pipeline&#45;cpzwr&#45;2562321235&#45;&gt;fraud&#45;detection&#45;pipeline&#45;cpzwr&#45;2902630996</title>\n",
"<title>fraud&#45;detection&#45;pipeline&#45;47fhj&#45;3902555814&#45;&gt;fraud&#45;detection&#45;pipeline&#45;47fhj&#45;4183868495</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M128.86,-73.12C136.35,-64.38 145.8,-53.35 154.22,-43.54\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"157.42,-46.18 161.27,-36.31 152.11,-41.62 157.42,-46.18\"/>\n",
"</g>\n",
"</g>\n",
"</svg>\n"
],
"text/plain": [
"<graphviz.graphs.Digraph at 0x7fee50abd5b0>"
"<graphviz.graphs.Digraph at 0x7fa4bd1f8d30>"
]
},
"metadata": {},
Expand All @@ -139,7 +139,7 @@
{
"data": {
"text/html": [
"<h2>Run Results</h2><h3>[info] Workflow eb48cc6e-d6ae-4a2d-947f-600a6e4cd469 finished, state=Succeeded</h3><br>click the hyper links below to see detailed results<br><table border=\"1\" class=\"dataframe\">\n",
"<h2>Run Results</h2><h3>[info] Workflow 24e028c0-cc0d-45c3-bdc7-ae29c02d1ee1 finished, state=Succeeded</h3><br>click the hyper links below to see detailed results<br><table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th>uid</th>\n",
Expand All @@ -152,16 +152,16 @@
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td><div title=\"e60c7c8f7bc64a5b8927c99d6a004109\"><a href=\"https://dashboard.default-tenant.app.cust-cs-il-353.iguazio-cd2.com/mlprojects/sagemaker-admin/jobs/monitor/e60c7c8f7bc64a5b8927c99d6a004109/overview\" target=\"_blank\" >...6a004109</a></div></td>\n",
" <td>Feb 12 13:51:31</td>\n",
" <td><div title=\"73ae4a820346453f83a5641e2d51f117\"><a href=\"https://dashboard.default-tenant.app.cust-cs-il-353.iguazio-cd2.com/mlprojects/sagemaker-admin/jobs/monitor/73ae4a820346453f83a5641e2d51f117/overview\" target=\"_blank\" >...2d51f117</a></div></td>\n",
" <td>Feb 13 10:21:33</td>\n",
" <td>completed</td>\n",
" <td>evaluate</td>\n",
" <td><div class=\"dictlist\">model_path=store://artifacts/sagemaker-admin/train_model_path@eb48cc6e-d6ae-4a2d-947f-600a6e4cd469</div><div class=\"dictlist\">model_name=xgboost-model</div><div class=\"dictlist\">label_column=transaction_category</div></td>\n",
" <td><div class=\"dictlist\">model_path=store://artifacts/sagemaker-admin/train_model_path@24e028c0-cc0d-45c3-bdc7-ae29c02d1ee1</div><div class=\"dictlist\">model_name=xgboost-model</div><div class=\"dictlist\">label_column=transaction_category</div></td>\n",
" <td></td>\n",
" </tr>\n",
" <tr>\n",
" <td><div title=\"23e9fbf5907845e2be991bafd67ee105\"><a href=\"https://dashboard.default-tenant.app.cust-cs-il-353.iguazio-cd2.com/mlprojects/sagemaker-admin/jobs/monitor/23e9fbf5907845e2be991bafd67ee105/overview\" target=\"_blank\" >...d67ee105</a></div></td>\n",
" <td>Feb 12 13:44:38</td>\n",
" <td><div title=\"abb418c2fed745028afff6d0eb4ea9d0\"><a href=\"https://dashboard.default-tenant.app.cust-cs-il-353.iguazio-cd2.com/mlprojects/sagemaker-admin/jobs/monitor/abb418c2fed745028afff6d0eb4ea9d0/overview\" target=\"_blank\" >...eb4ea9d0</a></div></td>\n",
" <td>Feb 13 10:14:27</td>\n",
" <td>completed</td>\n",
" <td>train</td>\n",
" <td></td>\n",
Expand All @@ -180,10 +180,10 @@
{
"data": {
"text/plain": [
"eb48cc6e-d6ae-4a2d-947f-600a6e4cd469"
"24e028c0-cc0d-45c3-bdc7-ae29c02d1ee1"
]
},
"execution_count": 4,
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -205,7 +205,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 4,
"id": "1421e3e9-dad2-4983-88d7-c9d48cb49fb2",
"metadata": {},
"outputs": [],
Expand All @@ -216,7 +216,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 5,
"id": "fff74774-7422-4c8f-af9f-e39ee2505f08",
"metadata": {},
"outputs": [],
Expand Down Expand Up @@ -255,27 +255,27 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 6,
"id": "5716bca9-ac11-44cf-b9da-b895dba9055f",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"> 2024-01-30 09:00:00,436 [info] invoking function: {'method': 'POST', 'path': 'http://nuclio-sagemaker-yoni-serving.default-tenant.svc.cluster.local:8080/predict'}\n"
"> 2024-02-13 10:23:44,657 [info] Invoking function: {'method': 'POST', 'path': 'http://sagemaker-admin-serving-sagemaker-admin.default-tenant.app.cust-cs-il-353.iguazio-cd2.com//predict'}\n"
]
},
{
"data": {
"text/plain": [
"{'id': 'cce12b91-6890-4de4-a584-0b23aa27aaac',\n",
"{'id': 'f0d430c4-79e8-426a-bd15-9acc0941bf84',\n",
" 'model_name': 'xgboost-model',\n",
" 'predictions': [1],\n",
" 'confidences': [0.43330907821655273]}"
" 'predictions': [17],\n",
" 'confidences': [0.3079691231250763]}"
]
},
"execution_count": 7,
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
Expand Down
166 changes: 164 additions & 2 deletions src/functions/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,176 @@
import mlrun.feature_store as fs
import numpy as np
import sagemaker
import pandas as pd
from sagemaker.feature_store.feature_group import FeatureGroup
import time




def train(context):
# Set AWS environment variables:
_set_envars(context)

# Get data from feature-store:
data = _get_feature_store_data(context)

region = sagemaker.Session().boto_region_name
sm_client = boto3.client("sagemaker")
boto_session = boto3.Session(region_name=region)
sagemaker_session = sagemaker.session.Session(boto_session=boto_session, sagemaker_client=sm_client)
role = os.environ["SAGEMAKER-ROLE"]
bucket_prefix = "payment-classification"
s3_bucket = sagemaker_session.default_bucket()

factorize_key = {
"Uncategorized": 0,
"Entertainment": 1,
"Education": 2,
"Shopping": 3,
"Personal Care": 4,
"Health and Fitness": 5,
"Food and Dining": 6,
"Gifts and Donations": 7,
"Investments": 8,
"Bills and Utilities": 9,
"Auto and Transport": 10,
"Travel": 11,
"Fees and Charges": 12,
"Business Services": 13,
"Personal Services": 14,
"Taxes": 15,
"Gambling": 16,
"Home": 17,
"Pension and insurances": 18,
}

factorize_key = {key: str(value) for key, value in factorize_key.items()}

s3 = boto3.client("s3")
s3.download_file(
f"sagemaker-example-files-prod-{region}",
"datasets/tabular/synthetic_financial/financial_transactions_mini.csv",
"financial_transactions_mini.csv",
)

data = pd.read_csv(
"financial_transactions_mini.csv",
parse_dates=["timestamp"],
infer_datetime_format=True,
dtype={"transaction_category": "string"},
)

data["year"] = data["timestamp"].dt.year
data["month"] = data["timestamp"].dt.month
data["day"] = data["timestamp"].dt.day
data["hour"] = data["timestamp"].dt.hour
data["minute"] = data["timestamp"].dt.minute
data["second"] = data["timestamp"].dt.second

del data["timestamp"]

data["transaction_category"] = data["transaction_category"].replace(factorize_key)

feature_group_name = "feature-group-payment-classification"
record_identifier_feature_name = "identifier"

feature_group = FeatureGroup(name=feature_group_name, sagemaker_session=sagemaker_session)

featurestore_runtime = boto_session.client(
service_name="sagemaker-featurestore-runtime", region_name=region
)

feature_store_session = sagemaker.Session(
boto_session=boto_session,
sagemaker_client=sm_client,
sagemaker_featurestore_runtime_client=featurestore_runtime,
)

columns = ["mean_amount", "count", "identifier", "EventTime"]
feature_store_data = pd.DataFrame(columns=columns, dtype=object)

feature_store_data["identifier"] = range(19)
feature_store_data["mean_amount"] = 0.0
feature_store_data["count"] = 1
feature_store_data["EventTime"] = time.time()

feature_group.load_feature_definitions(data_frame=feature_store_data)

status = feature_group.describe().get("FeatureGroupStatus")

if status!='Created':
feature_group.create(
s3_uri=f"s3://{s3_bucket}/{bucket_prefix}",
record_identifier_name=record_identifier_feature_name,
event_time_feature_name="EventTime",
role_arn=role,
enable_online_store=True,
)

status = feature_group.describe().get("FeatureGroupStatus")
while status == "Creating":
print("Waiting for Feature Group to be Created")
time.sleep(5)
status = feature_group.describe().get("FeatureGroupStatus")
print(f"FeatureGroup {feature_group.name} successfully created.")

feature_group.ingest(data_frame=feature_store_data, max_workers=3, wait=True)

def get_feature_store_values():
response = featurestore_runtime.batch_get_record(
Identifiers=[
{
"FeatureGroupName": feature_group_name,
"RecordIdentifiersValueAsString": [str(i) for i in range(19)],
}
]
)

columns = ["mean_amount", "count", "identifier", "EventTime"]

feature_store_resp = pd.DataFrame(
data=[
[resp["Record"][i]["ValueAsString"] for i in range(len(columns))]
for resp in response["Records"]
],
columns=columns,
)
feature_store_resp["identifier"] = feature_store_resp["identifier"].astype(int)
feature_store_resp["count"] = feature_store_resp["count"].astype(int)
feature_store_resp["mean_amount"] = feature_store_resp["mean_amount"].astype(float)
feature_store_resp["EventTime"] = feature_store_resp["EventTime"].astype(float)
feature_store_resp = feature_store_resp.sort_values(by="identifier")

return feature_store_resp

feature_store_resp = get_feature_store_values()

feature_store_data = pd.DataFrame()
feature_store_data["mean_amount"] = data.groupby(["transaction_category"]).mean()["amount"]
feature_store_data["count"] = data.groupby(["transaction_category"]).count()["amount"]
feature_store_data["identifier"] = feature_store_data.index
feature_store_data["EventTime"] = time.time()

feature_store_data["mean_amount"] = (
pd.concat([feature_store_resp, feature_store_data])
.groupby("identifier")
.apply(lambda x: np.average(x["mean_amount"], weights=x["count"]))
)
feature_store_data["count"] = (
pd.concat([feature_store_resp, feature_store_data]).groupby("identifier").sum()["count"]
)

feature_group.ingest(data_frame=feature_store_data, max_workers=3, wait=True)

feature_store_data = get_feature_store_values()

additional_features = pd.pivot_table(
feature_store_data, values=["mean_amount"], index=["identifier"]
).T.add_suffix("_dist")
additional_features_columns = list(additional_features.columns)
data = pd.concat([data, pd.DataFrame(columns=additional_features_columns, dtype=object)])
data[additional_features_columns] = additional_features.values[0]
for col in additional_features_columns:
data[col] = abs(data[col] - data["amount"])

# Randomly sort the data then split out first 70%, second 20%, and last 10%
train_data, validation_data, test_data = np.split(
Expand Down

0 comments on commit 308b816

Please sign in to comment.