Skip to content

Commit

Permalink
Adding output class
Browse files Browse the repository at this point in the history
  • Loading branch information
tiopramayudi committed Jan 22, 2024
1 parent eeb03ba commit dcf04a9
Show file tree
Hide file tree
Showing 29 changed files with 1,714 additions and 1,059 deletions.
52 changes: 48 additions & 4 deletions api/api/model_schema_api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (

"github.com/caraml-dev/merlin/models"
"github.com/caraml-dev/merlin/pkg/errors"
internalValidator "github.com/caraml-dev/merlin/pkg/validator"
"github.com/caraml-dev/merlin/service/mocks"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -295,7 +296,8 @@ func TestModelSchemaController_CreateOrUpdateSchema(t *testing.T) {
"model_prediction_output": {
"prediction_group_id_column": "session_id",
"rank_score_column": "score",
"relevance_score": "relevance_score"
"relevance_score": "relevance_score",
"output_class": "RankingOutput"
}
}
}`),
Expand All @@ -316,6 +318,7 @@ func TestModelSchemaController_CreateOrUpdateSchema(t *testing.T) {
PredictionGroudIDColumn: "session_id",
RankScoreColumn: "score",
RelevanceScoreColumn: "relevance_score",
OutputClass: models.Ranking,
},
},
},
Expand All @@ -335,6 +338,7 @@ func TestModelSchemaController_CreateOrUpdateSchema(t *testing.T) {
PredictionGroudIDColumn: "session_id",
RankScoreColumn: "score",
RelevanceScoreColumn: "relevance_score",
OutputClass: models.Ranking,
},
},
},
Expand All @@ -359,6 +363,7 @@ func TestModelSchemaController_CreateOrUpdateSchema(t *testing.T) {
PredictionGroudIDColumn: "session_id",
RankScoreColumn: "score",
RelevanceScoreColumn: "relevance_score",
OutputClass: models.Ranking,
},
},
},
Expand All @@ -384,7 +389,8 @@ func TestModelSchemaController_CreateOrUpdateSchema(t *testing.T) {
"negative_class_label": "negative",
"prediction_score_column": "prediction_score",
"prediction_label_column": "prediction_label",
"positive_class_label": "positive"
"positive_class_label": "positive",
"output_class": "BinaryClassificationOutput"
}
}
}`),
Expand All @@ -407,6 +413,7 @@ func TestModelSchemaController_CreateOrUpdateSchema(t *testing.T) {
PredictionScoreColumn: "prediction_score",
PredictionLabelColumn: "prediction_label",
PositiveClassLabel: "positive",
OutputClass: models.BinaryClassification,
},
},
},
Expand All @@ -428,6 +435,7 @@ func TestModelSchemaController_CreateOrUpdateSchema(t *testing.T) {
PredictionScoreColumn: "prediction_score",
PredictionLabelColumn: "prediction_label",
PositiveClassLabel: "positive",
OutputClass: models.BinaryClassification,
},
},
},
Expand All @@ -454,6 +462,7 @@ func TestModelSchemaController_CreateOrUpdateSchema(t *testing.T) {
PredictionScoreColumn: "prediction_score",
PredictionLabelColumn: "prediction_label",
PositiveClassLabel: "positive",
OutputClass: models.BinaryClassification,
},
},
},
Expand All @@ -476,7 +485,8 @@ func TestModelSchemaController_CreateOrUpdateSchema(t *testing.T) {
},
"model_prediction_output": {
"prediction_score_column": "prediction_score",
"actual_score_column": "actual_score"
"actual_score_column": "actual_score",
"output_class": "RegressionOutput"
}
}
}`),
Expand All @@ -496,6 +506,7 @@ func TestModelSchemaController_CreateOrUpdateSchema(t *testing.T) {
RegressionOutput: &models.RegressionOutput{
PredictionScoreColumn: "prediction_score",
ActualScoreColumn: "actual_score",
OutputClass: models.Regression,
},
},
},
Expand All @@ -514,6 +525,7 @@ func TestModelSchemaController_CreateOrUpdateSchema(t *testing.T) {
RegressionOutput: &models.RegressionOutput{
PredictionScoreColumn: "prediction_score",
ActualScoreColumn: "actual_score",
OutputClass: models.Regression,
},
},
},
Expand All @@ -537,6 +549,7 @@ func TestModelSchemaController_CreateOrUpdateSchema(t *testing.T) {
RegressionOutput: &models.RegressionOutput{
PredictionScoreColumn: "prediction_score",
ActualScoreColumn: "actual_score",
OutputClass: models.Regression,
},
},
},
Expand All @@ -560,7 +573,8 @@ func TestModelSchemaController_CreateOrUpdateSchema(t *testing.T) {
"model_prediction_output": {
"prediction_group_id_column": "session_id",
"rank_score_column": "score",
"relevance_score": "relevance_score"
"relevance_score": "relevance_score",
"output_class": "RankingOutput"
}
}
}`),
Expand All @@ -581,6 +595,7 @@ func TestModelSchemaController_CreateOrUpdateSchema(t *testing.T) {
PredictionGroudIDColumn: "session_id",
RankScoreColumn: "score",
RelevanceScoreColumn: "relevance_score",
OutputClass: models.Ranking,
},
},
},
Expand All @@ -603,12 +618,41 @@ func TestModelSchemaController_CreateOrUpdateSchema(t *testing.T) {
var modelSchema *models.ModelSchema
err := json.Unmarshal(tt.body, &modelSchema)
require.NoError(t, err)

validate, _ := internalValidator.NewValidator()
err = validate.Struct(modelSchema)
require.NoError(t, err)

resp := ctrl.CreateOrUpdateSchema(&http.Request{}, tt.vars, modelSchema)
assertEqualResponses(t, tt.expected, resp)
})
}
}

func Benchmark_Unmarshal(b *testing.B) {
data := []byte(` {
"prediction_id_column":"prediction_id",
"tag_columns": ["tags"],
"feature_types": {
"featureA": "float64",
"featureB": "int64",
"featureC": "boolean"
},
"model_prediction_output": {
"actual_label_column": "actual_label",
"negative_class_label": "negative",
"prediction_score_column": "prediction_score",
"prediction_label_column": "prediction_label",
"positive_class_label": "positive",
"output_class": "BinaryClassificationOutput"
}
}`)
for i := 0; i < b.N; i++ {
var schemaSpec models.SchemaSpec
_ = json.Unmarshal(data, &schemaSpec)
}
}

func TestModelSchemaController_DeleteSchema(t *testing.T) {
tests := []struct {
desc string
Expand Down
152 changes: 130 additions & 22 deletions api/client/api_model_schema.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit dcf04a9

Please sign in to comment.