From ebce098f83b114994160acff51fad7e6a4a64c07 Mon Sep 17 00:00:00 2001 From: Tio Pramayudi Date: Wed, 24 Jan 2024 17:31:35 +0700 Subject: [PATCH 1/5] Fix version patch and post --- api/api/versions_api.go | 1 + api/api/versions_api_test.go | 316 +++++++++++++++++++++++++++ api/models/version.go | 10 +- python/sdk/client/api_client.py | 2 + python/sdk/merlin/model.py | 4 + python/sdk/merlin/model_schema.py | 9 +- python/sdk/test/integration_test.py | 39 +++- python/sdk/test/model_schema_test.py | 9 +- 8 files changed, 380 insertions(+), 10 deletions(-) diff --git a/api/api/versions_api.go b/api/api/versions_api.go index 08e507930..0777312d0 100644 --- a/api/api/versions_api.go +++ b/api/api/versions_api.go @@ -137,6 +137,7 @@ func (c *VersionsController) CreateVersion(r *http.Request, vars map[string]stri ArtifactURI: run.Info.ArtifactURI, Labels: versionPost.Labels, PythonVersion: versionPost.PythonVersion, + ModelSchema: versionPost.ModelSchema, } version, _ = c.VersionsService.Save(ctx, version, c.FeatureToggleConfig.MonitoringConfig) diff --git a/api/api/versions_api_test.go b/api/api/versions_api_test.go index e8938f088..4e36f016f 100644 --- a/api/api/versions_api_test.go +++ b/api/api/versions_api_test.go @@ -708,6 +708,170 @@ func TestPatchVersion(t *testing.T) { data: Error{Message: "Error patching model version: Error creating secret: db is down"}, }, }, + { + desc: "Should success update model schema", + vars: map[string]string{ + "model_id": "1", + "version_id": "1", + }, + requestBody: &models.VersionPatch{ + Properties: &models.KV{ + "name": "model-1", + "created_by": "anonymous", + }, + ModelSchema: &models.ModelSchema{ + Spec: &models.SchemaSpec{ + PredictionIDColumn: "prediction_id", + ModelPredictionOutput: &models.ModelPredictionOutput{ + RankingOutput: &models.RankingOutput{ + PredictionGroudIDColumn: "session_id", + RankScoreColumn: "score", + RelevanceScoreColumn: "relevance_score", + OutputClass: models.Ranking, + }, + }, + FeatureTypes: map[string]models.ValueType{ + "featureA": models.Float64, + "featureB": models.Int64, + "featureC": models.Boolean, + }, + }, + ModelID: models.ID(1), + }, + }, + versionService: func() *mocks.VersionsService { + svc := &mocks.VersionsService{} + svc.On("FindByID", mock.Anything, models.ID(1), models.ID(1), mock.Anything).Return( + &models.Version{ + ID: models.ID(1), + ModelID: models.ID(1), + Model: &models.Model{ + ID: models.ID(1), + Name: "model-1", + ProjectID: models.ID(1), + Project: mlp.Project{}, + ExperimentID: 1, + Type: "pyfunc", + MlflowURL: "http://mlflow.com", + }, + MlflowURL: "http://mlflow.com", + }, nil) + svc.On("Save", mock.Anything, &models.Version{ + ID: models.ID(1), + ModelID: models.ID(1), + Model: &models.Model{ + ID: models.ID(1), + Name: "model-1", + ProjectID: models.ID(1), + Project: mlp.Project{}, + ExperimentID: 1, + Type: "pyfunc", + MlflowURL: "http://mlflow.com", + }, + MlflowURL: "http://mlflow.com", + Properties: models.KV{ + "name": "model-1", + "created_by": "anonymous", + }, + ModelSchema: &models.ModelSchema{ + Spec: &models.SchemaSpec{ + PredictionIDColumn: "prediction_id", + ModelPredictionOutput: &models.ModelPredictionOutput{ + RankingOutput: &models.RankingOutput{ + PredictionGroudIDColumn: "session_id", + RankScoreColumn: "score", + RelevanceScoreColumn: "relevance_score", + OutputClass: models.Ranking, + }, + }, + FeatureTypes: map[string]models.ValueType{ + "featureA": models.Float64, + "featureB": models.Int64, + "featureC": models.Boolean, + }, + }, + ModelID: models.ID(1), + }, + }, mock.Anything).Return(&models.Version{ + ID: models.ID(1), + ModelID: models.ID(1), + Model: &models.Model{ + ID: models.ID(1), + Name: "model-1", + ProjectID: models.ID(1), + Project: mlp.Project{}, + ExperimentID: 1, + Type: "pyfunc", + MlflowURL: "http://mlflow.com", + }, + MlflowURL: "http://mlflow.com", + Properties: models.KV{ + "name": "model-1", + "created_by": "anonymous", + }, + ModelSchema: &models.ModelSchema{ + Spec: &models.SchemaSpec{ + PredictionIDColumn: "prediction_id", + ModelPredictionOutput: &models.ModelPredictionOutput{ + RankingOutput: &models.RankingOutput{ + PredictionGroudIDColumn: "session_id", + RankScoreColumn: "score", + RelevanceScoreColumn: "relevance_score", + OutputClass: models.Ranking, + }, + }, + FeatureTypes: map[string]models.ValueType{ + "featureA": models.Float64, + "featureB": models.Int64, + "featureC": models.Boolean, + }, + }, + ModelID: models.ID(1), + }, + }, nil) + return svc + }, + expected: &Response{ + code: http.StatusOK, + data: &models.Version{ + ID: models.ID(1), + ModelID: models.ID(1), + Model: &models.Model{ + ID: models.ID(1), + Name: "model-1", + ProjectID: models.ID(1), + Project: mlp.Project{}, + ExperimentID: 1, + Type: "pyfunc", + MlflowURL: "http://mlflow.com", + }, + MlflowURL: "http://mlflow.com", + Properties: models.KV{ + "name": "model-1", + "created_by": "anonymous", + }, + ModelSchema: &models.ModelSchema{ + Spec: &models.SchemaSpec{ + PredictionIDColumn: "prediction_id", + ModelPredictionOutput: &models.ModelPredictionOutput{ + RankingOutput: &models.RankingOutput{ + PredictionGroudIDColumn: "session_id", + RankScoreColumn: "score", + RelevanceScoreColumn: "relevance_score", + OutputClass: models.Ranking, + }, + }, + FeatureTypes: map[string]models.ValueType{ + "featureA": models.Float64, + "featureB": models.Int64, + "featureC": models.Boolean, + }, + }, + ModelID: models.ID(1), + }, + }, + }, + }, } for _, tC := range testCases { t.Run(tC.desc, func(t *testing.T) { @@ -1155,6 +1319,158 @@ func TestCreateVersion(t *testing.T) { }, }, }, + { + desc: "Should successfully create version with model schema", + vars: map[string]string{ + "model_id": "1", + }, + body: models.VersionPost{ + ModelSchema: &models.ModelSchema{ + Spec: &models.SchemaSpec{ + PredictionIDColumn: "prediction_id", + ModelPredictionOutput: &models.ModelPredictionOutput{ + RankingOutput: &models.RankingOutput{ + PredictionGroudIDColumn: "session_id", + RankScoreColumn: "score", + RelevanceScoreColumn: "relevance_score", + OutputClass: models.Ranking, + }, + }, + FeatureTypes: map[string]models.ValueType{ + "featureA": models.Float64, + "featureB": models.Int64, + "featureC": models.Boolean, + }, + }, + ModelID: models.ID(1), + }, + }, + modelsService: func() *mocks.ModelsService { + svc := &mocks.ModelsService{} + svc.On("FindByID", mock.Anything, models.ID(1)).Return(&models.Model{ + ID: models.ID(1), + Name: "model-1", + ProjectID: models.ID(1), + Project: mlp.Project{ + MLFlowTrackingURL: "http://www.notinuse.com", + }, + ExperimentID: 1, + Type: "pyfunc", + MlflowURL: "http://mlflow.com", + Endpoints: nil, + }, nil) + return svc + }, + mlflowClient: func() *mlfmocks.Client { + svc := &mlfmocks.Client{} + svc.On("CreateRun", "1").Return(&mlflow.Run{ + Info: mlflow.Info{ + RunID: "1", + ArtifactURI: "artifact/url/run", + }, + }, nil) + return svc + }, + versionService: func() *mocks.VersionsService { + svc := &mocks.VersionsService{} + svc.On("Save", mock.Anything, &models.Version{ + ModelID: models.ID(1), + RunID: "1", + ArtifactURI: "artifact/url/run", + PythonVersion: DEFAULT_PYTHON_VERSION, + ModelSchema: &models.ModelSchema{ + Spec: &models.SchemaSpec{ + PredictionIDColumn: "prediction_id", + ModelPredictionOutput: &models.ModelPredictionOutput{ + RankingOutput: &models.RankingOutput{ + PredictionGroudIDColumn: "session_id", + RankScoreColumn: "score", + RelevanceScoreColumn: "relevance_score", + OutputClass: models.Ranking, + }, + }, + FeatureTypes: map[string]models.ValueType{ + "featureA": models.Float64, + "featureB": models.Int64, + "featureC": models.Boolean, + }, + }, + ModelID: models.ID(1), + }, + }, mock.Anything).Return(&models.Version{ + ID: models.ID(1), + ModelID: models.ID(1), + Model: &models.Model{ + ID: models.ID(1), + Name: "model-1", + ProjectID: models.ID(1), + Project: mlp.Project{}, + ExperimentID: 1, + Type: "sklearn", + MlflowURL: "http://mlflow.com", + }, + MlflowURL: "http://mlflow.com", + PythonVersion: DEFAULT_PYTHON_VERSION, + ModelSchema: &models.ModelSchema{ + Spec: &models.SchemaSpec{ + PredictionIDColumn: "prediction_id", + ModelPredictionOutput: &models.ModelPredictionOutput{ + RankingOutput: &models.RankingOutput{ + PredictionGroudIDColumn: "session_id", + RankScoreColumn: "score", + RelevanceScoreColumn: "relevance_score", + OutputClass: models.Ranking, + }, + }, + FeatureTypes: map[string]models.ValueType{ + "featureA": models.Float64, + "featureB": models.Int64, + "featureC": models.Boolean, + }, + }, + ModelID: models.ID(1), + }, + }, nil) + return svc + }, + expected: &Response{ + code: http.StatusCreated, + data: &models.Version{ + ID: models.ID(1), + ModelID: models.ID(1), + Model: &models.Model{ + ID: models.ID(1), + Name: "model-1", + ProjectID: models.ID(1), + Project: mlp.Project{}, + ExperimentID: 1, + Type: "sklearn", + MlflowURL: "http://mlflow.com", + }, + MlflowURL: "http://mlflow.com", + PythonVersion: DEFAULT_PYTHON_VERSION, + ModelSchema: &models.ModelSchema{ + Spec: &models.SchemaSpec{ + PredictionIDColumn: "prediction_id", + ModelPredictionOutput: &models.ModelPredictionOutput{ + RankingOutput: &models.RankingOutput{ + PredictionGroudIDColumn: "session_id", + RankScoreColumn: "score", + RelevanceScoreColumn: "relevance_score", + OutputClass: models.Ranking, + }, + }, + FeatureTypes: map[string]models.ValueType{ + "featureA": models.Float64, + "featureB": models.Int64, + "featureC": models.Boolean, + }, + }, + ModelID: models.ID(1), + }, + }, + }, + }, } for _, tC := range testCases { t.Run(tC.desc, func(t *testing.T) { diff --git a/api/models/version.go b/api/models/version.go index 4336c74ed..1bba9ff33 100644 --- a/api/models/version.go +++ b/api/models/version.go @@ -40,13 +40,15 @@ type Version struct { } type VersionPost struct { - Labels KV `json:"labels" gorm:"labels"` - PythonVersion string `json:"python_version" gorm:"python_version"` + Labels KV `json:"labels" gorm:"labels"` + PythonVersion string `json:"python_version" gorm:"python_version"` + ModelSchema *ModelSchema `json:"model_schema"` } type VersionPatch struct { Properties *KV `json:"properties,omitempty"` CustomPredictor *CustomPredictor `json:"custom_predictor,omitempty"` + ModelSchema *ModelSchema `json:"model_schema"` } type CustomPredictor struct { @@ -100,6 +102,10 @@ func (v *Version) Patch(patch *VersionPatch) error { } v.CustomPredictor = patch.CustomPredictor } + if patch.ModelSchema != nil { + v.ModelSchema = patch.ModelSchema + } + return nil } diff --git a/python/sdk/client/api_client.py b/python/sdk/client/api_client.py index 5bf60ec80..27e801662 100644 --- a/python/sdk/client/api_client.py +++ b/python/sdk/client/api_client.py @@ -296,6 +296,7 @@ def response_deserialize( # if not found, look for '1XX', '2XX', etc. response_type = response_types_map.get(str(response_data.status)[0] + "XX", None) + print(f"response status ----- {response_data.status}") if not 200 <= response_data.status <= 299: if response_data.status == 400: raise BadRequestException(http_resp=response_data) @@ -328,6 +329,7 @@ def response_deserialize( match = re.search(r"charset=([a-zA-Z\-\d]+)[\s;]?", content_type) encoding = match.group(1) if match else "utf-8" response_text = response_data.data.decode(encoding) + print(f"response_text ------ {response_text}") return_data = self.deserialize(response_text, response_type) return ApiResponse( diff --git a/python/sdk/merlin/model.py b/python/sdk/merlin/model.py index 7a9924743..a3c6d3c09 100644 --- a/python/sdk/merlin/model.py +++ b/python/sdk/merlin/model.py @@ -783,6 +783,10 @@ def url(self) -> str: model_id = self.model.id base_url = guess_mlp_ui_url(self.model.project.url) return f"{base_url}/projects/{project_id}/models/{model_id}/versions" + + @property + def model_schema(self) -> Optional[ModelSchema]: + return self._model_schema def start(self): """ diff --git a/python/sdk/merlin/model_schema.py b/python/sdk/merlin/model_schema.py index 376d18c2b..d5a9fbf86 100644 --- a/python/sdk/merlin/model_schema.py +++ b/python/sdk/merlin/model_schema.py @@ -81,18 +81,21 @@ def _to_client_prediction_output_spec(self) -> client.ModelPredictionOutput: actual_label_column=prediction_output.actual_label_column, positive_class_label=prediction_output.positive_class_label, negative_class_label=prediction_output.negative_class_label, - score_threshold=prediction_output.score_threshold + score_threshold=prediction_output.score_threshold, + output_class=client.ModelPredictionOutputClass(BinaryClassificationOutput.__name__) )) elif isinstance(prediction_output, RegressionOutput): return client.ModelPredictionOutput(client.RegressionOutput( actual_score_column=prediction_output.actual_score_column, - prediction_score_column=prediction_output.prediction_score_column + prediction_score_column=prediction_output.prediction_score_column, + output_class=client.ModelPredictionOutputClass(RegressionOutput.__name__) )) elif isinstance(prediction_output, RankingOutput): return client.ModelPredictionOutput(client.RankingOutput( relevance_score_column=prediction_output.relevance_score_column, prediction_group_id_column=prediction_output.prediction_group_id_column, - rank_score_column=prediction_output.rank_score_column + rank_score_column=prediction_output.rank_score_column, + output_class=client.ModelPredictionOutputClass(RankingOutput.__name__) )) raise ValueError("model prediction output is not recognized") diff --git a/python/sdk/test/integration_test.py b/python/sdk/test/integration_test.py index 44206c2cd..bac19c8f3 100644 --- a/python/sdk/test/integration_test.py +++ b/python/sdk/test/integration_test.py @@ -128,7 +128,37 @@ def test_xgboost( undeploy_all_version() - with merlin.new_model_version(model_schema=ModelSchema(spec=InferenceSchema( + with merlin.new_model_version() as v: + # Upload the serialized model to MLP + merlin.log_model(model_dir=model_dir) + + endpoint = merlin.deploy(v, deployment_mode=deployment_mode) + resp = requests.post(f"{endpoint.url}", json=request_json) + + assert resp.status_code == 200 + assert resp.json() is not None + assert len(resp.json()["predictions"]) == len(request_json["instances"]) + + merlin.undeploy(v) + +@pytest.mark.integration +@pytest.mark.dependency() +@pytest.mark.parametrize( + "deployment_mode", [DeploymentMode.RAW_DEPLOYMENT, DeploymentMode.SERVERLESS] +) +def test_model_schema( + integration_test_url, project_name, deployment_mode, use_google_oauth, requests +): + merlin.set_url(integration_test_url, use_google_oauth=use_google_oauth) + merlin.set_project(project_name) + merlin.set_model( + f"model-schema-{deployment_mode_suffix(deployment_mode)}", ModelType.XGBOOST + ) + + model_dir = "test/xgboost-model" + + undeploy_all_version() + model_schema = ModelSchema(spec=InferenceSchema( feature_types={ "featureA": ValueType.FLOAT64, "featureB": ValueType.INT64, @@ -143,10 +173,14 @@ def test_xgboost( negative_class_label="non_complete", score_threshold=0.7 ) - ))) as v: + )) + + with merlin.new_model_version(model_schema=model_schema) as v: # Upload the serialized model to MLP merlin.log_model(model_dir=model_dir) + assert v.model_schema == model_schema + endpoint = merlin.deploy(v, deployment_mode=deployment_mode) resp = requests.post(f"{endpoint.url}", json=request_json) @@ -157,6 +191,7 @@ def test_xgboost( merlin.undeploy(v) + @pytest.mark.integration def test_mlflow_tracking( integration_test_url, project_name, use_google_oauth, requests diff --git a/python/sdk/test/model_schema_test.py b/python/sdk/test/model_schema_test.py index 47345d484..500745e62 100644 --- a/python/sdk/test/model_schema_test.py +++ b/python/sdk/test/model_schema_test.py @@ -26,7 +26,8 @@ actual_label_column="actual_label", positive_class_label="positive", negative_class_label="negative", - score_threshold=0.5 + score_threshold=0.5, + output_class=client.ModelPredictionOutputClass.BINARYCLASSIFICATIONOUTPUT ) ) ) @@ -70,7 +71,8 @@ model_prediction_output=client.ModelPredictionOutput( client.RegressionOutput( prediction_score_column="prediction_score", - actual_score_column="actual_score" + actual_score_column="actual_score", + output_class=client.ModelPredictionOutputClass.REGRESSIONOUTPUT ) ) ) @@ -112,7 +114,8 @@ client.RankingOutput( rank_score_column="score", prediction_group_id_column="session_id", - relevance_score_column="relevance_score" + relevance_score_column="relevance_score", + output_class=client.ModelPredictionOutputClass.RANKINGOUTPUT ) ) ) From 71dbb0976a6b58738295591c4bcce7b4f409f64e Mon Sep 17 00:00:00 2001 From: Tio Pramayudi Date: Thu, 25 Jan 2024 14:46:24 +0700 Subject: [PATCH 2/5] Add more validation for version create and patch also exposes model_id as payload field --- api/api/model_schema_api.go | 5 + api/api/model_schema_api_test.go | 124 +++++++++++++++++++++ api/api/versions_api.go | 13 ++- api/api/versions_api_test.go | 183 ++++++++++++++++++++++++++++++- api/models/model_schema.go | 2 +- api/models/version.go | 22 +++- python/sdk/client/api_client.py | 2 - 7 files changed, 338 insertions(+), 13 deletions(-) diff --git a/api/api/model_schema_api.go b/api/api/model_schema_api.go index 53dcf5081..800295414 100644 --- a/api/api/model_schema_api.go +++ b/api/api/model_schema_api.go @@ -49,6 +49,11 @@ func (m *ModelSchemaController) CreateOrUpdateSchema(r *http.Request, vars map[s if !ok { return BadRequest("Unable to parse request body") } + + if modelSchema.ModelID > 0 && modelSchema.ModelID != modelID { + return BadRequest("Mismatch model id between request path and body") + } + modelSchema.ModelID = modelID schema, err := m.ModelSchemaService.Save(ctx, modelSchema) if err != nil { diff --git a/api/api/model_schema_api_test.go b/api/api/model_schema_api_test.go index e5f883107..9bda33e57 100644 --- a/api/api/model_schema_api_test.go +++ b/api/api/model_schema_api_test.go @@ -370,6 +370,98 @@ func TestModelSchemaController_CreateOrUpdateSchema(t *testing.T) { }, }, }, + { + desc: "success create ranking schema, with specifying model id", + vars: map[string]string{ + "model_id": "1", + }, + body: []byte(`{ + "spec": { + "prediction_id_column":"prediction_id", + "tag_columns": ["tags"], + "feature_types": { + "featureA": "float64", + "featureB": "int64", + "featureC": "boolean" + }, + "model_prediction_output": { + "prediction_group_id_column": "session_id", + "rank_score_column": "score", + "relevance_score": "relevance_score", + "output_class": "RankingOutput" + } + }, + "model_id": 1 + }`), + modelSchemaService: func() *mocks.ModelSchemaService { + mockSvc := &mocks.ModelSchemaService{} + mockSvc.On("Save", mock.Anything, &models.ModelSchema{ + ModelID: models.ID(1), + Spec: &models.SchemaSpec{ + PredictionIDColumn: "prediction_id", + TagColumns: []string{"tags"}, + FeatureTypes: map[string]models.ValueType{ + "featureA": models.Float64, + "featureB": models.Int64, + "featureC": models.Boolean, + }, + ModelPredictionOutput: &models.ModelPredictionOutput{ + RankingOutput: &models.RankingOutput{ + PredictionGroudIDColumn: "session_id", + RankScoreColumn: "score", + RelevanceScoreColumn: "relevance_score", + OutputClass: models.Ranking, + }, + }, + }, + }).Return(&models.ModelSchema{ + ID: models.ID(1), + ModelID: models.ID(1), + Spec: &models.SchemaSpec{ + PredictionIDColumn: "prediction_id", + TagColumns: []string{"tags"}, + FeatureTypes: map[string]models.ValueType{ + "featureA": models.Float64, + "featureB": models.Int64, + "featureC": models.Boolean, + }, + ModelPredictionOutput: &models.ModelPredictionOutput{ + RankingOutput: &models.RankingOutput{ + PredictionGroudIDColumn: "session_id", + RankScoreColumn: "score", + RelevanceScoreColumn: "relevance_score", + OutputClass: models.Ranking, + }, + }, + }, + }, nil) + return mockSvc + }, + expected: &Response{ + code: http.StatusOK, + data: &models.ModelSchema{ + ID: models.ID(1), + ModelID: models.ID(1), + Spec: &models.SchemaSpec{ + PredictionIDColumn: "prediction_id", + TagColumns: []string{"tags"}, + FeatureTypes: map[string]models.ValueType{ + "featureA": models.Float64, + "featureB": models.Int64, + "featureC": models.Boolean, + }, + ModelPredictionOutput: &models.ModelPredictionOutput{ + RankingOutput: &models.RankingOutput{ + PredictionGroudIDColumn: "session_id", + RankScoreColumn: "score", + RelevanceScoreColumn: "relevance_score", + OutputClass: models.Ranking, + }, + }, + }, + }, + }, + }, { desc: "success create binary classification schema", vars: map[string]string{ @@ -607,6 +699,38 @@ func TestModelSchemaController_CreateOrUpdateSchema(t *testing.T) { data: Error{Message: "Error save model schema: peer connection is reset"}, }, }, + { + desc: "model id mismatch", + vars: map[string]string{ + "model_id": "1", + }, + body: []byte(`{ + "spec": { + "prediction_id_column":"prediction_id", + "tag_columns": ["tags"], + "feature_types": { + "featureA": "float64", + "featureB": "int64", + "featureC": "boolean" + }, + "model_prediction_output": { + "prediction_group_id_column": "session_id", + "rank_score_column": "score", + "relevance_score": "relevance_score", + "output_class": "RankingOutput" + } + }, + "model_id": 2 + }`), + modelSchemaService: func() *mocks.ModelSchemaService { + mockSvc := &mocks.ModelSchemaService{} + return mockSvc + }, + expected: &Response{ + code: http.StatusBadRequest, + data: Error{Message: "Mismatch model id between request path and body"}, + }, + }, } for _, tt := range tests { t.Run(tt.desc, func(t *testing.T) { diff --git a/api/api/versions_api.go b/api/api/versions_api.go index 0777312d0..ebca9ef1d 100644 --- a/api/api/versions_api.go +++ b/api/api/versions_api.go @@ -69,7 +69,8 @@ func (c *VersionsController) PatchVersion(r *http.Request, vars map[string]strin return InternalServerError("Unable to parse request body") } - if err := v.Patch(versionPatch); err != nil { + v.Patch(versionPatch) + if err := v.Validate(); err != nil { return BadRequest(fmt.Sprintf("Error validating version: %v", err)) } @@ -105,7 +106,6 @@ func (c *VersionsController) ListVersions(r *http.Request, vars map[string]strin func (c *VersionsController) CreateVersion(r *http.Request, vars map[string]string, body interface{}) *Response { ctx := r.Context() - versionPost, ok := body.(*models.VersionPost) if !ok { return BadRequest("Unable to parse request body") @@ -140,7 +140,14 @@ func (c *VersionsController) CreateVersion(r *http.Request, vars map[string]stri ModelSchema: versionPost.ModelSchema, } - version, _ = c.VersionsService.Save(ctx, version, c.FeatureToggleConfig.MonitoringConfig) + if err := version.Validate(); err != nil { + return BadRequest(fmt.Sprintf("Error validating version: %v", err)) + } + + version, err = c.VersionsService.Save(ctx, version, c.FeatureToggleConfig.MonitoringConfig) + if err != nil { + return InternalServerError(fmt.Sprintf("Failed to save version: %v", err)) + } return Created(version) } diff --git a/api/api/versions_api_test.go b/api/api/versions_api_test.go index 4e36f016f..29f49b305 100644 --- a/api/api/versions_api_test.go +++ b/api/api/versions_api_test.go @@ -624,7 +624,7 @@ func TestPatchVersion(t *testing.T) { }, }, { - desc: "Should return 500 if request body is not valud", + desc: "Should return 500 if request body is not valid", vars: map[string]string{ "model_id": "1", "version_id": "1", @@ -872,6 +872,63 @@ func TestPatchVersion(t *testing.T) { }, }, }, + { + desc: "Should fail update model schema when there is mismatch of model id", + vars: map[string]string{ + "model_id": "1", + "version_id": "1", + }, + requestBody: &models.VersionPatch{ + Properties: &models.KV{ + "name": "model-1", + "created_by": "anonymous", + }, + ModelSchema: &models.ModelSchema{ + Spec: &models.SchemaSpec{ + PredictionIDColumn: "prediction_id", + ModelPredictionOutput: &models.ModelPredictionOutput{ + RankingOutput: &models.RankingOutput{ + PredictionGroudIDColumn: "session_id", + RankScoreColumn: "score", + RelevanceScoreColumn: "relevance_score", + OutputClass: models.Ranking, + }, + }, + FeatureTypes: map[string]models.ValueType{ + "featureA": models.Float64, + "featureB": models.Int64, + "featureC": models.Boolean, + }, + }, + ModelID: models.ID(5), + }, + }, + versionService: func() *mocks.VersionsService { + svc := &mocks.VersionsService{} + svc.On("FindByID", mock.Anything, models.ID(1), models.ID(1), mock.Anything).Return( + &models.Version{ + ID: models.ID(1), + ModelID: models.ID(1), + Model: &models.Model{ + ID: models.ID(1), + Name: "model-1", + ProjectID: models.ID(1), + Project: mlp.Project{}, + ExperimentID: 1, + Type: "pyfunc", + MlflowURL: "http://mlflow.com", + }, + MlflowURL: "http://mlflow.com", + }, nil) + return svc + }, + expected: &Response{ + code: http.StatusBadRequest, + data: Error{ + Message: "Error validating version: mismatch model id between version and model schema", + }, + }, + }, } for _, tC := range testCases { t.Run(tC.desc, func(t *testing.T) { @@ -1004,6 +1061,67 @@ func TestCreateVersion(t *testing.T) { }, }, }, + { + desc: "Should fail create version when save model returning error", + vars: map[string]string{ + "model_id": "1", + }, + body: models.VersionPost{ + Labels: models.KV{ + "service.type": "GO-FOOD", + "1-targeting_date": "2021-02-01", + "TheQuickBrownFoxJumpsOverTheLazyDogTheQuickBrownFoxJumpsOverThe": "TheQuickBrownFoxJumpsOverTheLazyDogTheQuickBrownFoxJumpsOverThe", + }, + PythonVersion: "3.10.*", + }, + modelsService: func() *mocks.ModelsService { + svc := &mocks.ModelsService{} + svc.On("FindByID", mock.Anything, models.ID(1)).Return(&models.Model{ + ID: models.ID(1), + Name: "model-1", + ProjectID: models.ID(1), + Project: mlp.Project{ + MLFlowTrackingURL: "http://www.notinuse.com", + }, + ExperimentID: 1, + Type: "pyfunc", + MlflowURL: "http://mlflow.com", + Endpoints: nil, + }, nil) + return svc + }, + mlflowClient: func() *mlfmocks.Client { + svc := &mlfmocks.Client{} + svc.On("CreateRun", "1").Return(&mlflow.Run{ + Info: mlflow.Info{ + RunID: "1", + ArtifactURI: "artifact/url/run", + }, + }, nil) + return svc + }, + versionService: func() *mocks.VersionsService { + svc := &mocks.VersionsService{} + svc.On("Save", mock.Anything, &models.Version{ + ModelID: models.ID(1), + RunID: "1", + ArtifactURI: "artifact/url/run", + Labels: models.KV{ + "service.type": "GO-FOOD", + "1-targeting_date": "2021-02-01", + "TheQuickBrownFoxJumpsOverTheLazyDogTheQuickBrownFoxJumpsOverThe": "TheQuickBrownFoxJumpsOverTheLazyDogTheQuickBrownFoxJumpsOverThe", + }, + PythonVersion: "3.10.*", + }, mock.Anything).Return(nil, fmt.Errorf("pq constraint violation")) + return svc + }, + expected: &Response{ + code: http.StatusInternalServerError, + data: Error{ + Message: "Failed to save version: pq constraint violation", + }, + }, + }, { desc: "Should fail label key validation: has emoji inside", vars: map[string]string{ @@ -1471,6 +1589,69 @@ func TestCreateVersion(t *testing.T) { }, }, }, + { + desc: "Should fail create version with model schema, when there is mismatch model if between version and model schema", + vars: map[string]string{ + "model_id": "1", + }, + body: models.VersionPost{ + ModelSchema: &models.ModelSchema{ + Spec: &models.SchemaSpec{ + PredictionIDColumn: "prediction_id", + ModelPredictionOutput: &models.ModelPredictionOutput{ + RankingOutput: &models.RankingOutput{ + PredictionGroudIDColumn: "session_id", + RankScoreColumn: "score", + RelevanceScoreColumn: "relevance_score", + OutputClass: models.Ranking, + }, + }, + FeatureTypes: map[string]models.ValueType{ + "featureA": models.Float64, + "featureB": models.Int64, + "featureC": models.Boolean, + }, + }, + ModelID: models.ID(5), + }, + }, + modelsService: func() *mocks.ModelsService { + svc := &mocks.ModelsService{} + svc.On("FindByID", mock.Anything, models.ID(1)).Return(&models.Model{ + ID: models.ID(1), + Name: "model-1", + ProjectID: models.ID(1), + Project: mlp.Project{ + MLFlowTrackingURL: "http://www.notinuse.com", + }, + ExperimentID: 1, + Type: "pyfunc", + MlflowURL: "http://mlflow.com", + Endpoints: nil, + }, nil) + return svc + }, + mlflowClient: func() *mlfmocks.Client { + svc := &mlfmocks.Client{} + svc.On("CreateRun", "1").Return(&mlflow.Run{ + Info: mlflow.Info{ + RunID: "1", + ArtifactURI: "artifact/url/run", + }, + }, nil) + return svc + }, + versionService: func() *mocks.VersionsService { + svc := &mocks.VersionsService{} + return svc + }, + expected: &Response{ + code: http.StatusBadRequest, + data: Error{ + Message: "Error validating version: mismatch model id between version and model schema", + }, + }, + }, } for _, tC := range testCases { t.Run(tC.desc, func(t *testing.T) { diff --git a/api/models/model_schema.go b/api/models/model_schema.go index 944cdc005..dd82c432e 100644 --- a/api/models/model_schema.go +++ b/api/models/model_schema.go @@ -30,7 +30,7 @@ const ( type ModelSchema struct { ID ID `json:"id"` Spec *SchemaSpec `json:"spec,omitempty"` - ModelID ID `json:"-"` + ModelID ID `json:"model_id"` } type SchemaSpec struct { diff --git a/api/models/version.go b/api/models/version.go index 1bba9ff33..9a212590a 100644 --- a/api/models/version.go +++ b/api/models/version.go @@ -18,6 +18,7 @@ import ( "database/sql/driver" "encoding/json" "errors" + "fmt" "gorm.io/gorm" ) @@ -92,21 +93,30 @@ func (kv *KV) Scan(value interface{}) error { return json.Unmarshal(b, &kv) } -func (v *Version) Patch(patch *VersionPatch) error { +func (v *Version) Validate() error { + if v.CustomPredictor != nil && v.Model.Type == ModelTypeCustom { + if err := v.CustomPredictor.IsValid(); err != nil { + return err + } + } + if v.ModelSchema != nil { + if v.ModelSchema.ModelID > 0 && v.ModelSchema.ModelID != v.ModelID { + return fmt.Errorf("mismatch model id between version and model schema") + } + } + return nil +} + +func (v *Version) Patch(patch *VersionPatch) { if patch.Properties != nil { v.Properties = *patch.Properties } if patch.CustomPredictor != nil && v.Model.Type == ModelTypeCustom { - if err := patch.CustomPredictor.IsValid(); err != nil { - return err - } v.CustomPredictor = patch.CustomPredictor } if patch.ModelSchema != nil { v.ModelSchema = patch.ModelSchema } - - return nil } func (v *Version) BeforeCreate(db *gorm.DB) error { diff --git a/python/sdk/client/api_client.py b/python/sdk/client/api_client.py index 27e801662..5bf60ec80 100644 --- a/python/sdk/client/api_client.py +++ b/python/sdk/client/api_client.py @@ -296,7 +296,6 @@ def response_deserialize( # if not found, look for '1XX', '2XX', etc. response_type = response_types_map.get(str(response_data.status)[0] + "XX", None) - print(f"response status ----- {response_data.status}") if not 200 <= response_data.status <= 299: if response_data.status == 400: raise BadRequestException(http_resp=response_data) @@ -329,7 +328,6 @@ def response_deserialize( match = re.search(r"charset=([a-zA-Z\-\d]+)[\s;]?", content_type) encoding = match.group(1) if match else "utf-8" response_text = response_data.data.decode(encoding) - print(f"response_text ------ {response_text}") return_data = self.deserialize(response_text, response_type) return ApiResponse( From ee29be968fc445c9cf1d7b33f18e25433f0a0f27 Mon Sep 17 00:00:00 2001 From: Tio Pramayudi Date: Fri, 26 Jan 2024 08:35:22 +0700 Subject: [PATCH 3/5] Codegen using discriminator lookup --- openapi-api-codegen.yaml | 3 ++- openapi-sdk-codegen.yaml | 1 + .../models/binary_classification_output.py | 2 +- .../client/models/model_prediction_output.py | 20 +++++++++++++++++++ python/sdk/client/models/ranking_output.py | 2 +- python/sdk/client/models/regression_output.py | 2 +- swagger.yaml | 3 +++ 7 files changed, 29 insertions(+), 4 deletions(-) diff --git a/openapi-api-codegen.yaml b/openapi-api-codegen.yaml index 915f83148..e6766dd59 100644 --- a/openapi-api-codegen.yaml +++ b/openapi-api-codegen.yaml @@ -1,8 +1,9 @@ packageName: client enumClassPrefix: true +useOneOfDiscriminatorLookup: true # Global Properties globalProperties: apiTests: false modelTests: false apiDocs: false - modelDocs: false \ No newline at end of file + modelDocs: false diff --git a/openapi-sdk-codegen.yaml b/openapi-sdk-codegen.yaml index 2c6461cb7..bb05688f4 100644 --- a/openapi-sdk-codegen.yaml +++ b/openapi-sdk-codegen.yaml @@ -1,6 +1,7 @@ projectName: merlin-sdk packageName: client generateSourceCodeOnly: true +useOneOfDiscriminatorLookup: true # Global Properties globalProperties: apiTests: false diff --git a/python/sdk/client/models/binary_classification_output.py b/python/sdk/client/models/binary_classification_output.py index 1d829d260..b99ebf437 100644 --- a/python/sdk/client/models/binary_classification_output.py +++ b/python/sdk/client/models/binary_classification_output.py @@ -35,7 +35,7 @@ class BinaryClassificationOutput(BaseModel): positive_class_label: StrictStr negative_class_label: StrictStr score_threshold: Optional[Union[StrictFloat, StrictInt]] = None - output_class: Optional[ModelPredictionOutputClass] = None + output_class: ModelPredictionOutputClass __properties: ClassVar[List[str]] = ["prediction_score_column", "actual_label_column", "positive_class_label", "negative_class_label", "score_threshold", "output_class"] model_config = { diff --git a/python/sdk/client/models/model_prediction_output.py b/python/sdk/client/models/model_prediction_output.py index 8b1fbc247..2ac361e6a 100644 --- a/python/sdk/client/models/model_prediction_output.py +++ b/python/sdk/client/models/model_prediction_output.py @@ -104,6 +104,26 @@ def from_json(cls, json_str: str) -> Self: error_messages = [] match = 0 + # use oneOf discriminator to lookup the data type + _data_type = json.loads(json_str).get("output_class") + if not _data_type: + raise ValueError("Failed to lookup data type from the field `output_class` in the input.") + + # check if data type is `BinaryClassificationOutput` + if _data_type == "BinaryClassificationOutput": + instance.actual_instance = BinaryClassificationOutput.from_json(json_str) + return instance + + # check if data type is `RankingOutput` + if _data_type == "RankingOutput": + instance.actual_instance = RankingOutput.from_json(json_str) + return instance + + # check if data type is `RegressionOutput` + if _data_type == "RegressionOutput": + instance.actual_instance = RegressionOutput.from_json(json_str) + return instance + # deserialize data into BinaryClassificationOutput try: instance.actual_instance = BinaryClassificationOutput.from_json(json_str) diff --git a/python/sdk/client/models/ranking_output.py b/python/sdk/client/models/ranking_output.py index 426d80add..31f673a75 100644 --- a/python/sdk/client/models/ranking_output.py +++ b/python/sdk/client/models/ranking_output.py @@ -33,7 +33,7 @@ class RankingOutput(BaseModel): rank_score_column: StrictStr prediction_group_id_column: StrictStr relevance_score_column: Optional[StrictStr] = None - output_class: Optional[ModelPredictionOutputClass] = None + output_class: ModelPredictionOutputClass __properties: ClassVar[List[str]] = ["rank_score_column", "prediction_group_id_column", "relevance_score_column", "output_class"] model_config = { diff --git a/python/sdk/client/models/regression_output.py b/python/sdk/client/models/regression_output.py index 74c9fb4cc..ca12db692 100644 --- a/python/sdk/client/models/regression_output.py +++ b/python/sdk/client/models/regression_output.py @@ -32,7 +32,7 @@ class RegressionOutput(BaseModel): """ # noqa: E501 prediction_score_column: StrictStr actual_score_column: Optional[StrictStr] = None - output_class: Optional[ModelPredictionOutputClass] = None + output_class: ModelPredictionOutputClass __properties: ClassVar[List[str]] = ["prediction_score_column", "actual_score_column", "output_class"] model_config = { diff --git a/swagger.yaml b/swagger.yaml index 6727de5a7..3bf626b5a 100644 --- a/swagger.yaml +++ b/swagger.yaml @@ -1423,6 +1423,7 @@ components: - prediction_score_column - positive_class_label - negative_class_label + - output_class properties: prediction_score_column: type: string @@ -1442,6 +1443,7 @@ components: required: - rank_score_column - prediction_group_id_column + - output_class properties: rank_score_column: type: string @@ -1455,6 +1457,7 @@ components: type: object required: - prediction_score_column + - output_class properties: prediction_score_column: type: string From cd90b153809005ea7305f2c0db2c1f0c3cec0adb Mon Sep 17 00:00:00 2001 From: Tio Pramayudi Date: Fri, 26 Jan 2024 09:00:25 +0700 Subject: [PATCH 4/5] Fix sdk test --- python/sdk/test/model_test.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/sdk/test/model_test.py b/python/sdk/test/model_test.py index 6c612c19b..62ff07ded 100644 --- a/python/sdk/test/model_test.py +++ b/python/sdk/test/model_test.py @@ -1333,7 +1333,8 @@ class TestModel: client.RankingOutput( rank_score_column="score", prediction_group_id_column="session_id", - relevance_score_column="relevance_score" + relevance_score_column="relevance_score", + output_class=client.ModelPredictionOutputClass.RANKINGOUTPUT ) ) ) From fc1c8b72ca04f01c7eb48867684fda2835a18296 Mon Sep 17 00:00:00 2001 From: Tio Pramayudi Date: Fri, 26 Jan 2024 09:26:12 +0700 Subject: [PATCH 5/5] Fix integration test --- python/sdk/test/integration_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sdk/test/integration_test.py b/python/sdk/test/integration_test.py index bac19c8f3..07bebd175 100644 --- a/python/sdk/test/integration_test.py +++ b/python/sdk/test/integration_test.py @@ -179,7 +179,7 @@ def test_model_schema( # Upload the serialized model to MLP merlin.log_model(model_dir=model_dir) - assert v.model_schema == model_schema + assert v.model_schema.spec == model_schema.spec endpoint = merlin.deploy(v, deployment_mode=deployment_mode) resp = requests.post(f"{endpoint.url}", json=request_json)