Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Incorrect request body for create and patch version #522

Merged
merged 5 commits into from
Jan 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions api/api/model_schema_api.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,11 @@ func (m *ModelSchemaController) CreateOrUpdateSchema(r *http.Request, vars map[s
if !ok {
return BadRequest("Unable to parse request body")
}

if modelSchema.ModelID > 0 && modelSchema.ModelID != modelID {
return BadRequest("Mismatch model id between request path and body")
}

modelSchema.ModelID = modelID
schema, err := m.ModelSchemaService.Save(ctx, modelSchema)
if err != nil {
Expand Down
124 changes: 124 additions & 0 deletions api/api/model_schema_api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,98 @@ func TestModelSchemaController_CreateOrUpdateSchema(t *testing.T) {
},
},
},
{
desc: "success create ranking schema, with specifying model id",
vars: map[string]string{
"model_id": "1",
},
body: []byte(`{
"spec": {
"prediction_id_column":"prediction_id",
"tag_columns": ["tags"],
"feature_types": {
"featureA": "float64",
"featureB": "int64",
"featureC": "boolean"
},
"model_prediction_output": {
"prediction_group_id_column": "session_id",
"rank_score_column": "score",
"relevance_score": "relevance_score",
"output_class": "RankingOutput"
}
},
"model_id": 1
}`),
modelSchemaService: func() *mocks.ModelSchemaService {
mockSvc := &mocks.ModelSchemaService{}
mockSvc.On("Save", mock.Anything, &models.ModelSchema{
ModelID: models.ID(1),
Spec: &models.SchemaSpec{
PredictionIDColumn: "prediction_id",
TagColumns: []string{"tags"},
FeatureTypes: map[string]models.ValueType{
"featureA": models.Float64,
"featureB": models.Int64,
"featureC": models.Boolean,
},
ModelPredictionOutput: &models.ModelPredictionOutput{
RankingOutput: &models.RankingOutput{
PredictionGroudIDColumn: "session_id",
RankScoreColumn: "score",
RelevanceScoreColumn: "relevance_score",
OutputClass: models.Ranking,
},
},
},
}).Return(&models.ModelSchema{
ID: models.ID(1),
ModelID: models.ID(1),
Spec: &models.SchemaSpec{
PredictionIDColumn: "prediction_id",
TagColumns: []string{"tags"},
FeatureTypes: map[string]models.ValueType{
"featureA": models.Float64,
"featureB": models.Int64,
"featureC": models.Boolean,
},
ModelPredictionOutput: &models.ModelPredictionOutput{
RankingOutput: &models.RankingOutput{
PredictionGroudIDColumn: "session_id",
RankScoreColumn: "score",
RelevanceScoreColumn: "relevance_score",
OutputClass: models.Ranking,
},
},
},
}, nil)
return mockSvc
},
expected: &Response{
code: http.StatusOK,
data: &models.ModelSchema{
ID: models.ID(1),
ModelID: models.ID(1),
Spec: &models.SchemaSpec{
PredictionIDColumn: "prediction_id",
TagColumns: []string{"tags"},
FeatureTypes: map[string]models.ValueType{
"featureA": models.Float64,
"featureB": models.Int64,
"featureC": models.Boolean,
},
ModelPredictionOutput: &models.ModelPredictionOutput{
RankingOutput: &models.RankingOutput{
PredictionGroudIDColumn: "session_id",
RankScoreColumn: "score",
RelevanceScoreColumn: "relevance_score",
OutputClass: models.Ranking,
},
},
},
},
},
},
{
desc: "success create binary classification schema",
vars: map[string]string{
Expand Down Expand Up @@ -607,6 +699,38 @@ func TestModelSchemaController_CreateOrUpdateSchema(t *testing.T) {
data: Error{Message: "Error save model schema: peer connection is reset"},
},
},
{
desc: "model id mismatch",
vars: map[string]string{
"model_id": "1",
},
body: []byte(`{
"spec": {
"prediction_id_column":"prediction_id",
"tag_columns": ["tags"],
"feature_types": {
"featureA": "float64",
"featureB": "int64",
"featureC": "boolean"
},
"model_prediction_output": {
"prediction_group_id_column": "session_id",
"rank_score_column": "score",
"relevance_score": "relevance_score",
"output_class": "RankingOutput"
}
},
"model_id": 2
}`),
modelSchemaService: func() *mocks.ModelSchemaService {
mockSvc := &mocks.ModelSchemaService{}
return mockSvc
},
expected: &Response{
code: http.StatusBadRequest,
data: Error{Message: "Mismatch model id between request path and body"},
},
},
}
for _, tt := range tests {
t.Run(tt.desc, func(t *testing.T) {
Expand Down
14 changes: 11 additions & 3 deletions api/api/versions_api.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ func (c *VersionsController) PatchVersion(r *http.Request, vars map[string]strin
return InternalServerError("Unable to parse request body")
}

if err := v.Patch(versionPatch); err != nil {
v.Patch(versionPatch)
if err := v.Validate(); err != nil {
return BadRequest(fmt.Sprintf("Error validating version: %v", err))
}

Expand Down Expand Up @@ -105,7 +106,6 @@ func (c *VersionsController) ListVersions(r *http.Request, vars map[string]strin

func (c *VersionsController) CreateVersion(r *http.Request, vars map[string]string, body interface{}) *Response {
ctx := r.Context()

versionPost, ok := body.(*models.VersionPost)
if !ok {
return BadRequest("Unable to parse request body")
Expand Down Expand Up @@ -137,9 +137,17 @@ func (c *VersionsController) CreateVersion(r *http.Request, vars map[string]stri
ArtifactURI: run.Info.ArtifactURI,
Labels: versionPost.Labels,
PythonVersion: versionPost.PythonVersion,
ModelSchema: versionPost.ModelSchema,
}

if err := version.Validate(); err != nil {
return BadRequest(fmt.Sprintf("Error validating version: %v", err))
}

version, _ = c.VersionsService.Save(ctx, version, c.FeatureToggleConfig.MonitoringConfig)
version, err = c.VersionsService.Save(ctx, version, c.FeatureToggleConfig.MonitoringConfig)
if err != nil {
return InternalServerError(fmt.Sprintf("Failed to save version: %v", err))
}
return Created(version)
}

Expand Down
Loading
Loading