-
Notifications
You must be signed in to change notification settings - Fork 43
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Update Example folder for Internal Demo (#469)
<!-- Thanks for sending a pull request! Here are some tips for you: 1. Run unit tests and ensure that they are passing 2. If your change introduces any API changes, make sure to update the e2e tests 3. Make sure documentation is updated for your PR! --> **What this PR does / why we need it**: <!-- Explain here the context and why you're making the change. What is the problem you're trying to solve. ---> - I have fixed some example that not working - Add Custom Model using fastapi python **Which issue(s) this PR fixes**: <!-- *Automatically closes linked issue when PR is merged. Usage: `Fixes #<issue number>`, or `Fixes (paste link of issue)`. --> Fixes # **Does this PR introduce a user-facing change?**: <!-- If no, just write "NONE" in the release-note block below. If yes, a release note is required. Enter your extended release note in the block below. If the PR requires additional action from users switching to the new release, include the string "action required". For more information about release notes, see kubernetes' guide here: http://git.k8s.io/community/contributors/guide/release-notes.md --> ```release-note ``` **Checklist** - [ ] Added unit test, integration, and/or e2e tests - [ ] Tested locally - [ ] Updated documentation - [ ] Update Swagger spec if the PR introduce API changes - [ ] Regenerated Golang and Python client if the PR introduce API changes
- Loading branch information
1 parent
595b38a
commit 0a75202
Showing
11 changed files
with
458 additions
and
18 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,301 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"id": "88acfbf4", | ||
"metadata": {}, | ||
"source": [ | ||
"# Custom Model Sample" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "1c7f43f1", | ||
"metadata": {}, | ||
"source": [ | ||
"## Requirements\n", | ||
"\n", | ||
"- Authenticated to gcloud (```gcloud auth application-default login```)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "f6218f2a", | ||
"metadata": {}, | ||
"source": [ | ||
"This notebook demonstrate how to create and deploy custom model which using IRIS classifier based on xgboost model into Merlin." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "dcd7ae51", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import merlin\n", | ||
"import warnings\n", | ||
"import os\n", | ||
"import xgboost as xgb\n", | ||
"from merlin.model import ModelType\n", | ||
"from sklearn.datasets import load_iris\n", | ||
"warnings.filterwarnings('ignore')\n", | ||
"print(merlin.__version__)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "38fd8feb", | ||
"metadata": {}, | ||
"source": [ | ||
"## 1. Initialize Merlin Resources\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "46d2cc6b", | ||
"metadata": {}, | ||
"source": [ | ||
"### 1.1 Set Merlin Server" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "e4e165b2", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Set MLP Server\n", | ||
"MERLIN_SERVER_URL = os.environ.get(\"MERLIN_SERVER_URL\", \"localhost:8080/api/merlin\")\n", | ||
"merlin.set_url(MERLIN_SERVER_URL)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "c826d621", | ||
"metadata": {}, | ||
"source": [ | ||
"### 1.2 Set Active Project\n", | ||
"\n", | ||
"`project` represent a project in real life. You may have multiple model within a project.\n", | ||
"\n", | ||
"`merlin.set_project(<project_name>)` will set the active project into the name matched by argument. You can only set it to an existing project. If you would like to create a new project, please do so from the MLP console at http://localhost:8080/projects/create." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "813438d9", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"merlin.set_project(\"sample\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "546a926d", | ||
"metadata": {}, | ||
"source": [ | ||
"### 1.3 Set Active Model\n", | ||
"\n", | ||
"`model` represents an abstract ML model. Conceptually, `model` in Merlin is similar to a class in programming language. To instantiate a `model` you'll have to create a `model_version`.\n", | ||
"\n", | ||
"Each `model` has a type, currently model type supported by Merlin are: sklearn, xgboost, tensorflow, pytorch, and user defined model (i.e. pyfunc model).\n", | ||
"\n", | ||
"`model_version` represents a snapshot of particular `model` iteration. You'll be able to attach information such as metrics and tag to a given `model_version` as well as deploy it as a model service.\n", | ||
"\n", | ||
"`merlin.set_model(<model_name>, <model_type>)` will set the active model to the name given by parameter, if the model with given name is not found, a new model will be created." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "d2325686", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"merlin.set_model(\"custom-model\", ModelType.CUSTOM)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "3c50c1d0", | ||
"metadata": {}, | ||
"source": [ | ||
"## 2. Train and Deploy Images" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "8a36d25c", | ||
"metadata": {}, | ||
"source": [ | ||
"### 2.1 Train Model" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "33430c99", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"model_dir = \"xgboost-model\"\n", | ||
"BST_FILE = \"model.bst\"\n", | ||
"\n", | ||
"iris = load_iris()\n", | ||
"y = iris['target']\n", | ||
"X = iris['data']\n", | ||
"dtrain = xgb.DMatrix(X, label=y)\n", | ||
"param = {'max_depth': 6,\n", | ||
" 'eta': 0.1,\n", | ||
" 'silent': 1,\n", | ||
" 'nthread': 4,\n", | ||
" 'num_class': 10,\n", | ||
" 'objective': 'multi:softmax'\n", | ||
" }\n", | ||
"xgb_model = xgb.train(params=param, dtrain=dtrain)\n", | ||
"model_file = os.path.join((model_dir), BST_FILE)\n", | ||
"xgb_model.save_model(model_file)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "9bb2ae92", | ||
"metadata": {}, | ||
"source": [ | ||
"### 2.2 Create Model Version and Upload Model" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "5f27678f", | ||
"metadata": {}, | ||
"source": [ | ||
"`merlin.new_model_version()` is a convenient method to create a model version and start its development process. It is equal to following codes:\n", | ||
"\n", | ||
"```\n", | ||
"v = model.new_model_version()\n", | ||
"v.start()\n", | ||
"v.log_custom_model(image=\"ghcr.io/gojek/custom-model:v0.3\",model_dir=model_dir)\n", | ||
"v.finish()\n", | ||
"```\n", | ||
"\n", | ||
"\n", | ||
"This image `afif2100/caraml-dev-merlin-sample-test-custom-model:latest` is built by using this [Dockerfile](./server/dockerfile). The image contains python fast-api web service executable where the code you can find [here](./server)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "63d9b1fb", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Create new version of the model\n", | ||
"with merlin.new_model_version() as v:\n", | ||
" # Upload the serialized model to Merlin\n", | ||
" merlin.log_custom_model(image=\"afif2100/caraml-dev-merlin-sample-test-custom-model:latest\", model_dir=model_dir)\n", | ||
" " | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "5045593a", | ||
"metadata": {}, | ||
"source": [ | ||
"### 2.2 Deploy Model\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "0107a735", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from merlin.protocol import Protocol\n", | ||
"\n", | ||
"endpoint = merlin.deploy(v, protocol = Protocol.HTTP_JSON)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "779c5a0c", | ||
"metadata": {}, | ||
"source": [ | ||
"### 2.3 Send Test Request" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "c7ae3f9b", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Test deployment\n", | ||
"import requests\n", | ||
"\n", | ||
"# Get endpoint\n", | ||
"data = {\"instances\": [[1,2,3,4], [2,1,2,4]]}\n", | ||
"\n", | ||
"# Send request\n", | ||
"response = requests.post(endpoint.url, json=data)\n", | ||
"print(response.json())" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "f167e8ba", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"%%bash -s \"$endpoint.url\"\n", | ||
"curl -v -X POST $1 -H \"Content-Type: application/json\" -d '{\"instances\": [[1,2,3,4], [2,1,2,4]]}'" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "c14ce243", | ||
"metadata": {}, | ||
"source": [ | ||
"### 2.4 Delete Deployment" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "076dd1d4", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"merlin.undeploy(v)" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "merlin-sdk", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.8.18" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 5 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
scikit-learn==1.1.2 | ||
xgboost==1.6.2 | ||
merlin-sdk |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
import os | ||
from fastapi import FastAPI | ||
import xgboost as xgb | ||
|
||
app = FastAPI() | ||
|
||
# set global env | ||
MERLIN_MODEL_NAME = os.environ.get("MERLIN_MODEL_NAME", "my-model") | ||
MERLIN_PREDICTOR_PORT = int(os.environ.get("MERLIN_PREDICTOR_PORT", 8080)) | ||
MERLIN_ARTIFACT_LOCATION = os.environ.get("MERLIN_ARTIFACT_LOCATION", "model") | ||
|
||
# Set API endpoints | ||
API_HEALTH_ENDPOINT = "/" | ||
API_ENDPOINT = f"/v1/models/{MERLIN_MODEL_NAME}" | ||
API_ENDPOINT_PREDICT = f"{API_ENDPOINT}:predict" | ||
|
||
# Print Endpoint Info | ||
print("Starting API server") | ||
print(f"Starting API server: {API_ENDPOINT}") | ||
print(f"Starting API predict server: {API_ENDPOINT_PREDICT}") | ||
print(f"Artifact Location : {MERLIN_ARTIFACT_LOCATION}") | ||
|
||
# Create Prediction Class | ||
class XgbModel: | ||
def __init__(self): | ||
self.loaded = False | ||
self.model = xgb.Booster({'nthread': 4}) | ||
self.load_model(MERLIN_ARTIFACT_LOCATION) | ||
|
||
def load_model(self, model_path): | ||
model_file = os.path.join((model_path), 'model.bst') | ||
self.model.load_model(model_file) | ||
self.loaded = True | ||
|
||
def predict(self, request): | ||
data = request['instances'] | ||
dmatrix = xgb.DMatrix(data) | ||
predictions = self.model.predict(dmatrix) | ||
return {"response": predictions.tolist(), "status": "ok"} | ||
|
||
# Init Class | ||
prediction_model = XgbModel() | ||
|
||
|
||
# API Endpoints | ||
@app.get(API_HEALTH_ENDPOINT) | ||
async def root(): | ||
return {"message": "API Ready"} | ||
|
||
@app.get(API_ENDPOINT) | ||
async def predict_status(): | ||
if prediction_model.loaded: | ||
return {"message": "Model is loaded"} | ||
else: | ||
return {"message": "Model is not loaded"}, 503 | ||
|
||
@app.post(API_ENDPOINT_PREDICT) | ||
async def predict(request:dict): | ||
response = prediction_model.predict(request) | ||
return response |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
#!/bin/sh | ||
|
||
# Get the value of the MERLIN_PREDICTOR_PORT environment variable | ||
# Get the value of the MERLIN_GUNICORN_WORKERS environment variable | ||
port="$MERLIN_PREDICTOR_PORT" | ||
workers="$WORKERS" | ||
|
||
# If the port environment variable is not set, use the default port 8080 | ||
# If the workers environment variable is not set, use the default worker number 1 | ||
|
||
if [ -z "$port" ]; then | ||
port="8080" | ||
fi | ||
|
||
if [ -z "$workers" ]; then | ||
workers="1" | ||
fi | ||
|
||
# Execute the Gunicorn command with the specified port and number of workers | ||
exec uvicorn app:app --host=0.0.0.0 --port=$port --workers=$workers --no-access-log |
Oops, something went wrong.