Skip to content

Commit

Permalink
fix model schema after refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
khorshuheng committed Feb 23, 2024
1 parent f8cc083 commit 5705004
Show file tree
Hide file tree
Showing 18 changed files with 300 additions and 219 deletions.
129 changes: 82 additions & 47 deletions api/api/model_schema_api_test.go

Large diffs are not rendered by default.

90 changes: 50 additions & 40 deletions api/api/versions_api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -722,12 +722,13 @@ func TestPatchVersion(t *testing.T) {
ModelSchema: &models.ModelSchema{
Spec: &models.SchemaSpec{
PredictionIDColumn: "prediction_id",
SessionIDColumn: "session_id",
RowIDColumn: "row_id",
ModelPredictionOutput: &models.ModelPredictionOutput{
RankingOutput: &models.RankingOutput{
PredictionGroupIDColumn: "session_id",
RankScoreColumn: "score",
RelevanceScoreColumn: "relevance_score",
OutputClass: models.Ranking,
RankScoreColumn: "score",
RelevanceScoreColumn: "relevance_score",
OutputClass: models.Ranking,
},
},
FeatureTypes: map[string]models.ValueType{
Expand Down Expand Up @@ -776,12 +777,13 @@ func TestPatchVersion(t *testing.T) {
ModelSchema: &models.ModelSchema{
Spec: &models.SchemaSpec{
PredictionIDColumn: "prediction_id",
SessionIDColumn: "session_id",
RowIDColumn: "row_id",
ModelPredictionOutput: &models.ModelPredictionOutput{
RankingOutput: &models.RankingOutput{
PredictionGroupIDColumn: "session_id",
RankScoreColumn: "score",
RelevanceScoreColumn: "relevance_score",
OutputClass: models.Ranking,
RankScoreColumn: "score",
RelevanceScoreColumn: "relevance_score",
OutputClass: models.Ranking,
},
},
FeatureTypes: map[string]models.ValueType{
Expand Down Expand Up @@ -812,12 +814,13 @@ func TestPatchVersion(t *testing.T) {
ModelSchema: &models.ModelSchema{
Spec: &models.SchemaSpec{
PredictionIDColumn: "prediction_id",
SessionIDColumn: "session_id",
RowIDColumn: "row_id",
ModelPredictionOutput: &models.ModelPredictionOutput{
RankingOutput: &models.RankingOutput{
PredictionGroupIDColumn: "session_id",
RankScoreColumn: "score",
RelevanceScoreColumn: "relevance_score",
OutputClass: models.Ranking,
RankScoreColumn: "score",
RelevanceScoreColumn: "relevance_score",
OutputClass: models.Ranking,
},
},
FeatureTypes: map[string]models.ValueType{
Expand Down Expand Up @@ -853,12 +856,13 @@ func TestPatchVersion(t *testing.T) {
ModelSchema: &models.ModelSchema{
Spec: &models.SchemaSpec{
PredictionIDColumn: "prediction_id",
SessionIDColumn: "session_id",
RowIDColumn: "row_id",
ModelPredictionOutput: &models.ModelPredictionOutput{
RankingOutput: &models.RankingOutput{
PredictionGroupIDColumn: "session_id",
RankScoreColumn: "score",
RelevanceScoreColumn: "relevance_score",
OutputClass: models.Ranking,
RankScoreColumn: "score",
RelevanceScoreColumn: "relevance_score",
OutputClass: models.Ranking,
},
},
FeatureTypes: map[string]models.ValueType{
Expand Down Expand Up @@ -886,12 +890,13 @@ func TestPatchVersion(t *testing.T) {
ModelSchema: &models.ModelSchema{
Spec: &models.SchemaSpec{
PredictionIDColumn: "prediction_id",
SessionIDColumn: "session_id",
RowIDColumn: "row_id",
ModelPredictionOutput: &models.ModelPredictionOutput{
RankingOutput: &models.RankingOutput{
PredictionGroupIDColumn: "session_id",
RankScoreColumn: "score",
RelevanceScoreColumn: "relevance_score",
OutputClass: models.Ranking,
RankScoreColumn: "score",
RelevanceScoreColumn: "relevance_score",
OutputClass: models.Ranking,
},
},
FeatureTypes: map[string]models.ValueType{
Expand Down Expand Up @@ -1446,12 +1451,13 @@ func TestCreateVersion(t *testing.T) {
ModelSchema: &models.ModelSchema{
Spec: &models.SchemaSpec{
PredictionIDColumn: "prediction_id",
SessionIDColumn: "session_id",
RowIDColumn: "row_id",
ModelPredictionOutput: &models.ModelPredictionOutput{
RankingOutput: &models.RankingOutput{
PredictionGroupIDColumn: "session_id",
RankScoreColumn: "score",
RelevanceScoreColumn: "relevance_score",
OutputClass: models.Ranking,
RankScoreColumn: "score",
RelevanceScoreColumn: "relevance_score",
OutputClass: models.Ranking,
},
},
FeatureTypes: map[string]models.ValueType{
Expand Down Expand Up @@ -1499,12 +1505,13 @@ func TestCreateVersion(t *testing.T) {
ModelSchema: &models.ModelSchema{
Spec: &models.SchemaSpec{
PredictionIDColumn: "prediction_id",
SessionIDColumn: "session_id",
RowIDColumn: "row_id",
ModelPredictionOutput: &models.ModelPredictionOutput{
RankingOutput: &models.RankingOutput{
PredictionGroupIDColumn: "session_id",
RankScoreColumn: "score",
RelevanceScoreColumn: "relevance_score",
OutputClass: models.Ranking,
RankScoreColumn: "score",
RelevanceScoreColumn: "relevance_score",
OutputClass: models.Ranking,
},
},
FeatureTypes: map[string]models.ValueType{
Expand Down Expand Up @@ -1532,12 +1539,13 @@ func TestCreateVersion(t *testing.T) {
ModelSchema: &models.ModelSchema{
Spec: &models.SchemaSpec{
PredictionIDColumn: "prediction_id",
SessionIDColumn: "session_id",
RowIDColumn: "row_id",
ModelPredictionOutput: &models.ModelPredictionOutput{
RankingOutput: &models.RankingOutput{
PredictionGroupIDColumn: "session_id",
RankScoreColumn: "score",
RelevanceScoreColumn: "relevance_score",
OutputClass: models.Ranking,
RankScoreColumn: "score",
RelevanceScoreColumn: "relevance_score",
OutputClass: models.Ranking,
},
},
FeatureTypes: map[string]models.ValueType{
Expand Down Expand Up @@ -1570,12 +1578,13 @@ func TestCreateVersion(t *testing.T) {
ModelSchema: &models.ModelSchema{
Spec: &models.SchemaSpec{
PredictionIDColumn: "prediction_id",
SessionIDColumn: "session_id",
RowIDColumn: "row_id",
ModelPredictionOutput: &models.ModelPredictionOutput{
RankingOutput: &models.RankingOutput{
PredictionGroupIDColumn: "session_id",
RankScoreColumn: "score",
RelevanceScoreColumn: "relevance_score",
OutputClass: models.Ranking,
RankScoreColumn: "score",
RelevanceScoreColumn: "relevance_score",
OutputClass: models.Ranking,
},
},
FeatureTypes: map[string]models.ValueType{
Expand All @@ -1598,12 +1607,13 @@ func TestCreateVersion(t *testing.T) {
ModelSchema: &models.ModelSchema{
Spec: &models.SchemaSpec{
PredictionIDColumn: "prediction_id",
SessionIDColumn: "session_id",
RowIDColumn: "row_id",
ModelPredictionOutput: &models.ModelPredictionOutput{
RankingOutput: &models.RankingOutput{
PredictionGroupIDColumn: "session_id",
RankScoreColumn: "score",
RelevanceScoreColumn: "relevance_score",
OutputClass: models.Ranking,
RankScoreColumn: "score",
RelevanceScoreColumn: "relevance_score",
OutputClass: models.Ranking,
},
},
FeatureTypes: map[string]models.ValueType{
Expand Down
36 changes: 14 additions & 22 deletions api/client/model_binary_classification_output.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

14 changes: 8 additions & 6 deletions api/cluster/resource/templater_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,8 @@ func TestCreateInferenceServiceSpec(t *testing.T) {
ModelID: models.ID(1),
Spec: &models.SchemaSpec{
PredictionIDColumn: "prediction_id",
SessionIDColumn: "session_id",
RowIDColumn: "row_id",
TagColumns: []string{"tags"},
FeatureTypes: map[string]models.ValueType{
"featureA": models.Float64,
Expand All @@ -241,9 +243,8 @@ func TestCreateInferenceServiceSpec(t *testing.T) {
},
ModelPredictionOutput: &models.ModelPredictionOutput{
RankingOutput: &models.RankingOutput{
PredictionGroupIDColumn: "session_id",
RankScoreColumn: "score",
RelevanceScoreColumn: "relevance_score",
RankScoreColumn: "score",
RelevanceScoreColumn: "relevance_score",
},
},
},
Expand Down Expand Up @@ -859,6 +860,8 @@ func TestCreateInferenceServiceSpec(t *testing.T) {
ModelID: models.ID(1),
Spec: &models.SchemaSpec{
PredictionIDColumn: "prediction_id",
SessionIDColumn: "session_id",
RowIDColumn: "row_id",
TagColumns: []string{"tags"},
FeatureTypes: map[string]models.ValueType{
"featureA": models.Float64,
Expand All @@ -868,9 +871,8 @@ func TestCreateInferenceServiceSpec(t *testing.T) {
},
ModelPredictionOutput: &models.ModelPredictionOutput{
RankingOutput: &models.RankingOutput{
PredictionGroupIDColumn: "session_id",
RankScoreColumn: "score",
RelevanceScoreColumn: "relevance_score",
RankScoreColumn: "score",
RelevanceScoreColumn: "relevance_score",
},
},
},
Expand Down
11 changes: 6 additions & 5 deletions api/models/model_schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ type ModelSchema struct {
// SchemaSpec
type SchemaSpec struct {
PredictionIDColumn string `json:"prediction_id_column"`
SessionIDColumn string `json:"session_id_column"`
RowIDColumn string `json:"row_id_column"`
ModelPredictionOutput *ModelPredictionOutput `json:"model_prediction_output"`
TagColumns []string `json:"tag_columns"`
FeatureTypes map[string]ValueType `json:"feature_types"`
Expand Down Expand Up @@ -125,7 +127,7 @@ func (m ModelPredictionOutput) MarshalJSON() ([]byte, error) {

// BinaryClassificationOutput is specification for prediction of binary classification model
type BinaryClassificationOutput struct {
ActualLabelColumn string `json:"actual_label_column"`
ActualScoreColumn string `json:"actual_score_column"`
NegativeClassLabel string `json:"negative_class_label"`
PredictionScoreColumn string `json:"prediction_score_column"`
PredictionLabelColumn string `json:"prediction_label_column"`
Expand All @@ -136,10 +138,9 @@ type BinaryClassificationOutput struct {

// RankingOutput is specification for prediction of ranking model
type RankingOutput struct {
PredictionGroupIDColumn string `json:"prediction_group_id_column"`
RankScoreColumn string `json:"rank_score_column"`
RelevanceScoreColumn string `json:"relevance_score_column"`
OutputClass ModelPredictionOutputClass `json:"output_class" validate:"required"`
RankScoreColumn string `json:"rank_score_column"`
RelevanceScoreColumn string `json:"relevance_score_column"`
OutputClass ModelPredictionOutputClass `json:"output_class" validate:"required"`
}

// Regression is specification for prediction of regression model
Expand Down
Loading

0 comments on commit 5705004

Please sign in to comment.