diff --git a/api/api/model_schema_api.go b/api/api/model_schema_api.go index 800295414..719f849c2 100644 --- a/api/api/model_schema_api.go +++ b/api/api/model_schema_api.go @@ -9,10 +9,12 @@ import ( mErrors "github.com/caraml-dev/merlin/pkg/errors" ) +// ModelSchemaController type ModelSchemaController struct { *AppContext } +// GetAllSchemas list all model schemas given model ID func (m *ModelSchemaController) GetAllSchemas(r *http.Request, vars map[string]string, _ interface{}) *Response { ctx := r.Context() modelID, _ := models.ParseID(vars["model_id"]) @@ -26,6 +28,7 @@ func (m *ModelSchemaController) GetAllSchemas(r *http.Request, vars map[string]s return Ok(modelSchemas) } +// GetSchema get detail of a model schema given the schema id and model id func (m *ModelSchemaController) GetSchema(r *http.Request, vars map[string]string, _ interface{}) *Response { ctx := r.Context() modelID, _ := models.ParseID(vars["model_id"]) @@ -41,6 +44,10 @@ func (m *ModelSchemaController) GetSchema(r *http.Request, vars map[string]strin return Ok(modelSchema) } +// CreateOrUpdateSchema upsert schema +// If ID is not defined it will create new model schema +// If ID is defined but not exist, it will create new model schema +// If ID is defined and exist, it will update the existing model schema associated with that ID func (m *ModelSchemaController) CreateOrUpdateSchema(r *http.Request, vars map[string]string, body interface{}) *Response { ctx := r.Context() modelID, _ := models.ParseID(vars["model_id"]) @@ -62,6 +69,7 @@ func (m *ModelSchemaController) CreateOrUpdateSchema(r *http.Request, vars map[s return Ok(schema) } +// DeleteSchema delete model schema given schema id and model id func (m *ModelSchemaController) DeleteSchema(r *http.Request, vars map[string]string, _ interface{}) *Response { ctx := r.Context() modelID, _ := models.ParseID(vars["model_id"]) diff --git a/api/models/model_schema.go b/api/models/model_schema.go index dd82c432e..451eff0aa 100644 --- a/api/models/model_schema.go +++ b/api/models/model_schema.go @@ -8,8 +8,7 @@ import ( "fmt" ) -type InferenceType string - +// ModelPredictionOutputClass is type for kinds of model type type ModelPredictionOutputClass string const ( @@ -18,6 +17,7 @@ const ( Ranking ModelPredictionOutputClass = "RankingOutput" ) +// Value type is type that represent type of the value type ValueType string const ( @@ -27,12 +27,14 @@ const ( String ValueType = "string" ) +// ModelSchema type ModelSchema struct { ID ID `json:"id"` Spec *SchemaSpec `json:"spec,omitempty"` ModelID ID `json:"model_id"` } +// SchemaSpec type SchemaSpec struct { PredictionIDColumn string `json:"prediction_id_column"` ModelPredictionOutput *ModelPredictionOutput `json:"model_prediction_output"` @@ -40,10 +42,14 @@ type SchemaSpec struct { FeatureTypes map[string]ValueType `json:"feature_types"` } +// Value returning a value for `SchemaSpec` instance +// This is required to be implemented when this instance is treated as JSONB column func (s SchemaSpec) Value() (driver.Value, error) { return json.Marshal(s) } +// Scan returning error when assigning value from db driver is failing +// This is required to be implemented when this instance is treated as JSONB column func (s *SchemaSpec) Scan(value interface{}) error { b, ok := value.([]byte) if !ok { @@ -65,6 +71,7 @@ func newStrictDecoder(data []byte) *json.Decoder { return dec } +// UnmarshalJSON custom deserialization of bytes into `ModelPredictionOutput` func (m *ModelPredictionOutput) UnmarshalJSON(data []byte) error { var err error outputClassStruct := struct { @@ -99,6 +106,7 @@ func (m *ModelPredictionOutput) UnmarshalJSON(data []byte) error { return nil } +// MarshalJSON custom serialization of `ModelPredictionOutput` into json byte func (m ModelPredictionOutput) MarshalJSON() ([]byte, error) { if m.BinaryClassificationOutput != nil { return json.Marshal(&m.BinaryClassificationOutput) @@ -115,6 +123,7 @@ func (m ModelPredictionOutput) MarshalJSON() ([]byte, error) { return nil, nil } +// BinaryClassificationOutput is specification for prediction of binary classification model type BinaryClassificationOutput struct { ActualLabelColumn string `json:"actual_label_column"` NegativeClassLabel string `json:"negative_class_label"` @@ -125,6 +134,7 @@ type BinaryClassificationOutput struct { OutputClass ModelPredictionOutputClass `json:"output_class" validate:"required"` } +// RankingOutput is specification for prediction of ranking model type RankingOutput struct { PredictionGroudIDColumn string `json:"prediction_group_id_column"` RankScoreColumn string `json:"rank_score_column"` @@ -132,6 +142,7 @@ type RankingOutput struct { OutputClass ModelPredictionOutputClass `json:"output_class" validate:"required"` } +// Regression is specification for prediction of regression model type RegressionOutput struct { PredictionScoreColumn string `json:"prediction_score_column"` ActualScoreColumn string `json:"actual_score_column"` diff --git a/api/models/version.go b/api/models/version.go index 9a212590a..686ea486a 100644 --- a/api/models/version.go +++ b/api/models/version.go @@ -23,6 +23,7 @@ import ( "gorm.io/gorm" ) +// Version type Version struct { ID ID `json:"id" gorm:"primary_key"` ModelID ID `json:"model_id" gorm:"primary_key"` @@ -40,18 +41,21 @@ type Version struct { CreatedUpdated } +// VersionPost contains all information that is used during version creation type VersionPost struct { Labels KV `json:"labels" gorm:"labels"` PythonVersion string `json:"python_version" gorm:"python_version"` ModelSchema *ModelSchema `json:"model_schema"` } +// VersionPatch contains all information that is used during version update or patch type VersionPatch struct { Properties *KV `json:"properties,omitempty"` CustomPredictor *CustomPredictor `json:"custom_predictor,omitempty"` ModelSchema *ModelSchema `json:"model_schema"` } +// CustomPredictor contains configuration for custom model type CustomPredictor struct { Image string `json:"image"` Command string `json:"command"` @@ -93,6 +97,7 @@ func (kv *KV) Scan(value interface{}) error { return json.Unmarshal(b, &kv) } +// Validate do validation on the value of version func (v *Version) Validate() error { if v.CustomPredictor != nil && v.Model.Type == ModelTypeCustom { if err := v.CustomPredictor.IsValid(); err != nil { @@ -107,6 +112,7 @@ func (v *Version) Validate() error { return nil } +// Patch version value func (v *Version) Patch(patch *VersionPatch) { if patch.Properties != nil { v.Properties = *patch.Properties @@ -119,6 +125,7 @@ func (v *Version) Patch(patch *VersionPatch) { } } +// BeforeCreate find the latest persisted ID from version DB and increament it and assign to the receiver func (v *Version) BeforeCreate(db *gorm.DB) error { if v.ID == 0 { var maxModelVersionID int diff --git a/api/service/model_schema_service.go b/api/service/model_schema_service.go index a9ee82100..d7f66016a 100644 --- a/api/service/model_schema_service.go +++ b/api/service/model_schema_service.go @@ -10,10 +10,15 @@ import ( "gorm.io/gorm" ) +// ModelSchemaService interface type ModelSchemaService interface { + // List all the model schemas for a model List(ctx context.Context, modelID models.ID) ([]*models.ModelSchema, error) + // Save model schema, it can be create or update existing schema Save(ctx context.Context, modelSchema *models.ModelSchema) (*models.ModelSchema, error) + // Delete a model schema Delete(ctx context.Context, modelSchema *models.ModelSchema) error + // FindByID get schema given it's schema id and model id FindByID(ctx context.Context, modelSchemaID models.ID, modelID models.ID) (*models.ModelSchema, error) } @@ -21,6 +26,7 @@ type modelSchemaService struct { modelSchemaStorage storage.ModelSchemaStorage } +// NewModelSchemaService create an instance of `ModelSchemaService` func NewModelSchemaService(storage storage.ModelSchemaStorage) ModelSchemaService { return &modelSchemaService{ modelSchemaStorage: storage, diff --git a/api/storage/model_schema_storage.go b/api/storage/model_schema_storage.go index f7833f086..3270d0f80 100644 --- a/api/storage/model_schema_storage.go +++ b/api/storage/model_schema_storage.go @@ -7,10 +7,15 @@ import ( "gorm.io/gorm" ) +// ModelSchemaStorage interface, layer that responsibles to communicate directly with database type ModelSchemaStorage interface { + // Save create or update model schema to DB Save(ctx context.Context, modelSchema *models.ModelSchema) (*models.ModelSchema, error) + // FindAll find all schemas givem model id from DB FindAll(ctx context.Context, modelID models.ID) ([]*models.ModelSchema, error) + // FindByID find schema given it's id from DB FindByID(ctx context.Context, modelSchemaID models.ID, modelID models.ID) (*models.ModelSchema, error) + // Delete delete schema give it's id from DB Delete(ctx context.Context, modelSchema *models.ModelSchema) error } @@ -18,6 +23,7 @@ type modelSchemaStorage struct { db *gorm.DB } +// NewModelSchemaStorage create new instance of ModelSchemaStorage func NewModelSchemaStorage(db *gorm.DB) ModelSchemaStorage { return &modelSchemaStorage{db: db} } diff --git a/docs/user/templates/09_model_observability.md b/docs/user/templates/09_model_observability.md new file mode 100644 index 000000000..d224fcffa --- /dev/null +++ b/docs/user/templates/09_model_observability.md @@ -0,0 +1,3 @@ + +# Model Observability +Model observability enable model's owner to observe and analyze their model in production by look at the performance and drift metrics. \ No newline at end of file diff --git a/python/sdk/merlin/fluent.py b/python/sdk/merlin/fluent.py index ce175c310..72667a07b 100644 --- a/python/sdk/merlin/fluent.py +++ b/python/sdk/merlin/fluent.py @@ -25,10 +25,10 @@ from merlin.environment import Environment from merlin.logger import Logger from merlin.model import Model, ModelType, ModelVersion, Project +from merlin.model_schema import ModelSchema from merlin.protocol import Protocol from merlin.resource_request import ResourceRequest from merlin.transformer import Transformer -from merlin.model_schema import ModelSchema _merlin_client: Optional[MerlinClient] = None _active_project: Optional[Project] @@ -156,11 +156,14 @@ def active_model() -> Optional[Model]: @contextmanager -def new_model_version(labels: Dict[str, str] = None, model_schema: Optional[ModelSchema] = None): +def new_model_version( + labels: Dict[str, str] = None, model_schema: Optional[ModelSchema] = None +): """ Create new model version under currently active model :param labels: dictionary containing the label that will be stored in the new model version + :param model_schema: Detail schema specification of a model :return: ModelVersion """ v = None diff --git a/python/sdk/merlin/model.py b/python/sdk/merlin/model.py index a3c6d3c09..d50428eef 100644 --- a/python/sdk/merlin/model.py +++ b/python/sdk/merlin/model.py @@ -26,7 +26,6 @@ import client import docker -import mlflow import pyprind import yaml from client import ( @@ -39,7 +38,6 @@ ) from docker import APIClient from docker.models.containers import Container -from merlin import pyfunc from merlin.autoscaling import ( RAW_DEPLOYMENT_DEFAULT_AUTOSCALING_POLICY, SERVERLESS_DEFAULT_AUTOSCALING_POLICY, @@ -53,6 +51,7 @@ from merlin.docker.docker import copy_standard_dockerfile, wait_build_complete from merlin.endpoint import ModelEndpoint, Status, VersionEndpoint from merlin.logger import Logger +from merlin.model_schema import ModelSchema from merlin.protocol import Protocol from merlin.pyfunc import run_pyfunc_local_server from merlin.resource_request import ResourceRequest @@ -60,16 +59,18 @@ from merlin.util import ( autostr, download_files_from_gcs, + extract_optional_value_with_default, guess_mlp_ui_url, valid_name_check, ) from merlin.validation import validate_model_dir +from merlin.version import VERSION from mlflow.entities import Run, RunData from mlflow.exceptions import MlflowException from mlflow.pyfunc import PythonModel -from merlin.version import VERSION -from merlin.model_schema import ModelSchema -from merlin.util import extract_optional_value_with_default + +import mlflow +from merlin import pyfunc # Ensure backward compatibility after moving PyFuncModel and PyFuncV2Model to pyfunc.py # This allows users to do following import statement @@ -123,7 +124,9 @@ def __init__( self._url = mlp_url self._api_client = api_client self._readers = extract_optional_value_with_default(project.readers, []) - self._administrators = extract_optional_value_with_default(project.administrators, []) + self._administrators = extract_optional_value_with_default( + project.administrators, [] + ) @property def id(self) -> int: @@ -336,6 +339,7 @@ def mlflow_experiment_id(self) -> Optional[int]: if self._mlflow_experiment_id is not None: return int(self._mlflow_experiment_id) return None + @property def created_at(self) -> Optional[datetime]: return self._created_at @@ -440,14 +444,16 @@ def _list_version_pagination( ) versions = version_api_response.data headers = extract_optional_value_with_default(version_api_response.headers, {}) - + next_cursor = headers.get("Next-Cursor") or "" result = [] for v in versions: result.append(ModelVersion(v, self, self._api_client)) return result, next_cursor - def new_model_version(self, labels: Dict[str, str] = None, model_schema: Optional[ModelSchema] = None) -> "ModelVersion": + def new_model_version( + self, labels: Dict[str, str] = None, model_schema: Optional[ModelSchema] = None + ) -> "ModelVersion": """ Create a new version of this model @@ -461,7 +467,12 @@ def new_model_version(self, labels: Dict[str, str] = None, model_schema: Optiona model_schema.model_id = self.id model_schema_payload = model_schema.to_client_model_schema() v = version_api.models_model_id_versions_post( - int(self.id), body=client.Version(labels=labels, python_version=python_version, model_schema=model_schema_payload) + int(self.id), + body=client.Version( + labels=labels, + python_version=python_version, + model_schema=model_schema_payload, + ), ) return ModelVersion(v, self, self._api_client) @@ -522,18 +533,20 @@ def serve_traffic( ep = client.ModelEndpoint( model_id=self.id, environment_name=target_env, rule=rule ) - ep = mdl_epi_api.models_model_id_endpoints_post( - model_id=self.id, body=ep - ) + ep = mdl_epi_api.models_model_id_endpoints_post(model_id=self.id, body=ep) elif prev_endpoint.id is not None: # update: GET and PUT ep = mdl_epi_api.models_model_id_endpoints_model_endpoint_id_get( model_id=self.id, model_endpoint_id=prev_endpoint.id ) - if ep.rule is not None and ep.rule.destinations is not None and len(ep.rule.destinations) > 0: + if ( + ep.rule is not None + and ep.rule.destinations is not None + and len(ep.rule.destinations) > 0 + ): ep.rule.destinations[0].version_endpoint_id = version_endpoint.id ep.rule.destinations[0].weight = 100 - + ep = mdl_epi_api.models_model_id_endpoints_model_endpoint_id_put( model_id=int(self.id), model_endpoint_id=prev_endpoint.id, @@ -581,9 +594,7 @@ def stop_serving_traffic(self, environment_name: str = None): f"in {target_env} environment" ) if target_endpoint.id is None: - raise ValueError( - f"model endpoint doesn't have id information" - ) + raise ValueError(f"model endpoint doesn't have id information") mdl_epi_api.models_model_id_endpoints_model_endpoint_id_delete( self.id, target_endpoint.id ) @@ -646,7 +657,7 @@ def set_traffic(self, traffic_rule: Dict["ModelVersion", int]) -> ModelEndpoint: weight=100, ) ] - ) + ), ), model_id=int(self.id), ) @@ -657,10 +668,14 @@ def set_traffic(self, traffic_rule: Dict["ModelVersion", int]) -> ModelEndpoint: ep = model_endpoint_api.models_model_id_endpoints_model_endpoint_id_get( model_id=int(self.id), model_endpoint_id=def_model_endpoint.id ) - if ep.rule is not None and ep.rule.destinations is not None and len(ep.rule.destinations) > 0: + if ( + ep.rule is not None + and ep.rule.destinations is not None + and len(ep.rule.destinations) > 0 + ): ep.rule.destinations[0].version_endpoint_id = def_version_endpoint.id ep.rule.destinations[0].weight = 100 - + ep = model_endpoint_api.models_model_id_endpoints_model_endpoint_id_put( model_id=int(self.id), model_endpoint_id=def_model_endpoint.id, @@ -710,8 +725,10 @@ def __init__( self._labels = version.labels self._custom_predictor = version.custom_predictor self._python_version = version.python_version - self._model_schema = ModelSchema.from_model_schema_response(version.model_schema) - mlflow.set_tracking_uri(model.project.mlflow_tracking_url) # type: ignore # noqa + self._model_schema = ModelSchema.from_model_schema_response( + version.model_schema + ) + mlflow.set_tracking_uri(model.project.mlflow_tracking_url) # type: ignore # noqa endpoints = None if version.endpoints is not None: @@ -720,7 +737,6 @@ def __init__( endpoints.append(VersionEndpoint(ep)) self._version_endpoints = endpoints - @property def id(self) -> int: return self._id @@ -783,7 +799,7 @@ def url(self) -> str: model_id = self.model.id base_url = guess_mlp_ui_url(self.model.project.url) return f"{base_url}/projects/{project_id}/models/{model_id}/versions" - + @property def model_schema(self) -> Optional[ModelSchema]: return self._model_schema @@ -1173,16 +1189,21 @@ def deploy( for env in env_list: if env.gpus is None: continue - + for gpu in env.gpus: if resource_request.gpu_name == gpu.name: - if gpu.values is not None and resource_request.gpu_request not in gpu.values: + if ( + gpu.values is not None + and resource_request.gpu_request not in gpu.values + ): raise ValueError( f"Invalid GPU request count. Supported GPUs count for {resource_request.gpu_name} is {gpu.values}" ) if target_resource_request is not None: - target_resource_request.gpu_name = resource_request.gpu_name + target_resource_request.gpu_name = ( + resource_request.gpu_name + ) target_resource_request.gpu_request = ( resource_request.gpu_request ) @@ -1240,7 +1261,7 @@ def deploy( if current_endpoint.id is None: raise ValueError("current endpoint must have id") - + endpoint = endpoint_api.models_model_id_versions_version_id_endpoint_endpoint_id_put( int(model.id), int(self.id), @@ -1262,7 +1283,7 @@ def deploy( # started acting after the deployment job has been submitted. if endpoint.id is None: raise ValueError("endpoint id must be set") - + endpoint = endpoint_api.models_model_id_versions_version_id_endpoint_endpoint_id_get( model_id=model.id, version_id=self.id, endpoint_id=endpoint.id ) @@ -1272,7 +1293,11 @@ def deploy( bar.update() bar.stop() - if endpoint.status != "running" and endpoint.status != "serving" and endpoint.message is not None: + if ( + endpoint.status != "running" + and endpoint.status != "serving" + and endpoint.message is not None + ): raise ModelEndpointDeploymentError(model.name, self.id, endpoint.message) log_url = f"{self.url}/{self.id}/endpoints/{endpoint.id}/logs" @@ -1350,7 +1375,7 @@ def create_prediction_job( result=client.PredictionJobConfigModelResult( type=client.ResultType(job_config.result_type.value), item_type=client.ResultType(job_config.item_type.value), - ) + ), ), ) @@ -1412,7 +1437,7 @@ def create_prediction_job( job_id = j.id if job_id is None: raise ValueError("job id must be exist") - + if not sync: j = job_client.models_model_id_versions_version_id_jobs_job_id_get( model_id=self.model.id, version_id=self.id, job_id=job_id @@ -1646,10 +1671,11 @@ def _get_default_resource_request( resource_request.validate() return client.ResourceRequest( - min_replica=resource_request.min_replica, + min_replica=resource_request.min_replica, max_replica=resource_request.max_replica, - cpu_request=resource_request.cpu_request, - memory_request=resource_request.memory_request) + cpu_request=resource_request.cpu_request, + memory_request=resource_request.memory_request, + ) @staticmethod def _get_default_autoscaling_policy( @@ -1660,8 +1686,9 @@ def _get_default_autoscaling_policy( else: autoscaling_policy = SERVERLESS_DEFAULT_AUTOSCALING_POLICY return client.AutoscalingPolicy( - metrics_type=client.MetricsType(autoscaling_policy.metrics_type.value), - target_value=autoscaling_policy.target_value) + metrics_type=client.MetricsType(autoscaling_policy.metrics_type.value), + target_value=autoscaling_policy.target_value, + ) @staticmethod def _add_env_vars(target_env_vars, new_env_vars): @@ -1689,10 +1716,11 @@ def _create_transformer_spec( else: resource_request.validate() target_resource_request = client.ResourceRequest( - min_replica=resource_request.min_replica, + min_replica=resource_request.min_replica, max_replica=resource_request.max_replica, - cpu_request=resource_request.cpu_request, - memory_request=resource_request.memory_request) + cpu_request=resource_request.cpu_request, + memory_request=resource_request.memory_request, + ) target_env_vars: List[client.EnvVar] = [] if transformer.env_vars is not None: @@ -1761,7 +1789,7 @@ def match_dependency(spec, name): if python_version is None: raise ValueError("python version must be set") - + new_conda_env = {} if isinstance(conda_env, str): diff --git a/python/sdk/merlin/model_schema.py b/python/sdk/merlin/model_schema.py index d5a9fbf86..821f59928 100644 --- a/python/sdk/merlin/model_schema.py +++ b/python/sdk/merlin/model_schema.py @@ -1,44 +1,57 @@ - from __future__ import annotations -from typing import Dict, Optional, Any -from dataclasses import dataclass -from dataclasses_json import dataclass_json - -from merlin.util import autostr -from merlin.util import extract_optional_value_with_default - -from enum import Enum -from merlin.observability.inference import InferenceSchema, PredictionOutput, BinaryClassificationOutput, RegressionOutput, RankingOutput, ValueType +from dataclasses import dataclass +from typing import Optional import client +from dataclasses_json import dataclass_json +from merlin.observability.inference import ( + BinaryClassificationOutput, + InferenceSchema, + PredictionOutput, + RankingOutput, + RegressionOutput, + ValueType, +) +from merlin.util import autostr, extract_optional_value_with_default @autostr @dataclass_json @dataclass class ModelSchema: + """ + Representation of schema for a model + """ + spec: InferenceSchema id: Optional[int] = None model_id: Optional[int] = None @classmethod - def from_model_schema_response(cls, response: Optional[client.ModelSchema]=None) -> Optional[ModelSchema]: + def from_model_schema_response( + cls, response: Optional[client.ModelSchema] = None + ) -> Optional[ModelSchema]: + """ + Convert model schema payload from server response and convert it to `ModelSchema` + + :param response: Model schema payload as part of response from server that already deserialize to OpenAPI ModelSchema + :type response: Optional[client.ModelSchema] + """ if response is None: return None - + response_spec = response.spec if response_spec is None: - return ModelSchema( - id=response.id, - model_id=response.model_id + return ModelSchema(id=response.id, model_id=response.model_id) + + prediction_output = cls.model_prediction_output_from_response( + response_spec.model_prediction_output ) - - prediction_output = cls.model_prediction_output_from_response(response_spec.model_prediction_output) feature_types = {} for key, val in response_spec.feature_types.items(): feature_types[key] = ValueType(val.value) - + return ModelSchema( id=response.id, model_id=response.model_id, @@ -46,65 +59,99 @@ def from_model_schema_response(cls, response: Optional[client.ModelSchema]=None) feature_types=feature_types, prediction_id_column=response_spec.prediction_id_column, tag_columns=response_spec.tag_columns, - model_prediction_output=prediction_output - ) + model_prediction_output=prediction_output, + ), ) + @classmethod - def model_prediction_output_from_response(cls, model_prediction_output: client.ModelPredictionOutput) -> PredictionOutput: + def model_prediction_output_from_response( + cls, model_prediction_output: client.ModelPredictionOutput + ) -> PredictionOutput: + """ + Convert model prediction output from server payload into `PredictionOutput`. + + :param model_prediction_output: Model prediction output information that is part of model schema server payload. + :type model_prediction_output: client.ModelPredictionOutput + """ actual_instance = model_prediction_output.actual_instance if isinstance(actual_instance, client.BinaryClassificationOutput): return BinaryClassificationOutput( prediction_score_column=actual_instance.prediction_score_column, - actual_label_column=extract_optional_value_with_default(actual_instance.actual_label_column, ""), + actual_label_column=extract_optional_value_with_default( + actual_instance.actual_label_column, "" + ), positive_class_label=actual_instance.positive_class_label, negative_class_label=actual_instance.negative_class_label, - score_threshold=extract_optional_value_with_default(actual_instance.score_threshold, 0.5) + score_threshold=extract_optional_value_with_default( + actual_instance.score_threshold, 0.5 + ), ) elif isinstance(actual_instance, client.RegressionOutput): return RegressionOutput( - actual_score_column=extract_optional_value_with_default(actual_instance.actual_score_column, ""), - prediction_score_column=actual_instance.prediction_score_column + actual_score_column=extract_optional_value_with_default( + actual_instance.actual_score_column, "" + ), + prediction_score_column=actual_instance.prediction_score_column, ) elif isinstance(actual_instance, client.RankingOutput): return RankingOutput( - relevance_score_column=extract_optional_value_with_default(actual_instance.relevance_score_column, ""), + relevance_score_column=extract_optional_value_with_default( + actual_instance.relevance_score_column, "" + ), prediction_group_id_column=actual_instance.prediction_group_id_column, - rank_score_column=actual_instance.rank_score_column + rank_score_column=actual_instance.rank_score_column, ) - raise ValueError("model prediction output from server is not in acceptable type") - + raise ValueError( + "model prediction output from server is not in acceptable type" + ) + def _to_client_prediction_output_spec(self) -> client.ModelPredictionOutput: prediction_output = self.spec.model_prediction_output if isinstance(prediction_output, BinaryClassificationOutput): - return client.ModelPredictionOutput(client.BinaryClassificationOutput( - prediction_score_column=prediction_output.prediction_score_column, - actual_label_column=prediction_output.actual_label_column, - positive_class_label=prediction_output.positive_class_label, - negative_class_label=prediction_output.negative_class_label, - score_threshold=prediction_output.score_threshold, - output_class=client.ModelPredictionOutputClass(BinaryClassificationOutput.__name__) - )) + return client.ModelPredictionOutput( + client.BinaryClassificationOutput( + prediction_score_column=prediction_output.prediction_score_column, + actual_label_column=prediction_output.actual_label_column, + positive_class_label=prediction_output.positive_class_label, + negative_class_label=prediction_output.negative_class_label, + score_threshold=prediction_output.score_threshold, + output_class=client.ModelPredictionOutputClass( + BinaryClassificationOutput.__name__ + ), + ) + ) elif isinstance(prediction_output, RegressionOutput): - return client.ModelPredictionOutput(client.RegressionOutput( - actual_score_column=prediction_output.actual_score_column, - prediction_score_column=prediction_output.prediction_score_column, - output_class=client.ModelPredictionOutputClass(RegressionOutput.__name__) - )) + return client.ModelPredictionOutput( + client.RegressionOutput( + actual_score_column=prediction_output.actual_score_column, + prediction_score_column=prediction_output.prediction_score_column, + output_class=client.ModelPredictionOutputClass( + RegressionOutput.__name__ + ), + ) + ) elif isinstance(prediction_output, RankingOutput): - return client.ModelPredictionOutput(client.RankingOutput( - relevance_score_column=prediction_output.relevance_score_column, - prediction_group_id_column=prediction_output.prediction_group_id_column, - rank_score_column=prediction_output.rank_score_column, - output_class=client.ModelPredictionOutputClass(RankingOutput.__name__) - )) - + return client.ModelPredictionOutput( + client.RankingOutput( + relevance_score_column=prediction_output.relevance_score_column, + prediction_group_id_column=prediction_output.prediction_group_id_column, + rank_score_column=prediction_output.rank_score_column, + output_class=client.ModelPredictionOutputClass( + RankingOutput.__name__ + ), + ) + ) + raise ValueError("model prediction output is not recognized") - + def to_client_model_schema(self) -> client.ModelSchema: + """ + Convert `ModelSchema` into OpenAPI `ModelSchema` that is being used by SDK to communicate to merlin server + """ feature_types = {} for key, val in self.spec.feature_types.items(): feature_types[key] = client.ValueType(val.value) - + return client.ModelSchema( id=self.id, model_id=self.model_id, @@ -112,8 +159,6 @@ def to_client_model_schema(self) -> client.ModelSchema: prediction_id_column=self.spec.prediction_id_column, tag_columns=self.spec.tag_columns, feature_types=feature_types, - model_prediction_output=self._to_client_prediction_output_spec() - ) + model_prediction_output=self._to_client_prediction_output_spec(), + ), ) - -