Skip to content

Commit

Permalink
fix: Incorrect request body for create and patch version (#522)
Browse files Browse the repository at this point in the history
<!--  Thanks for sending a pull request!  Here are some tips for you:

1. Run unit tests and ensure that they are passing
2. If your change introduces any API changes, make sure to update the
e2e tests
3. Make sure documentation is updated for your PR!

-->
# Description
<!-- Briefly describe the motivation for the change. Please include
illustrations where appropriate. -->
Previously in this #518 we
introduce CRUD API for model schema, also changed the version schema
that aim to update model schema during version creation. While CRUID API
is working correctly, create model version with model schema info is not
working properly, due to the request body type that is used for the
controller is not `Version` but `VersionPost`. Hence this PR try to fix
that issue
# Modifications
<!-- Summarize the key code changes. -->
* `api/api/version_api.go` -> Update the model version creation by
supplying model schema data from `VersionPost` to `Version`
* `api/models/version.go` -> Including `ModelSchema` field for
`VersionPost` and `VersionPatch` struct to supply model schema info
during model version creation or patch
* `python/sdk/merlin/model_schema.py` -> Add `output_class` field to
model prediction output
* `python/sdk/test/integration_test.py` -> Add new integration test for
model schema
# Tests
<!-- Besides the existing / updated automated tests, what specific
scenarios should be tested? Consider the backward compatibility of the
changes, whether corner cases are covered, etc. Please describe the
tests and check the ones that have been completed. Eg:
- [x] Deploying new and existing standard models
- [ ] Deploying PyFunc models
-->

# Checklist
- [ ] Added PR label
- [x] Added unit test, integration, and/or e2e tests
- [ ] Tested locally
- [ ] Updated documentation
- [ ] Update Swagger spec if the PR introduce API changes
- [ ] Regenerated Golang and Python client if the PR introduces API
changes

# Release Notes
<!--
Does this PR introduce a user-facing change?
If no, just write "NONE" in the release-note block below.
If yes, a release note is required. Enter your extended release note in
the block below.
If the PR requires additional action from users switching to the new
release, include the string "action required".

For more information about release notes, see kubernetes' guide here:
http://git.k8s.io/community/contributors/guide/release-notes.md
-->

```release-note

```
  • Loading branch information
tiopramayudi authored Jan 26, 2024
1 parent bfd99e2 commit 9f3b161
Show file tree
Hide file tree
Showing 18 changed files with 746 additions and 25 deletions.
5 changes: 5 additions & 0 deletions api/api/model_schema_api.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,11 @@ func (m *ModelSchemaController) CreateOrUpdateSchema(r *http.Request, vars map[s
if !ok {
return BadRequest("Unable to parse request body")
}

if modelSchema.ModelID > 0 && modelSchema.ModelID != modelID {
return BadRequest("Mismatch model id between request path and body")
}

modelSchema.ModelID = modelID
schema, err := m.ModelSchemaService.Save(ctx, modelSchema)
if err != nil {
Expand Down
124 changes: 124 additions & 0 deletions api/api/model_schema_api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,98 @@ func TestModelSchemaController_CreateOrUpdateSchema(t *testing.T) {
},
},
},
{
desc: "success create ranking schema, with specifying model id",
vars: map[string]string{
"model_id": "1",
},
body: []byte(`{
"spec": {
"prediction_id_column":"prediction_id",
"tag_columns": ["tags"],
"feature_types": {
"featureA": "float64",
"featureB": "int64",
"featureC": "boolean"
},
"model_prediction_output": {
"prediction_group_id_column": "session_id",
"rank_score_column": "score",
"relevance_score": "relevance_score",
"output_class": "RankingOutput"
}
},
"model_id": 1
}`),
modelSchemaService: func() *mocks.ModelSchemaService {
mockSvc := &mocks.ModelSchemaService{}
mockSvc.On("Save", mock.Anything, &models.ModelSchema{
ModelID: models.ID(1),
Spec: &models.SchemaSpec{
PredictionIDColumn: "prediction_id",
TagColumns: []string{"tags"},
FeatureTypes: map[string]models.ValueType{
"featureA": models.Float64,
"featureB": models.Int64,
"featureC": models.Boolean,
},
ModelPredictionOutput: &models.ModelPredictionOutput{
RankingOutput: &models.RankingOutput{
PredictionGroudIDColumn: "session_id",
RankScoreColumn: "score",
RelevanceScoreColumn: "relevance_score",
OutputClass: models.Ranking,
},
},
},
}).Return(&models.ModelSchema{
ID: models.ID(1),
ModelID: models.ID(1),
Spec: &models.SchemaSpec{
PredictionIDColumn: "prediction_id",
TagColumns: []string{"tags"},
FeatureTypes: map[string]models.ValueType{
"featureA": models.Float64,
"featureB": models.Int64,
"featureC": models.Boolean,
},
ModelPredictionOutput: &models.ModelPredictionOutput{
RankingOutput: &models.RankingOutput{
PredictionGroudIDColumn: "session_id",
RankScoreColumn: "score",
RelevanceScoreColumn: "relevance_score",
OutputClass: models.Ranking,
},
},
},
}, nil)
return mockSvc
},
expected: &Response{
code: http.StatusOK,
data: &models.ModelSchema{
ID: models.ID(1),
ModelID: models.ID(1),
Spec: &models.SchemaSpec{
PredictionIDColumn: "prediction_id",
TagColumns: []string{"tags"},
FeatureTypes: map[string]models.ValueType{
"featureA": models.Float64,
"featureB": models.Int64,
"featureC": models.Boolean,
},
ModelPredictionOutput: &models.ModelPredictionOutput{
RankingOutput: &models.RankingOutput{
PredictionGroudIDColumn: "session_id",
RankScoreColumn: "score",
RelevanceScoreColumn: "relevance_score",
OutputClass: models.Ranking,
},
},
},
},
},
},
{
desc: "success create binary classification schema",
vars: map[string]string{
Expand Down Expand Up @@ -607,6 +699,38 @@ func TestModelSchemaController_CreateOrUpdateSchema(t *testing.T) {
data: Error{Message: "Error save model schema: peer connection is reset"},
},
},
{
desc: "model id mismatch",
vars: map[string]string{
"model_id": "1",
},
body: []byte(`{
"spec": {
"prediction_id_column":"prediction_id",
"tag_columns": ["tags"],
"feature_types": {
"featureA": "float64",
"featureB": "int64",
"featureC": "boolean"
},
"model_prediction_output": {
"prediction_group_id_column": "session_id",
"rank_score_column": "score",
"relevance_score": "relevance_score",
"output_class": "RankingOutput"
}
},
"model_id": 2
}`),
modelSchemaService: func() *mocks.ModelSchemaService {
mockSvc := &mocks.ModelSchemaService{}
return mockSvc
},
expected: &Response{
code: http.StatusBadRequest,
data: Error{Message: "Mismatch model id between request path and body"},
},
},
}
for _, tt := range tests {
t.Run(tt.desc, func(t *testing.T) {
Expand Down
14 changes: 11 additions & 3 deletions api/api/versions_api.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ func (c *VersionsController) PatchVersion(r *http.Request, vars map[string]strin
return InternalServerError("Unable to parse request body")
}

if err := v.Patch(versionPatch); err != nil {
v.Patch(versionPatch)
if err := v.Validate(); err != nil {
return BadRequest(fmt.Sprintf("Error validating version: %v", err))
}

Expand Down Expand Up @@ -105,7 +106,6 @@ func (c *VersionsController) ListVersions(r *http.Request, vars map[string]strin

func (c *VersionsController) CreateVersion(r *http.Request, vars map[string]string, body interface{}) *Response {
ctx := r.Context()

versionPost, ok := body.(*models.VersionPost)
if !ok {
return BadRequest("Unable to parse request body")
Expand Down Expand Up @@ -137,9 +137,17 @@ func (c *VersionsController) CreateVersion(r *http.Request, vars map[string]stri
ArtifactURI: run.Info.ArtifactURI,
Labels: versionPost.Labels,
PythonVersion: versionPost.PythonVersion,
ModelSchema: versionPost.ModelSchema,
}

if err := version.Validate(); err != nil {
return BadRequest(fmt.Sprintf("Error validating version: %v", err))
}

version, _ = c.VersionsService.Save(ctx, version, c.FeatureToggleConfig.MonitoringConfig)
version, err = c.VersionsService.Save(ctx, version, c.FeatureToggleConfig.MonitoringConfig)
if err != nil {
return InternalServerError(fmt.Sprintf("Failed to save version: %v", err))
}
return Created(version)
}

Expand Down
Loading

0 comments on commit 9f3b161

Please sign in to comment.