Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
tiopramayudi committed Jan 30, 2024
1 parent e97330c commit 2183242
Show file tree
Hide file tree
Showing 9 changed files with 217 additions and 100 deletions.
8 changes: 8 additions & 0 deletions api/api/model_schema_api.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,12 @@ import (
mErrors "github.com/caraml-dev/merlin/pkg/errors"
)

// ModelSchemaController
type ModelSchemaController struct {
*AppContext
}

// GetAllSchemas list all model schemas given model ID
func (m *ModelSchemaController) GetAllSchemas(r *http.Request, vars map[string]string, _ interface{}) *Response {
ctx := r.Context()
modelID, _ := models.ParseID(vars["model_id"])
Expand All @@ -26,6 +28,7 @@ func (m *ModelSchemaController) GetAllSchemas(r *http.Request, vars map[string]s
return Ok(modelSchemas)
}

// GetSchema get detail of a model schema given the schema id and model id
func (m *ModelSchemaController) GetSchema(r *http.Request, vars map[string]string, _ interface{}) *Response {
ctx := r.Context()
modelID, _ := models.ParseID(vars["model_id"])
Expand All @@ -41,6 +44,10 @@ func (m *ModelSchemaController) GetSchema(r *http.Request, vars map[string]strin
return Ok(modelSchema)
}

// CreateOrUpdateSchema upsert schema
// If ID is not defined it will create new model schema
// If ID is defined but not exist, it will create new model schema
// If ID is defined and exist, it will update the existing model schema associated with that ID
func (m *ModelSchemaController) CreateOrUpdateSchema(r *http.Request, vars map[string]string, body interface{}) *Response {
ctx := r.Context()
modelID, _ := models.ParseID(vars["model_id"])
Expand All @@ -62,6 +69,7 @@ func (m *ModelSchemaController) CreateOrUpdateSchema(r *http.Request, vars map[s
return Ok(schema)
}

// DeleteSchema delete model schema given schema id and model id
func (m *ModelSchemaController) DeleteSchema(r *http.Request, vars map[string]string, _ interface{}) *Response {
ctx := r.Context()
modelID, _ := models.ParseID(vars["model_id"])
Expand Down
15 changes: 13 additions & 2 deletions api/models/model_schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@ import (
"fmt"
)

type InferenceType string

// ModelPredictionOutputClass is type for kinds of model type
type ModelPredictionOutputClass string

const (
Expand All @@ -18,6 +17,7 @@ const (
Ranking ModelPredictionOutputClass = "RankingOutput"
)

// Value type is type that represent type of the value
type ValueType string

const (
Expand All @@ -27,23 +27,29 @@ const (
String ValueType = "string"
)

// ModelSchema
type ModelSchema struct {
ID ID `json:"id"`
Spec *SchemaSpec `json:"spec,omitempty"`
ModelID ID `json:"model_id"`
}

// SchemaSpec
type SchemaSpec struct {
PredictionIDColumn string `json:"prediction_id_column"`
ModelPredictionOutput *ModelPredictionOutput `json:"model_prediction_output"`
TagColumns []string `json:"tag_columns"`
FeatureTypes map[string]ValueType `json:"feature_types"`
}

// Value returning a value for `SchemaSpec` instance
// This is required to be implemented when this instance is treated as JSONB column
func (s SchemaSpec) Value() (driver.Value, error) {
return json.Marshal(s)
}

// Scan returning error when assigning value from db driver is failing
// This is required to be implemented when this instance is treated as JSONB column
func (s *SchemaSpec) Scan(value interface{}) error {
b, ok := value.([]byte)
if !ok {
Expand All @@ -65,6 +71,7 @@ func newStrictDecoder(data []byte) *json.Decoder {
return dec
}

// UnmarshalJSON custom deserialization of bytes into `ModelPredictionOutput`
func (m *ModelPredictionOutput) UnmarshalJSON(data []byte) error {
var err error
outputClassStruct := struct {
Expand Down Expand Up @@ -99,6 +106,7 @@ func (m *ModelPredictionOutput) UnmarshalJSON(data []byte) error {
return nil
}

// MarshalJSON custom serialization of `ModelPredictionOutput` into json byte
func (m ModelPredictionOutput) MarshalJSON() ([]byte, error) {
if m.BinaryClassificationOutput != nil {
return json.Marshal(&m.BinaryClassificationOutput)
Expand All @@ -115,6 +123,7 @@ func (m ModelPredictionOutput) MarshalJSON() ([]byte, error) {
return nil, nil
}

// BinaryClassificationOutput is specification for prediction of binary classification model
type BinaryClassificationOutput struct {
ActualLabelColumn string `json:"actual_label_column"`
NegativeClassLabel string `json:"negative_class_label"`
Expand All @@ -125,13 +134,15 @@ type BinaryClassificationOutput struct {
OutputClass ModelPredictionOutputClass `json:"output_class" validate:"required"`
}

// RankingOutput is specification for prediction of ranking model
type RankingOutput struct {
PredictionGroudIDColumn string `json:"prediction_group_id_column"`
RankScoreColumn string `json:"rank_score_column"`
RelevanceScoreColumn string `json:"relevance_score"`
OutputClass ModelPredictionOutputClass `json:"output_class" validate:"required"`
}

// Regression is specification for prediction of regression model
type RegressionOutput struct {
PredictionScoreColumn string `json:"prediction_score_column"`
ActualScoreColumn string `json:"actual_score_column"`
Expand Down
7 changes: 7 additions & 0 deletions api/models/version.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"gorm.io/gorm"
)

// Version
type Version struct {
ID ID `json:"id" gorm:"primary_key"`
ModelID ID `json:"model_id" gorm:"primary_key"`
Expand All @@ -40,18 +41,21 @@ type Version struct {
CreatedUpdated
}

// VersionPost contains all information that is used during version creation
type VersionPost struct {
Labels KV `json:"labels" gorm:"labels"`
PythonVersion string `json:"python_version" gorm:"python_version"`
ModelSchema *ModelSchema `json:"model_schema"`
}

// VersionPatch contains all information that is used during version update or patch
type VersionPatch struct {
Properties *KV `json:"properties,omitempty"`
CustomPredictor *CustomPredictor `json:"custom_predictor,omitempty"`
ModelSchema *ModelSchema `json:"model_schema"`
}

// CustomPredictor contains configuration for custom model
type CustomPredictor struct {
Image string `json:"image"`
Command string `json:"command"`
Expand Down Expand Up @@ -93,6 +97,7 @@ func (kv *KV) Scan(value interface{}) error {
return json.Unmarshal(b, &kv)
}

// Validate do validation on the value of version
func (v *Version) Validate() error {
if v.CustomPredictor != nil && v.Model.Type == ModelTypeCustom {
if err := v.CustomPredictor.IsValid(); err != nil {
Expand All @@ -107,6 +112,7 @@ func (v *Version) Validate() error {
return nil
}

// Patch version value
func (v *Version) Patch(patch *VersionPatch) {
if patch.Properties != nil {
v.Properties = *patch.Properties
Expand All @@ -119,6 +125,7 @@ func (v *Version) Patch(patch *VersionPatch) {
}
}

// BeforeCreate find the latest persisted ID from version DB and increament it and assign to the receiver
func (v *Version) BeforeCreate(db *gorm.DB) error {
if v.ID == 0 {
var maxModelVersionID int
Expand Down
6 changes: 6 additions & 0 deletions api/service/model_schema_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,23 @@ import (
"gorm.io/gorm"
)

// ModelSchemaService interface
type ModelSchemaService interface {
// List all the model schemas for a model
List(ctx context.Context, modelID models.ID) ([]*models.ModelSchema, error)
// Save model schema, it can be create or update existing schema
Save(ctx context.Context, modelSchema *models.ModelSchema) (*models.ModelSchema, error)
// Delete a model schema
Delete(ctx context.Context, modelSchema *models.ModelSchema) error
// FindByID get schema given it's schema id and model id
FindByID(ctx context.Context, modelSchemaID models.ID, modelID models.ID) (*models.ModelSchema, error)
}

type modelSchemaService struct {
modelSchemaStorage storage.ModelSchemaStorage
}

// NewModelSchemaService create an instance of `ModelSchemaService`
func NewModelSchemaService(storage storage.ModelSchemaStorage) ModelSchemaService {
return &modelSchemaService{
modelSchemaStorage: storage,
Expand Down
6 changes: 6 additions & 0 deletions api/storage/model_schema_storage.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,23 @@ import (
"gorm.io/gorm"
)

// ModelSchemaStorage interface, layer that responsibles to communicate directly with database
type ModelSchemaStorage interface {
// Save create or update model schema to DB
Save(ctx context.Context, modelSchema *models.ModelSchema) (*models.ModelSchema, error)
// FindAll find all schemas givem model id from DB
FindAll(ctx context.Context, modelID models.ID) ([]*models.ModelSchema, error)
// FindByID find schema given it's id from DB
FindByID(ctx context.Context, modelSchemaID models.ID, modelID models.ID) (*models.ModelSchema, error)
// Delete delete schema give it's id from DB
Delete(ctx context.Context, modelSchema *models.ModelSchema) error
}

type modelSchemaStorage struct {
db *gorm.DB
}

// NewModelSchemaStorage create new instance of ModelSchemaStorage
func NewModelSchemaStorage(db *gorm.DB) ModelSchemaStorage {
return &modelSchemaStorage{db: db}
}
Expand Down
3 changes: 3 additions & 0 deletions docs/user/templates/09_model_observability.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
<!-- page-title: Model Observability -->
# Model Observability
Model observability enable model's owner to observe and analyze their model in production by look at the performance and drift metrics.
7 changes: 5 additions & 2 deletions python/sdk/merlin/fluent.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@
from merlin.environment import Environment
from merlin.logger import Logger
from merlin.model import Model, ModelType, ModelVersion, Project
from merlin.model_schema import ModelSchema
from merlin.protocol import Protocol
from merlin.resource_request import ResourceRequest
from merlin.transformer import Transformer
from merlin.model_schema import ModelSchema

_merlin_client: Optional[MerlinClient] = None
_active_project: Optional[Project]
Expand Down Expand Up @@ -156,11 +156,14 @@ def active_model() -> Optional[Model]:


@contextmanager
def new_model_version(labels: Dict[str, str] = None, model_schema: Optional[ModelSchema] = None):
def new_model_version(
labels: Dict[str, str] = None, model_schema: Optional[ModelSchema] = None
):
"""
Create new model version under currently active model
:param labels: dictionary containing the label that will be stored in the new model version
:param model_schema: Detail schema specification of a model
:return: ModelVersion
"""
v = None
Expand Down
Loading

0 comments on commit 2183242

Please sign in to comment.