diff --git a/api/cmd/api/main.go b/api/cmd/api/main.go index 551534eab..d444710d4 100644 --- a/api/cmd/api/main.go +++ b/api/cmd/api/main.go @@ -43,6 +43,7 @@ import ( mlflow "github.com/caraml-dev/merlin/mlflow" "github.com/caraml-dev/merlin/models" "github.com/caraml-dev/merlin/pkg/gitlab" + "github.com/caraml-dev/merlin/pkg/observability/event" "github.com/caraml-dev/merlin/queue" "github.com/caraml-dev/merlin/queue/work" "github.com/caraml-dev/merlin/service" @@ -138,7 +139,7 @@ func main() { dependencies := buildDependencies(ctx, cfg, db, dispatcher) - registerQueueJob(dispatcher, dependencies.modelDeployment, dependencies.batchDeployment) + registerQueueJob(dispatcher, dependencies.modelDeployment, dependencies.batchDeployment, dependencies.observabilityDeployment) dispatcher.Start() if err := initCronJob(dependencies, db); err != nil { @@ -253,9 +254,10 @@ func newPprofRouter() *mux.Router { return r } -func registerQueueJob(consumer queue.Consumer, modelServiceDepl *work.ModelServiceDeployment, batchDepl *work.BatchDeployment) { +func registerQueueJob(consumer queue.Consumer, modelServiceDepl *work.ModelServiceDeployment, batchDepl *work.BatchDeployment, obsDepl *work.ObservabilityPublisherDeployment) { consumer.RegisterJob(service.ModelServiceDeployment, modelServiceDepl.Deploy) consumer.RegisterJob(service.BatchDeployment, batchDepl.Deploy) + consumer.RegisterJob(event.ObservabilityPublisherDeployment, obsDepl.Deploy) } func buildDependencies(ctx context.Context, cfg *config.Config, db *gorm.DB, dispatcher *queue.Dispatcher) deps { @@ -268,10 +270,14 @@ func buildDependencies(ctx context.Context, cfg *config.Config, db *gorm.DB, dis webServiceBuilder, predJobBuilder, imageBuilderJanitor := initImageBuilder(cfg) + observabilityPublisherStorage := storage.NewObservabilityPublisherStorage(db) + observabilityPublisherDeployment := initObservabilityPublisherDeployment(cfg, observabilityPublisherStorage) + versionStorage := storage.NewVersionStorage(db) + observabilityEvent := event.NewEventProducer(dispatcher, observabilityPublisherStorage, versionStorage) clusterControllers := initClusterControllers(cfg) - modelServiceDeployment := initModelServiceDeployment(cfg, webServiceBuilder, clusterControllers, db) + modelServiceDeployment := initModelServiceDeployment(cfg, webServiceBuilder, clusterControllers, db, observabilityEvent) versionEndpointService := initVersionEndpointService(cfg, webServiceBuilder, clusterControllers, db, coreClient, dispatcher) - modelEndpointService := initModelEndpointService(cfg, db) + modelEndpointService := initModelEndpointService(cfg, db, observabilityEvent) batchControllers := initBatchControllers(cfg, db, mlpAPIClient) batchDeployment := initBatchDeployment(cfg, db, batchControllers, predJobBuilder) @@ -356,9 +362,10 @@ func buildDependencies(ctx context.Context, cfg *config.Config, db *gorm.DB, dis MlflowClient: mlflowClient, } return deps{ - apiContext: apiContext, - modelDeployment: modelServiceDeployment, - batchDeployment: batchDeployment, - imageBuilderJanitor: imageBuilderJanitor, + apiContext: apiContext, + modelDeployment: modelServiceDeployment, + batchDeployment: batchDeployment, + observabilityDeployment: observabilityPublisherDeployment, + imageBuilderJanitor: imageBuilderJanitor, } } diff --git a/api/cmd/api/setup.go b/api/cmd/api/setup.go index 2252ac716..74f4b2e4c 100644 --- a/api/cmd/api/setup.go +++ b/api/cmd/api/setup.go @@ -28,6 +28,8 @@ import ( "github.com/caraml-dev/merlin/mlp" "github.com/caraml-dev/merlin/models" "github.com/caraml-dev/merlin/pkg/imagebuilder" + "github.com/caraml-dev/merlin/pkg/observability/deployment" + "github.com/caraml-dev/merlin/pkg/observability/event" "github.com/caraml-dev/merlin/queue" "github.com/caraml-dev/merlin/queue/work" "github.com/caraml-dev/merlin/service" @@ -36,10 +38,11 @@ import ( ) type deps struct { - apiContext api.AppContext - modelDeployment *work.ModelServiceDeployment - batchDeployment *work.BatchDeployment - imageBuilderJanitor *imagebuilder.Janitor + apiContext api.AppContext + modelDeployment *work.ModelServiceDeployment + batchDeployment *work.BatchDeployment + observabilityDeployment *work.ObservabilityPublisherDeployment + imageBuilderJanitor *imagebuilder.Janitor } func initMLPAPIClient(ctx context.Context, cfg config.MlpAPIConfig) mlp.APIClient { @@ -332,7 +335,7 @@ func initEnvironmentService(cfg *config.Config, db *gorm.DB) service.Environment return svc } -func initModelEndpointService(cfg *config.Config, db *gorm.DB) service.ModelEndpointsService { +func initModelEndpointService(cfg *config.Config, db *gorm.DB, observabilityEvent event.EventProducer) service.ModelEndpointsService { istioClients := make(map[string]istio.Client) for _, env := range cfg.ClusterConfig.EnvironmentConfigs { creds := mlpcluster.NewK8sClusterCreds(env.K8sConfig) @@ -348,7 +351,7 @@ func initModelEndpointService(cfg *config.Config, db *gorm.DB) service.ModelEndp istioClients[env.Name] = istioClient } - return service.NewModelEndpointsService(istioClients, storage.NewModelEndpointStorage(db), storage.NewVersionEndpointStorage(db), cfg.Environment) + return service.NewModelEndpointsService(istioClients, storage.NewModelEndpointStorage(db), storage.NewVersionEndpointStorage(db), cfg.Environment, observabilityEvent) } func initBatchDeployment(cfg *config.Config, db *gorm.DB, controllers map[string]batch.Controller, builder imagebuilder.ImageBuilder) *work.BatchDeployment { @@ -415,13 +418,56 @@ func initPredictionJobService(cfg *config.Config, controllers map[string]batch.C return service.NewPredictionJobService(controllers, builder, predictionJobStorage, clock.RealClock{}, cfg.Environment, producer) } -func initModelServiceDeployment(cfg *config.Config, builder imagebuilder.ImageBuilder, controllers map[string]cluster.Controller, db *gorm.DB) *work.ModelServiceDeployment { +func initModelServiceDeployment(cfg *config.Config, builder imagebuilder.ImageBuilder, controllers map[string]cluster.Controller, db *gorm.DB, observabilityEvent event.EventProducer) *work.ModelServiceDeployment { return &work.ModelServiceDeployment{ - ClusterControllers: controllers, - ImageBuilder: builder, - Storage: storage.NewVersionEndpointStorage(db), - DeploymentStorage: storage.NewDeploymentStorage(db), - LoggerDestinationURL: cfg.LoggerDestinationURL, + ClusterControllers: controllers, + ImageBuilder: builder, + Storage: storage.NewVersionEndpointStorage(db), + DeploymentStorage: storage.NewDeploymentStorage(db), + LoggerDestinationURL: cfg.LoggerDestinationURL, + ObservabilityEventProducer: observabilityEvent, + } +} + +func initObservabilityPublisherDeployment(cfg *config.Config, observabilityPublisherStorage storage.ObservabilityPublisherStorage) *work.ObservabilityPublisherDeployment { + var envCfg *config.EnvironmentConfig + for _, env := range cfg.ClusterConfig.EnvironmentConfigs { + if env.Name == cfg.ObservabilityPublisher.EnvironmentName { + envCfg = env + break + } + } + if envCfg == nil { + log.Panicf("could not find destination environment for observability publisher") + } + + clusterCfg := cluster.Config{ + ClusterName: envCfg.Cluster, + GcpProject: envCfg.GcpProject, + } + + var restConfig *rest.Config + var err error + if cfg.ClusterConfig.InClusterConfig { + restConfig, err = rest.InClusterConfig() + if err != nil { + log.Panicf("unable to get in cluster configs: %v", err) + } + } else { + creds := mlpcluster.NewK8sClusterCreds(envCfg.K8sConfig) + restConfig, err = creds.ToRestConfig() + if err != nil { + log.Panicf("unable to get cluster config of cluster: %s %v", clusterCfg.ClusterName, err) + } + } + deployer, err := deployment.New(restConfig, cfg.ObservabilityPublisher) + if err != nil { + log.Panicf("unable to initialize observability deployer with err: %w", err) + } + + return &work.ObservabilityPublisherDeployment{ + Deployer: deployer, + ObservabilityPublisherStorage: observabilityPublisherStorage, } } diff --git a/api/config/config.go b/api/config/config.go index 0074efee5..3cf7c94b7 100644 --- a/api/config/config.go +++ b/api/config/config.go @@ -68,6 +68,7 @@ type Config struct { MlflowConfig MlflowConfig PyFuncPublisherConfig PyFuncPublisherConfig InferenceServiceDefaults InferenceServiceDefaults + ObservabilityPublisher ObservabilityPublisher } // UIConfig stores the configuration for the UI. diff --git a/api/config/observability.go b/api/config/observability.go new file mode 100644 index 000000000..26b485e72 --- /dev/null +++ b/api/config/observability.go @@ -0,0 +1,38 @@ +package config + +import "time" + +// ObservabilityPublisher +type ObservabilityPublisher struct { + ArizeSink ArizeSink + BigQuerySink BigQuerySink + KafkaConsumer KafkaConsumer + ImageName string + DefaultResources ResourceRequestsLimits + EnvironmentName string + Replicas int32 + TargetNamespace string + ServiceAccountName string + DeploymentTimeout time.Duration `default:"30m"` +} + +// KafkaConsumer +type KafkaConsumer struct { + Brokers string `validate:"required"` + BatchSize int + GroupID string + AdditionalConsumerConfig map[string]string +} + +// ArizeSink +type ArizeSink struct { + APIKey string + SpaceKey string +} + +// BigQuerySink +type BigQuerySink struct { + Project string + Dataset string + TTLDays int +} diff --git a/api/models/model.go b/api/models/model.go index e0877a0fb..fb0893f6a 100644 --- a/api/models/model.go +++ b/api/models/model.go @@ -53,13 +53,14 @@ type CreatedUpdated struct { } type Model struct { - ID ID `json:"id"` - Name string `json:"name" validate:"required,min=3,max=25,subdomain_rfc1123"` - ProjectID ID `json:"project_id"` - Project mlp.Project `json:"-" gorm:"-"` - ExperimentID ID `json:"mlflow_experiment_id" gorm:"column:mlflow_experiment_id"` - Type string `json:"type" gorm:"type"` - MlflowURL string `json:"mlflow_url" gorm:"-"` + ID ID `json:"id"` + Name string `json:"name" validate:"required,min=3,max=25,subdomain_rfc1123"` + ProjectID ID `json:"project_id"` + Project mlp.Project `json:"-" gorm:"-"` + ExperimentID ID `json:"mlflow_experiment_id" gorm:"column:mlflow_experiment_id"` + Type string `json:"type" gorm:"type"` + MlflowURL string `json:"mlflow_url" gorm:"-"` + ObservabilitySupported bool `json:"observability_supported" gorm:"column:observability_supported"` Endpoints []*ModelEndpoint `json:"endpoints" gorm:"foreignkey:ModelID;"` diff --git a/api/models/model_endpoint.go b/api/models/model_endpoint.go index 5dd1241be..a1486ec7d 100644 --- a/api/models/model_endpoint.go +++ b/api/models/model_endpoint.go @@ -37,6 +37,14 @@ type ModelEndpoint struct { CreatedUpdated } +func (me *ModelEndpoint) GetVersionEndpoint() *VersionEndpoint { + if me.Rule == nil || len(me.Rule.Destination) == 0 { + return nil + } + destination := me.Rule.Destination[0] + return destination.VersionEndpoint +} + // ModelEndpointRule describes model's endpoint traffic rule. type ModelEndpointRule struct { Destination []*ModelEndpointRuleDestination `json:"destinations"` diff --git a/api/models/model_schema.go b/api/models/model_schema.go index deede46e3..96696f2f5 100644 --- a/api/models/model_schema.go +++ b/api/models/model_schema.go @@ -36,12 +36,12 @@ type ModelSchema struct { // SchemaSpec type SchemaSpec struct { - SessionIDColumn string `json:"session_id_column"` - RowIDColumn string `json:"row_id_column"` - ModelPredictionOutput *ModelPredictionOutput `json:"model_prediction_output"` - TagColumns []string `json:"tag_columns"` - FeatureTypes map[string]ValueType `json:"feature_types"` - FeatureOrders []string `json:"feature_orders"` + SessionIDColumn string `json:"session_id_column" yaml:"session_id_column"` + RowIDColumn string `json:"row_id_column" yaml:"row_id_column"` + ModelPredictionOutput *ModelPredictionOutput `json:"model_prediction_output" yaml:"model_prediction_output"` + TagColumns []string `json:"tag_columns" yaml:"tag_columns"` + FeatureTypes map[string]ValueType `json:"feature_types" yaml:"feature_types"` + FeatureOrders []string `json:"feature_orders" yaml:"feature_orders"` } // Value returning a value for `SchemaSpec` instance @@ -125,27 +125,41 @@ func (m ModelPredictionOutput) MarshalJSON() ([]byte, error) { return nil, nil } +func (m ModelPredictionOutput) MarshalYAML() (interface{}, error) { + var in interface{} + if m.BinaryClassificationOutput != nil { + in = m.BinaryClassificationOutput + } else if m.RankingOutput != nil { + in = m.RankingOutput + } else if m.RegressionOutput != nil { + in = m.RegressionOutput + } else { + return nil, fmt.Errorf("not valid model prediction output") + } + return in, nil +} + // BinaryClassificationOutput is specification for prediction of binary classification model type BinaryClassificationOutput struct { - ActualScoreColumn string `json:"actual_score_column"` - NegativeClassLabel string `json:"negative_class_label"` - PredictionScoreColumn string `json:"prediction_score_column"` - PredictionLabelColumn string `json:"prediction_label_column"` - PositiveClassLabel string `json:"positive_class_label"` - ScoreThreshold *float64 `json:"score_threshold,omitempty"` - OutputClass ModelPredictionOutputClass `json:"output_class" validate:"required"` + ActualScoreColumn string `json:"actual_score_column" yaml:"actual_score_column"` + NegativeClassLabel string `json:"negative_class_label" yaml:"negative_class_label"` + PredictionScoreColumn string `json:"prediction_score_column" yaml:"prediction_score_column"` + PredictionLabelColumn string `json:"prediction_label_column" yaml:"prediction_label_column"` + PositiveClassLabel string `json:"positive_class_label" yaml:"positive_class_label"` + ScoreThreshold *float64 `json:"score_threshold,omitempty" yaml:"score_threshold"` + OutputClass ModelPredictionOutputClass `json:"output_class" yaml:"output_class" validate:"required"` } // RankingOutput is specification for prediction of ranking model type RankingOutput struct { - RankScoreColumn string `json:"rank_score_column"` - RelevanceScoreColumn string `json:"relevance_score_column"` - OutputClass ModelPredictionOutputClass `json:"output_class" validate:"required"` + RankScoreColumn string `json:"rank_score_column" yaml:"rank_score_column"` + RelevanceScoreColumn string `json:"relevance_score_column" yaml:"relevance_score_column"` + OutputClass ModelPredictionOutputClass `json:"output_class" yaml:"output_class" validate:"required"` } // Regression is specification for prediction of regression model type RegressionOutput struct { - PredictionScoreColumn string `json:"prediction_score_column"` - ActualScoreColumn string `json:"actual_score_column"` - OutputClass ModelPredictionOutputClass `json:"output_class" validate:"required"` + PredictionScoreColumn string `json:"prediction_score_column" yaml:"prediction_score_column"` + ActualScoreColumn string `json:"actual_score_column" yaml:"actual_score_column"` + OutputClass ModelPredictionOutputClass `json:"output_class" yaml:"output_class" validate:"required"` } diff --git a/api/models/observability_publisher.go b/api/models/observability_publisher.go new file mode 100644 index 000000000..f1878f7ee --- /dev/null +++ b/api/models/observability_publisher.go @@ -0,0 +1,67 @@ +package models + +import ( + "fmt" +) + +// PublisherStatus +type PublisherStatus string + +const ( + Pending PublisherStatus = "pending" + Running PublisherStatus = "running" + Failed PublisherStatus = "failed" + Terminated PublisherStatus = "terminated" +) + +// ObservabilityPublisher +type ObservabilityPublisher struct { + ID ID `gorm:"id"` + VersionModelID ID `gorm:"version_model_id"` + VersionID ID `gorm:"version_id"` + Revision int `gorm:"revision"` + Status PublisherStatus `gorm:"status"` + ModelSchemaSpec *SchemaSpec `gorm:"model_schema_spec"` + CreatedUpdated +} + +type ActionType string + +const ( + DeployPublisher ActionType = "deploy" + UndeployPublisher ActionType = "delete" +) + +type WorkerData struct { + Project string + ModelSchemaSpec *SchemaSpec + Metadata Metadata + ModelName string + ModelVersion string + Revision int + TopicSource string +} + +func NewWorkerData(modelVersion *Version, model *Model, observabilityPublisher *ObservabilityPublisher) *WorkerData { + return &WorkerData{ + ModelName: model.Name, + Project: model.Project.Name, + ModelSchemaSpec: observabilityPublisher.ModelSchemaSpec, + Metadata: Metadata{ + App: fmt.Sprintf("%s-observability-publisher", model.Name), + Component: "worker", + Stream: model.Project.Stream, + Team: model.Project.Team, + Labels: model.Project.Labels, + }, + ModelVersion: modelVersion.ID.String(), + Revision: observabilityPublisher.Revision, + TopicSource: getPredictionLogTopicForVersion(model.Project.Name, model.Name, modelVersion.ID.String()), + } +} + +type ObservabilityPublisherJob struct { + ActionType ActionType + Publisher *ObservabilityPublisher + WorkerData *WorkerData +} diff --git a/api/models/service.go b/api/models/service.go index ca070c69b..11d5106d5 100644 --- a/api/models/service.go +++ b/api/models/service.go @@ -118,29 +118,35 @@ func (svc *Service) GetPredictionLogTopic() string { } func (svc *Service) GetPredictionLogTopicForVersion() string { - return fmt.Sprintf("caraml-%s-%s-%s-prediction-log", svc.Namespace, svc.ModelName, svc.ModelVersion) + return getPredictionLogTopicForVersion(svc.Namespace, svc.ModelName, svc.ModelVersion) +} + +func getPredictionLogTopicForVersion(project string, modelName string, modelVersion string) string { + return fmt.Sprintf("caraml-%s-%s-%s-prediction-log", project, modelName, modelVersion) } func MergeProjectVersionLabels(projectLabels mlp.Labels, versionLabels KV) mlp.Labels { projectLabelsMap := map[string]int{} + updatedLabels := make(mlp.Labels, 0) for index, projectLabel := range projectLabels { projectLabelsMap[projectLabel.Key] = index + updatedLabels = append(updatedLabels, projectLabel) } for versionLabelKey, versionLabelValue := range versionLabels { if _, labelExists := projectLabelsMap[versionLabelKey]; labelExists { index := projectLabelsMap[versionLabelKey] - projectLabels[index].Value = fmt.Sprint(versionLabelValue) + updatedLabels[index].Value = fmt.Sprint(versionLabelValue) continue } - projectLabels = append(projectLabels, mlpclient.Label{ + updatedLabels = append(updatedLabels, mlpclient.Label{ Key: versionLabelKey, Value: fmt.Sprint(versionLabelValue), }) } - return projectLabels + return updatedLabels } func CreateInferenceServiceName(modelName, versionID, revisionID string) string { diff --git a/api/pkg/observability/deployment/config.go b/api/pkg/observability/deployment/config.go new file mode 100644 index 000000000..f0685acc5 --- /dev/null +++ b/api/pkg/observability/deployment/config.go @@ -0,0 +1,54 @@ +package deployment + +import ( + "github.com/caraml-dev/merlin/models" +) + +type ConsumerConfig struct { + ModelID string `yaml:"model_id"` + ModelVersion string `yaml:"model_version"` + InferenceSchema *models.SchemaSpec `yaml:"inference_schema"` + ObservationSinks []ObservationSink `yaml:"observation_sinks"` + ObservationSource *ObserVationSource `yaml:"observation_source"` +} + +type ObserVationSource struct { + Type SourceType `yaml:"type"` + Config any `yaml:"config"` +} + +type KafkaSource struct { + Topic string `yaml:"topic"` + BootstrapServers string `yaml:"bootstrap_servers"` + GroupID string `yaml:"group_id"` + BatchSize int `yaml:"batch_size"` + AdditionalConsumerConfig map[string]string `yaml:"additional_consumer_config"` +} + +type SinkType string +type SourceType string + +const ( + Arize SinkType = "ARIZE" + BQ SinkType = "BIGQUERY" + + Kafka SourceType = "KAFKA" + + PublisherRevisionAnnotationKey = "publisher-revision" +) + +type ObservationSink struct { + Type SinkType `yaml:"type"` + Config any `yaml:"config"` +} + +type ArizeSink struct { + APIKey string `yaml:"api_key"` + SpaceKey string `yaml:"space_key"` +} + +type BigQuerySink struct { + Project string `yaml:"project"` + Dataset string `yaml:"dataset"` + TTLDays int `yaml:"ttl_days"` +} diff --git a/api/pkg/observability/deployment/deployment.go b/api/pkg/observability/deployment/deployment.go new file mode 100644 index 000000000..4c58287c4 --- /dev/null +++ b/api/pkg/observability/deployment/deployment.go @@ -0,0 +1,457 @@ +package deployment + +import ( + "context" + "fmt" + "strconv" + "time" + + "github.com/caraml-dev/merlin/config" + "github.com/caraml-dev/merlin/log" + "github.com/caraml-dev/merlin/models" + "gopkg.in/yaml.v2" + appsv1 "k8s.io/api/apps/v1" + corev1 "k8s.io/api/core/v1" + k8serrors "k8s.io/apimachinery/pkg/api/errors" + "k8s.io/apimachinery/pkg/api/resource" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/client-go/kubernetes" + "k8s.io/client-go/rest" +) + +const ( + appLabelKey = "app" +) + +type Manifest struct { + Deployment *appsv1.Deployment + Secret *corev1.Secret + OnProgress bool +} + +type Deployer interface { + Deploy(ctx context.Context, data *models.WorkerData) error + GetDeployedManifest(ctx context.Context, data *models.WorkerData) (*Manifest, error) + Undeploy(ctx context.Context, data *models.WorkerData) error +} + +type deployer struct { + kubeClient kubernetes.Interface + consumerConfig config.ObservabilityPublisher + + resourceRequest corev1.ResourceList + resourceLimit corev1.ResourceList +} + +func New(restConfig *rest.Config, consumerConfig config.ObservabilityPublisher) (*deployer, error) { + kubeClient, err := kubernetes.NewForConfig(restConfig) + if err != nil { + return nil, err + } + + resourceRequest, err := parseResourceList(consumerConfig.DefaultResources.Requests) + if err != nil { + return nil, err + } + resourceLimit, err := parseResourceList(consumerConfig.DefaultResources.Limits) + if err != nil { + return nil, err + } + + return &deployer{ + kubeClient: kubeClient, + consumerConfig: consumerConfig, + resourceRequest: resourceRequest, + resourceLimit: resourceLimit, + }, nil +} + +func parseResourceList(resourceCfg config.Resource) (corev1.ResourceList, error) { + resourceList := corev1.ResourceList{} + if resourceCfg.CPU != "" { + quantity, err := resource.ParseQuantity(resourceCfg.CPU) + if err != nil { + return corev1.ResourceList{}, err + } + resourceList[corev1.ResourceCPU] = quantity + } + + if resourceCfg.Memory != "" { + quantity, err := resource.ParseQuantity(resourceCfg.Memory) + if err != nil { + return corev1.ResourceList{}, err + } + resourceList[corev1.ResourceMemory] = quantity + } + + return resourceList, nil +} + +func (c *deployer) targetNamespace() string { + return c.consumerConfig.TargetNamespace +} + +func (c *deployer) GetDeployedManifest(ctx context.Context, data *models.WorkerData) (*Manifest, error) { + secretName := c.getSecretName(data) + secret, err := c.getSecret(ctx, secretName, c.targetNamespace()) + if err != nil { + return nil, err + } + deploymentName := c.getDeploymentName(data) + depl, err := c.getDeployment(ctx, deploymentName, c.targetNamespace()) + if err != nil { + return nil, err + } + if depl == nil && secret == nil { + return nil, err + } + + isDeploymentRolledOut, err := deploymentRolledOut(depl, data.Revision, false) + if err != nil { + return nil, err + } + return &Manifest{Secret: secret, Deployment: depl, OnProgress: !isDeploymentRolledOut}, nil +} + +func deploymentRolledOut(depl *appsv1.Deployment, revision int, strictCheck bool) (bool, error) { + deploymentRev, err := getDeploymentRevision(depl) + if err != nil { + return false, err + } + + if strictCheck && deploymentRev != int64(revision) { + return false, fmt.Errorf("revision is not matched, requested: %d - actual: %d", revision, deploymentRev) + } + + if depl.Generation <= depl.Status.ObservedGeneration { + cond := getDeploymentCondition(depl.Status, appsv1.DeploymentProgressing) + if cond != nil && cond.Reason == timeoutReason { + return false, fmt.Errorf("deployment %q exceeded its progress deadline", depl.Name) + } + if depl.Spec.Replicas != nil && depl.Status.UpdatedReplicas < *depl.Spec.Replicas { + return false, nil + } + if depl.Status.Replicas > depl.Status.UpdatedReplicas { + return false, nil + } + if depl.Status.AvailableReplicas < depl.Status.UpdatedReplicas { + return false, nil + } + return true, nil + } + + return false, nil +} + +func (c *deployer) Deploy(ctx context.Context, data *models.WorkerData) (err error) { + secret, previousSecret, err := c.applySecret(ctx, data) + if err != nil { + return err + } + + defer func() { + if err != nil { + // meaning that we need to rollback to previous secret + if previousSecret != nil { + if _, err := c.rollbackSecret(ctx, previousSecret); err != nil { + log.Warnf("failed rollback secret to previous with err: %v", err) + } + } else { + // delete current secret + if err := c.deleteSecret(ctx, secret.Name, secret.Namespace); err != nil { + log.Warnf("failed delete secret with err: %v", err) + } + // delete current deployment + if err := c.deleteDeployment(ctx, c.getDeploymentName(data), c.targetNamespace()); err != nil { + log.Warnf("failed delete deployment with err: %v", err) + } + } + } + }() + deployment, err := c.applyDeployment(ctx, data, secret.Name) + if err != nil { + return err + } + + if err := c.waitUntilDeploymentReady(ctx, deployment, data.Revision); err != nil { + return err + } + + return nil +} + +func (c *deployer) rollbackSecret(ctx context.Context, secret *corev1.Secret) (*corev1.Secret, error) { + coreV1 := c.kubeClient.CoreV1() + secretV1 := coreV1.Secrets(secret.Namespace) + return secretV1.Update(ctx, secret, metav1.UpdateOptions{}) +} + +func (c *deployer) getSecret(ctx context.Context, secretName string, namespace string) (*corev1.Secret, error) { + coreV1 := c.kubeClient.CoreV1() + secretV1 := coreV1.Secrets(namespace) + secret, err := secretV1.Get(ctx, secretName, metav1.GetOptions{}) + if err != nil { + if !k8serrors.IsNotFound(err) { + return nil, err + } + return nil, nil + } + return secret, nil +} + +func (c *deployer) deleteSecret(ctx context.Context, secretName string, namespace string) error { + coreV1 := c.kubeClient.CoreV1() + secretV1 := coreV1.Secrets(namespace) + return secretV1.Delete(ctx, secretName, metav1.DeleteOptions{}) +} + +func (c *deployer) getDeployment(ctx context.Context, deploymentName string, namespace string) (*appsv1.Deployment, error) { + appV1 := c.kubeClient.AppsV1() + deploymentV1 := appV1.Deployments(namespace) + deployment, err := deploymentV1.Get(ctx, deploymentName, metav1.GetOptions{}) + if err != nil { + if !k8serrors.IsNotFound(err) { + return nil, err + } + return nil, nil + } + return deployment, nil +} + +func (c *deployer) deleteDeployment(ctx context.Context, deploymentName string, namespace string) error { + appV1 := c.kubeClient.AppsV1() + deploymentV1 := appV1.Deployments(namespace) + return deploymentV1.Delete(ctx, deploymentName, metav1.DeleteOptions{}) +} + +func (c *deployer) applySecret(ctx context.Context, data *models.WorkerData) (secret *corev1.Secret, previousSecret *corev1.Secret, err error) { + // Create secret + coreV1 := c.kubeClient.CoreV1() + secretV1 := coreV1.Secrets(c.targetNamespace()) + secretName := c.getSecretName(data) + applySecretFunc := func(data *models.WorkerData, isExistingSecret bool) (*corev1.Secret, error) { + secretSpec, err := c.createSecretSpec(data) + if err != nil { + return nil, err + } + if isExistingSecret { + return secretV1.Update(ctx, secretSpec, metav1.UpdateOptions{}) + } + return secretV1.Create(ctx, secretSpec, metav1.CreateOptions{}) + } + previousSecret, err = secretV1.Get(ctx, secretName, metav1.GetOptions{}) + if err != nil { + if !k8serrors.IsNotFound(err) { + return nil, nil, err + } + secret, err = applySecretFunc(data, false) + return secret, nil, err + } + secret, err = applySecretFunc(data, true) + return secret, previousSecret, err +} + +func (c *deployer) createSecretSpec(data *models.WorkerData) (*corev1.Secret, error) { + consumerCfg := &ConsumerConfig{ + ModelID: data.ModelName, + ModelVersion: data.ModelVersion, + InferenceSchema: data.ModelSchemaSpec, + ObservationSinks: []ObservationSink{ + { + Type: Arize, + Config: ArizeSink{ + APIKey: c.consumerConfig.ArizeSink.APIKey, + SpaceKey: c.consumerConfig.ArizeSink.SpaceKey, + }, + }, + { + Type: BQ, + Config: BigQuerySink{ + Project: c.consumerConfig.BigQuerySink.Project, + Dataset: c.consumerConfig.BigQuerySink.Dataset, + TTLDays: c.consumerConfig.BigQuerySink.TTLDays, + }, + }, + }, + ObservationSource: &ObserVationSource{ + Type: Kafka, + Config: &KafkaSource{ + Topic: data.TopicSource, + BootstrapServers: c.consumerConfig.KafkaConsumer.Brokers, + GroupID: c.consumerConfig.KafkaConsumer.GroupID, + BatchSize: c.consumerConfig.KafkaConsumer.BatchSize, + AdditionalConsumerConfig: c.consumerConfig.KafkaConsumer.AdditionalConsumerConfig, + }, + }, + } + consumerCfgStr, err := yaml.Marshal(consumerCfg) + if err != nil { + return nil, err + } + return &corev1.Secret{ + ObjectMeta: metav1.ObjectMeta{ + Name: c.getSecretName(data), + Namespace: c.targetNamespace(), + Labels: c.getLabels(data), + }, + StringData: map[string]string{ + "config.yaml": string(consumerCfgStr), + }, + }, nil +} + +func (c *deployer) applyDeployment(ctx context.Context, data *models.WorkerData, secretName string) (*appsv1.Deployment, error) { + appV1 := c.kubeClient.AppsV1() + deploymentName := c.getDeploymentName(data) + deploymentV1 := appV1.Deployments(c.targetNamespace()) + + applyDeploymentFunc := func(data *models.WorkerData, secretName string, isExistingDeployment bool) (*appsv1.Deployment, error) { + deployment, err := c.createDeploymentSpec(ctx, data, secretName) + if err != nil { + return nil, err + } + if isExistingDeployment { + return deploymentV1.Update(ctx, deployment, metav1.UpdateOptions{}) + } + return deploymentV1.Create(ctx, deployment, metav1.CreateOptions{}) + } + _, err := deploymentV1.Get(ctx, deploymentName, metav1.GetOptions{}) + if err != nil { + if !k8serrors.IsNotFound(err) { + return nil, err + } + return applyDeploymentFunc(data, secretName, false) + } + + return applyDeploymentFunc(data, secretName, true) +} + +func (c *deployer) getLabels(data *models.WorkerData) map[string]string { + labels := data.Metadata.ToLabel() + labels[appLabelKey] = data.Metadata.App + return labels +} + +func (c *deployer) createDeploymentSpec(ctx context.Context, data *models.WorkerData, secretName string) (*appsv1.Deployment, error) { + labels := c.getLabels(data) + + cfgVolName := "config-volume" + return &appsv1.Deployment{ + ObjectMeta: metav1.ObjectMeta{ + Name: c.getDeploymentName(data), + Namespace: c.targetNamespace(), + Labels: labels, + Annotations: map[string]string{ + PublisherRevisionAnnotationKey: strconv.Itoa(data.Revision), + }, + }, + Spec: appsv1.DeploymentSpec{ + Selector: &metav1.LabelSelector{ + MatchLabels: map[string]string{ + appLabelKey: data.Metadata.App, + }, + }, + Replicas: &c.consumerConfig.Replicas, + Template: corev1.PodTemplateSpec{ + ObjectMeta: metav1.ObjectMeta{ + Labels: labels, + Annotations: map[string]string{ + PublisherRevisionAnnotationKey: strconv.Itoa(data.Revision), + }, + }, + + Spec: corev1.PodSpec{ + Containers: []corev1.Container{ + { + Name: "worker", + Image: c.consumerConfig.ImageName, + Command: []string{ + "python", + "-m", + "publisher", + "+environment=config", + }, + ImagePullPolicy: corev1.PullIfNotPresent, + + Resources: corev1.ResourceRequirements{ + Requests: c.resourceRequest, + Limits: c.resourceLimit, + }, + VolumeMounts: []corev1.VolumeMount{ + { + Name: cfgVolName, + MountPath: "/mlobs/observation-publisher/conf/environment", + ReadOnly: true, + }, + }, + }, + }, + Volumes: []corev1.Volume{ + { + Name: cfgVolName, + VolumeSource: corev1.VolumeSource{ + Secret: &corev1.SecretVolumeSource{ + SecretName: c.getSecretName(data), + }, + }, + }, + }, + ServiceAccountName: c.consumerConfig.ServiceAccountName, + }, + }, + }, + }, nil +} + +func (c *deployer) waitUntilDeploymentReady(ctx context.Context, deployment *appsv1.Deployment, revision int) error { + deploymentv1 := c.kubeClient.AppsV1().Deployments(deployment.Namespace) + timeout := time.After(c.consumerConfig.DeploymentTimeout) + watcher, err := deploymentv1.Watch(ctx, metav1.ListOptions{ + FieldSelector: fmt.Sprintf("metadata.name=%s", deployment.Name), + }) + if err != nil { + return err + } + for { + select { + case <-timeout: + watcher.Stop() + return fmt.Errorf("timeout waiting deployment ready") + case watchRes := <-watcher.ResultChan(): + deployManifest, ok := watchRes.Object.(*appsv1.Deployment) + if !ok { + return fmt.Errorf("watch result is not deployment") + } + + rolledOut, err := deploymentRolledOut(deployManifest, revision, true) + if err != nil { + return err + } + if rolledOut { + return nil + } + } + } +} + +func (c *deployer) getDeploymentName(data *models.WorkerData) string { + return fmt.Sprintf("%s-%s-mlobs", data.Project, data.ModelName) +} + +func (c *deployer) getSecretName(data *models.WorkerData) string { + return fmt.Sprintf("%s-%s-config", data.Project, data.ModelName) +} + +func (c *deployer) Undeploy(ctx context.Context, data *models.WorkerData) error { + deploymentName := c.getDeploymentName(data) + if err := c.deleteDeployment(ctx, deploymentName, c.targetNamespace()); err != nil { + return err + } + + secretName := c.getSecretName(data) + if err := c.deleteSecret(ctx, secretName, c.targetNamespace()); err != nil { + return err + } + + return nil +} diff --git a/api/pkg/observability/deployment/deployment_test.go b/api/pkg/observability/deployment/deployment_test.go new file mode 100644 index 000000000..68f78eeb7 --- /dev/null +++ b/api/pkg/observability/deployment/deployment_test.go @@ -0,0 +1,1030 @@ +package deployment + +import ( + "context" + "fmt" + "time" + + "net/http" + "reflect" + "strconv" + "testing" + + "github.com/caraml-dev/merlin/config" + "github.com/caraml-dev/merlin/models" + "github.com/stretchr/testify/assert" + appsv1 "k8s.io/api/apps/v1" + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/watch" + + "k8s.io/apimachinery/pkg/api/errors" + "k8s.io/apimachinery/pkg/api/resource" + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/client-go/kubernetes" + "k8s.io/client-go/kubernetes/fake" + + fakeappsv1 "k8s.io/client-go/kubernetes/typed/apps/v1/fake" + fakecorev1 "k8s.io/client-go/kubernetes/typed/core/v1/fake" + + ktesting "k8s.io/client-go/testing" +) + +const ( + getMethod = "get" + createMethod = "create" + updateMethod = "update" + deleteMethod = "delete" + + secretResource = "secrets" + deploymentResource = "deployments" +) + +type deploymentStatus string + +const ( + noStatus deploymentStatus = "no_status" + onProgress deploymentStatus = "on_progress" + ready deploymentStatus = "ready" + timeoutError deploymentStatus = "timeout_error" + + namespace = "caraml-observability" + serviceAccountName = "caraml-observability-sa" +) + +func createDeploymentSpec(data *models.WorkerData, resourceRequest corev1.ResourceList, resourceLimit corev1.ResourceList, imageName string) *appsv1.Deployment { + labels := data.Metadata.ToLabel() + labels[appLabelKey] = data.Metadata.App + numReplicas := int32(2) + cfgVolName := "config-volume" + depl := &appsv1.Deployment{ + ObjectMeta: metav1.ObjectMeta{ + Name: fmt.Sprintf("%s-%s-mlobs", data.Project, data.ModelName), + Namespace: namespace, + Labels: labels, + Annotations: map[string]string{ + PublisherRevisionAnnotationKey: strconv.Itoa(data.Revision), + }, + }, + Spec: appsv1.DeploymentSpec{ + Selector: &metav1.LabelSelector{ + MatchLabels: map[string]string{ + "app": data.Metadata.App, + }, + }, + Template: corev1.PodTemplateSpec{ + ObjectMeta: metav1.ObjectMeta{ + Labels: labels, + Annotations: map[string]string{ + PublisherRevisionAnnotationKey: strconv.Itoa(data.Revision), + }, + }, + Spec: corev1.PodSpec{ + ServiceAccountName: serviceAccountName, + Containers: []corev1.Container{ + { + Name: "worker", + Image: imageName, + Command: []string{ + "python", + "-m", + "publisher", + "+environment=config", + }, + ImagePullPolicy: corev1.PullIfNotPresent, + + Resources: corev1.ResourceRequirements{ + Requests: resourceRequest, + Limits: resourceLimit, + }, + VolumeMounts: []corev1.VolumeMount{ + { + Name: cfgVolName, + MountPath: "/mlobs/observation-publisher/conf/environment", + ReadOnly: true, + }, + }, + }, + }, + Volumes: []corev1.Volume{ + { + Name: cfgVolName, + VolumeSource: corev1.VolumeSource{ + Secret: &corev1.SecretVolumeSource{ + SecretName: fmt.Sprintf("%s-%s-config", data.Project, data.ModelName), + }, + }, + }, + }, + }, + }, + Replicas: &numReplicas, + }, + } + + return depl +} + +func changeDeploymentStatus(depl *appsv1.Deployment, status deploymentStatus, revision int) *appsv1.Deployment { + updatedDepl := depl.DeepCopy() + numReplicas := int32(2) + var detailStatus appsv1.DeploymentStatus + if status == onProgress { + detailStatus = appsv1.DeploymentStatus{ + Replicas: numReplicas + 1, + UnavailableReplicas: numReplicas, + UpdatedReplicas: 1, + } + } else if status == timeoutError { + detailStatus = appsv1.DeploymentStatus{ + Conditions: []appsv1.DeploymentCondition{ + { + Type: appsv1.DeploymentProgressing, + Reason: timeoutReason, + }, + }, + Replicas: numReplicas, + UpdatedReplicas: 1, + } + } else { + detailStatus = appsv1.DeploymentStatus{ + Replicas: numReplicas, + UpdatedReplicas: numReplicas, + AvailableReplicas: numReplicas, + } + } + updatedDepl.Status = detailStatus + updatedDepl.Annotations[k8sRevisionAnnotation] = strconv.Itoa(revision) + return updatedDepl +} + +type deploymentWatchReactor struct { + result chan watch.Event +} + +func newDeploymentWatchReactor(depl *appsv1.Deployment) *deploymentWatchReactor { + w := &deploymentWatchReactor{result: make(chan watch.Event, 1)} + w.result <- watch.Event{Type: watch.Added, Object: depl} + return w +} + +func (w *deploymentWatchReactor) Handles(action ktesting.Action) bool { + return action.GetResource().Resource == deploymentResource +} + +func (w *deploymentWatchReactor) React(action ktesting.Action) (handled bool, ret watch.Interface, err error) { + return true, watch.NewProxyWatcher(w.result), nil +} + +func Test_deployer_Deploy(t *testing.T) { + consumerConfig := config.ObservabilityPublisher{ + ArizeSink: config.ArizeSink{ + APIKey: "api-key", + SpaceKey: "space-key", + }, + BigQuerySink: config.BigQuerySink{ + Project: "bq-project", + Dataset: "dataset", + TTLDays: 10, + }, + KafkaConsumer: config.KafkaConsumer{ + Brokers: "broker-1", + GroupID: "group-id", + BatchSize: 100, + AdditionalConsumerConfig: map[string]string{ + "auto.offset.reset": "latest", + "fetch.min.bytes": "1024000", + }, + }, + ImageName: "observability-publisher:v0.0", + DefaultResources: config.ResourceRequestsLimits{ + Requests: config.Resource{ + CPU: "1", + Memory: "1Gi", + }, + Limits: config.Resource{ + Memory: "1Gi", + }, + }, + EnvironmentName: "dev", + Replicas: 2, + TargetNamespace: namespace, + ServiceAccountName: serviceAccountName, + DeploymentTimeout: 5 * time.Second, + } + requestResource := corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("1"), + corev1.ResourceMemory: resource.MustParse("1Gi"), + } + limitResource := corev1.ResourceList{ + corev1.ResourceMemory: resource.MustParse("1Gi"), + } + schemaSpec := &models.SchemaSpec{ + SessionIDColumn: "session_id", + RowIDColumn: "row_id", + TagColumns: []string{"tag"}, + FeatureTypes: map[string]models.ValueType{ + "featureA": models.Float64, + "featureB": models.Float64, + "featureC": models.Int64, + "featureD": models.Boolean, + }, + ModelPredictionOutput: &models.ModelPredictionOutput{ + BinaryClassificationOutput: &models.BinaryClassificationOutput{ + NegativeClassLabel: "negative", + PositiveClassLabel: "positive", + PredictionLabelColumn: "prediction_label", + PredictionScoreColumn: "prediction_score", + OutputClass: models.BinaryClassification, + }, + }, + } + tests := []struct { + name string + data *models.WorkerData + kubeClient kubernetes.Interface + consumerConfig config.ObservabilityPublisher + resourceRequest corev1.ResourceList + resourceLimit corev1.ResourceList + expectedErr error + }{ + { + name: "fresh deployment", + data: &models.WorkerData{ + Project: "project-1", + ModelSchemaSpec: schemaSpec, + ModelName: "model-1", + ModelVersion: "1", + Revision: 1, + Metadata: models.Metadata{ + App: "model-1-observability-publisher", + Component: "worker", + Stream: "stream", + Team: "team", + }, + TopicSource: "caraml-project-1-model-1-1-prediction-log", + }, + consumerConfig: consumerConfig, + + resourceRequest: corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("1"), + corev1.ResourceMemory: resource.MustParse("1Gi"), + }, + resourceLimit: corev1.ResourceList{ + corev1.ResourceMemory: resource.MustParse("1Gi"), + }, + + kubeClient: func() kubernetes.Interface { + clientSet := fake.NewSimpleClientset() + secretAPI := clientSet.CoreV1().Secrets(namespace).(*fakecorev1.FakeSecrets) + prependGetSecretReactor(t, secretAPI, nil, nil) + prependUpsertSecretReactor(t, secretAPI, []*corev1.Secret{ + { + ObjectMeta: metav1.ObjectMeta{ + Name: "project-1-model-1-config", + Namespace: namespace, + Labels: map[string]string{ + "app": "model-1-observability-publisher", + "component": "worker", + "orchestrator": "merlin", + "stream": "stream", + "team": "team", + "environment": "", + }, + }, + StringData: map[string]string{ + "config.yaml": "model_id: model-1\nmodel_version: \"1\"\ninference_schema:\n session_id_column: session_id\n row_id_column: row_id\n model_prediction_output:\n actual_score_column: \"\"\n negative_class_label: negative\n prediction_score_column: prediction_score\n prediction_label_column: prediction_label\n positive_class_label: positive\n score_threshold: null\n output_class: BinaryClassificationOutput\n tag_columns:\n - tag\n feature_types:\n featureA: float64\n featureB: float64\n featureC: int64\n featureD: boolean\n feature_orders: []\nobservation_sinks:\n- type: ARIZE\n config:\n api_key: api-key\n space_key: space-key\n- type: BIGQUERY\n config:\n project: bq-project\n dataset: dataset\n ttl_days: 10\nobservation_source:\n type: KAFKA\n config:\n topic: caraml-project-1-model-1-1-prediction-log\n bootstrap_servers: broker-1\n group_id: group-id\n batch_size: 100\n additional_consumer_config:\n auto.offset.reset: latest\n fetch.min.bytes: \"1024000\"\n", + }, + }}, nil, false) + deploymentAPI := clientSet.AppsV1().Deployments(namespace).(*fakeappsv1.FakeDeployments) + preprendGetDeploymentReactor(t, deploymentAPI, nil, nil) + depl := createDeploymentSpec(&models.WorkerData{ + Project: "project-1", + ModelSchemaSpec: schemaSpec, + ModelName: "model-1", + ModelVersion: "1", + Revision: 1, + Metadata: models.Metadata{ + App: "model-1-observability-publisher", + Component: "worker", + Stream: "stream", + Team: "team", + }, + }, requestResource, limitResource, consumerConfig.ImageName) + prependUpsertDeploymentReactor(t, deploymentAPI, depl, nil, false) + + updatedDepl := changeDeploymentStatus(depl, ready, 1) + deplWatchReactor := newDeploymentWatchReactor(updatedDepl) + clientSet.WatchReactionChain = []ktesting.WatchReactor{deplWatchReactor} + + return clientSet + }(), + }, + { + name: "fresh deployment failed; failed create secret", + data: &models.WorkerData{ + Project: "project-1", + ModelSchemaSpec: schemaSpec, + ModelName: "model-1", + ModelVersion: "1", + Revision: 1, + Metadata: models.Metadata{ + App: "model-1-observability-publisher", + Component: "worker", + Stream: "stream", + Team: "team", + }, + TopicSource: "caraml-project-1-model-1-1-prediction-log", + }, + consumerConfig: consumerConfig, + + resourceRequest: corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("1"), + corev1.ResourceMemory: resource.MustParse("1Gi"), + }, + resourceLimit: corev1.ResourceList{ + corev1.ResourceMemory: resource.MustParse("1Gi"), + }, + + kubeClient: func() kubernetes.Interface { + clientSet := fake.NewSimpleClientset() + secretAPI := clientSet.CoreV1().Secrets(namespace).(*fakecorev1.FakeSecrets) + prependGetSecretReactor(t, secretAPI, nil, nil) + prependUpsertSecretReactor(t, secretAPI, []*corev1.Secret{ + { + ObjectMeta: metav1.ObjectMeta{ + Name: "project-1-model-1-config", + Namespace: namespace, + Labels: map[string]string{ + "app": "model-1-observability-publisher", + "component": "worker", + "orchestrator": "merlin", + "stream": "stream", + "team": "team", + "environment": "", + }, + }, + StringData: map[string]string{ + "config.yaml": "model_id: model-1\nmodel_version: \"1\"\ninference_schema:\n session_id_column: session_id\n row_id_column: row_id\n model_prediction_output:\n actual_score_column: \"\"\n negative_class_label: negative\n prediction_score_column: prediction_score\n prediction_label_column: prediction_label\n positive_class_label: positive\n score_threshold: null\n output_class: BinaryClassificationOutput\n tag_columns:\n - tag\n feature_types:\n featureA: float64\n featureB: float64\n featureC: int64\n featureD: boolean\n feature_orders: []\nobservation_sinks:\n- type: ARIZE\n config:\n api_key: api-key\n space_key: space-key\n- type: BIGQUERY\n config:\n project: bq-project\n dataset: dataset\n ttl_days: 10\nobservation_source:\n type: KAFKA\n config:\n topic: caraml-project-1-model-1-1-prediction-log\n bootstrap_servers: broker-1\n group_id: group-id\n batch_size: 100\n additional_consumer_config:\n auto.offset.reset: latest\n fetch.min.bytes: \"1024000\"\n", + }, + }}, fmt.Errorf("deployment control plane is down"), false) + return clientSet + }(), + expectedErr: fmt.Errorf("deployment control plane is down"), + }, + { + name: "fresh deployment failed; failed during deployment", + data: &models.WorkerData{ + Project: "project-1", + ModelSchemaSpec: schemaSpec, + ModelName: "model-1", + ModelVersion: "1", + Revision: 1, + Metadata: models.Metadata{ + App: "model-1-observability-publisher", + Component: "worker", + Stream: "stream", + Team: "team", + }, + TopicSource: "caraml-project-1-model-1-1-prediction-log", + }, + consumerConfig: consumerConfig, + + resourceRequest: corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("1"), + corev1.ResourceMemory: resource.MustParse("1Gi"), + }, + resourceLimit: corev1.ResourceList{ + corev1.ResourceMemory: resource.MustParse("1Gi"), + }, + + kubeClient: func() kubernetes.Interface { + clientSet := fake.NewSimpleClientset() + secretAPI := clientSet.CoreV1().Secrets(namespace).(*fakecorev1.FakeSecrets) + prependGetSecretReactor(t, secretAPI, nil, nil) + prependUpsertSecretReactor(t, secretAPI, []*corev1.Secret{ + { + ObjectMeta: metav1.ObjectMeta{ + Name: "project-1-model-1-config", + Namespace: namespace, + Labels: map[string]string{ + "app": "model-1-observability-publisher", + "component": "worker", + "orchestrator": "merlin", + "stream": "stream", + "team": "team", + "environment": "", + }, + }, + StringData: map[string]string{ + "config.yaml": "model_id: model-1\nmodel_version: \"1\"\ninference_schema:\n session_id_column: session_id\n row_id_column: row_id\n model_prediction_output:\n actual_score_column: \"\"\n negative_class_label: negative\n prediction_score_column: prediction_score\n prediction_label_column: prediction_label\n positive_class_label: positive\n score_threshold: null\n output_class: BinaryClassificationOutput\n tag_columns:\n - tag\n feature_types:\n featureA: float64\n featureB: float64\n featureC: int64\n featureD: boolean\n feature_orders: []\nobservation_sinks:\n- type: ARIZE\n config:\n api_key: api-key\n space_key: space-key\n- type: BIGQUERY\n config:\n project: bq-project\n dataset: dataset\n ttl_days: 10\nobservation_source:\n type: KAFKA\n config:\n topic: caraml-project-1-model-1-1-prediction-log\n bootstrap_servers: broker-1\n group_id: group-id\n batch_size: 100\n additional_consumer_config:\n auto.offset.reset: latest\n fetch.min.bytes: \"1024000\"\n", + }, + }}, nil, false) + deploymentAPI := clientSet.AppsV1().Deployments(namespace).(*fakeappsv1.FakeDeployments) + preprendGetDeploymentReactor(t, deploymentAPI, nil, nil) + depl := createDeploymentSpec(&models.WorkerData{ + Project: "project-1", + ModelSchemaSpec: schemaSpec, + ModelName: "model-1", + ModelVersion: "1", + Revision: 1, + Metadata: models.Metadata{ + App: "model-1-observability-publisher", + Component: "worker", + Stream: "stream", + Team: "team", + }, + }, requestResource, limitResource, consumerConfig.ImageName) + prependUpsertDeploymentReactor(t, deploymentAPI, depl, fmt.Errorf("control plane is down"), false) + prependDeleteSecretReactor(t, secretAPI, "project-1-model-1-config", nil) + prependDeleteDeploymentReactor(t, deploymentAPI, "project-1-model-1-mlobs", nil) + return clientSet + }(), + expectedErr: fmt.Errorf("control plane is down"), + }, + { + name: "redeployment", + data: &models.WorkerData{ + Project: "project-1", + ModelSchemaSpec: schemaSpec, + ModelName: "model-1", + ModelVersion: "2", + Revision: 2, + Metadata: models.Metadata{ + App: "model-1-observability-publisher", + Component: "worker", + Stream: "stream", + Team: "team", + }, + TopicSource: "caraml-project-1-model-1-2-prediction-log", + }, + consumerConfig: consumerConfig, + + resourceRequest: corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("1"), + corev1.ResourceMemory: resource.MustParse("1Gi"), + }, + resourceLimit: corev1.ResourceList{ + corev1.ResourceMemory: resource.MustParse("1Gi"), + }, + + kubeClient: func() kubernetes.Interface { + clientSet := fake.NewSimpleClientset() + secretAPI := clientSet.CoreV1().Secrets(namespace).(*fakecorev1.FakeSecrets) + prependGetSecretReactor(t, secretAPI, &corev1.Secret{ + ObjectMeta: metav1.ObjectMeta{ + Name: "project-1-model-1-config", + Namespace: namespace, + Labels: map[string]string{ + "app": "model-1-observability-publisher", + "component": "worker", + "orchestrator": "merlin", + "stream": "stream", + "team": "team", + "environment": "", + }, + }, + StringData: map[string]string{ + "config.yaml": "model_id: model-1\nmodel_version: \"1\"\ninference_schema:\n session_id_column: session_id\n row_id_column: row_id\n model_prediction_output:\n actual_score_column: \"\"\n negative_class_label: negative\n prediction_score_column: prediction_score\n prediction_label_column: prediction_label\n positive_class_label: positive\n score_threshold: null\n output_class: BinaryClassificationOutput\n tag_columns:\n - tag\n feature_types:\n featureA: float64\n featureB: float64\n featureC: int64\n featureD: boolean\n feature_orders: []\nobservation_sinks:\n- type: ARIZE\n config:\n api_key: api-key\n space_key: space-key\n- type: BIGQUERY\n config:\n project: bq-project\n dataset: dataset\n ttl_days: 10\nobservation_source:\n type: KAFKA\n config:\n topic: caraml-project-1-model-1-1-prediction-log\n bootstrap_servers: broker-1\n group_id: group-id\n batch_size: 100\n additional_consumer_config:\n auto.offset.reset: latest\n fetch.min.bytes: \"1024000\"\n", + }, + }, nil) + prependUpsertSecretReactor(t, secretAPI, []*corev1.Secret{ + { + ObjectMeta: metav1.ObjectMeta{ + Name: "project-1-model-1-config", + Namespace: namespace, + Labels: map[string]string{ + "app": "model-1-observability-publisher", + "component": "worker", + "orchestrator": "merlin", + "stream": "stream", + "team": "team", + "environment": "", + }, + }, + StringData: map[string]string{ + "config.yaml": "model_id: model-1\nmodel_version: \"2\"\ninference_schema:\n session_id_column: session_id\n row_id_column: row_id\n model_prediction_output:\n actual_score_column: \"\"\n negative_class_label: negative\n prediction_score_column: prediction_score\n prediction_label_column: prediction_label\n positive_class_label: positive\n score_threshold: null\n output_class: BinaryClassificationOutput\n tag_columns:\n - tag\n feature_types:\n featureA: float64\n featureB: float64\n featureC: int64\n featureD: boolean\n feature_orders: []\nobservation_sinks:\n- type: ARIZE\n config:\n api_key: api-key\n space_key: space-key\n- type: BIGQUERY\n config:\n project: bq-project\n dataset: dataset\n ttl_days: 10\nobservation_source:\n type: KAFKA\n config:\n topic: caraml-project-1-model-1-2-prediction-log\n bootstrap_servers: broker-1\n group_id: group-id\n batch_size: 100\n additional_consumer_config:\n auto.offset.reset: latest\n fetch.min.bytes: \"1024000\"\n", + }, + }}, nil, true) + deploymentAPI := clientSet.AppsV1().Deployments(namespace).(*fakeappsv1.FakeDeployments) + depl := createDeploymentSpec(&models.WorkerData{ + Project: "project-1", + ModelSchemaSpec: schemaSpec, + ModelName: "model-1", + ModelVersion: "2", + Revision: 2, + Metadata: models.Metadata{ + App: "model-1-observability-publisher", + Component: "worker", + Stream: "stream", + Team: "team", + }, + }, requestResource, limitResource, consumerConfig.ImageName) + preprendGetDeploymentReactor(t, deploymentAPI, depl, nil) + prependUpsertDeploymentReactor(t, deploymentAPI, depl, nil, true) + + updatedDepl := changeDeploymentStatus(depl, ready, 2) + deplWatchReactor := newDeploymentWatchReactor(updatedDepl) + clientSet.WatchReactionChain = []ktesting.WatchReactor{deplWatchReactor} + + return clientSet + }(), + }, + { + name: "redeployment failed; timeout waiting for deployment", + data: &models.WorkerData{ + Project: "project-1", + ModelSchemaSpec: schemaSpec, + ModelName: "model-1", + ModelVersion: "2", + Revision: 2, + Metadata: models.Metadata{ + App: "model-1-observability-publisher", + Component: "worker", + Stream: "stream", + Team: "team", + }, + TopicSource: "caraml-project-1-model-1-2-prediction-log", + }, + consumerConfig: consumerConfig, + + resourceRequest: corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("1"), + corev1.ResourceMemory: resource.MustParse("1Gi"), + }, + resourceLimit: corev1.ResourceList{ + corev1.ResourceMemory: resource.MustParse("1Gi"), + }, + + kubeClient: func() kubernetes.Interface { + clientSet := fake.NewSimpleClientset() + secretAPI := clientSet.CoreV1().Secrets(namespace).(*fakecorev1.FakeSecrets) + prependGetSecretReactor(t, secretAPI, &corev1.Secret{ + ObjectMeta: metav1.ObjectMeta{ + Name: "project-1-model-1-config", + Namespace: namespace, + Labels: map[string]string{ + "app": "model-1-observability-publisher", + "component": "worker", + "orchestrator": "merlin", + "stream": "stream", + "team": "team", + "environment": "", + }, + }, + StringData: map[string]string{ + "config.yaml": "model_id: model-1\nmodel_version: \"1\"\ninference_schema:\n session_id_column: session_id\n row_id_column: row_id\n model_prediction_output:\n actual_score_column: \"\"\n negative_class_label: negative\n prediction_score_column: prediction_score\n prediction_label_column: prediction_label\n positive_class_label: positive\n score_threshold: null\n output_class: BinaryClassificationOutput\n tag_columns:\n - tag\n feature_types:\n featureA: float64\n featureB: float64\n featureC: int64\n featureD: boolean\n feature_orders: []\nobservation_sinks:\n- type: ARIZE\n config:\n api_key: api-key\n space_key: space-key\n- type: BIGQUERY\n config:\n project: bq-project\n dataset: dataset\n ttl_days: 10\nobservation_source:\n type: KAFKA\n config:\n topic: caraml-project-1-model-1-1-prediction-log\n bootstrap_servers: broker-1\n group_id: group-id\n batch_size: 100\n additional_consumer_config:\n auto.offset.reset: latest\n fetch.min.bytes: \"1024000\"\n", + }, + }, nil) + prependUpsertSecretReactor(t, secretAPI, []*corev1.Secret{ + { + ObjectMeta: metav1.ObjectMeta{ + Name: "project-1-model-1-config", + Namespace: namespace, + Labels: map[string]string{ + "app": "model-1-observability-publisher", + "component": "worker", + "orchestrator": "merlin", + "stream": "stream", + "team": "team", + "environment": "", + }, + }, + StringData: map[string]string{ + "config.yaml": "model_id: model-1\nmodel_version: \"2\"\ninference_schema:\n session_id_column: session_id\n row_id_column: row_id\n model_prediction_output:\n actual_score_column: \"\"\n negative_class_label: negative\n prediction_score_column: prediction_score\n prediction_label_column: prediction_label\n positive_class_label: positive\n score_threshold: null\n output_class: BinaryClassificationOutput\n tag_columns:\n - tag\n feature_types:\n featureA: float64\n featureB: float64\n featureC: int64\n featureD: boolean\n feature_orders: []\nobservation_sinks:\n- type: ARIZE\n config:\n api_key: api-key\n space_key: space-key\n- type: BIGQUERY\n config:\n project: bq-project\n dataset: dataset\n ttl_days: 10\nobservation_source:\n type: KAFKA\n config:\n topic: caraml-project-1-model-1-2-prediction-log\n bootstrap_servers: broker-1\n group_id: group-id\n batch_size: 100\n additional_consumer_config:\n auto.offset.reset: latest\n fetch.min.bytes: \"1024000\"\n", + }, + }, { + ObjectMeta: metav1.ObjectMeta{ + Name: "project-1-model-1-config", + Namespace: namespace, + Labels: map[string]string{ + "app": "model-1-observability-publisher", + "component": "worker", + "orchestrator": "merlin", + "stream": "stream", + "team": "team", + "environment": "", + }, + }, + StringData: map[string]string{ + "config.yaml": "model_id: model-1\nmodel_version: \"1\"\ninference_schema:\n session_id_column: session_id\n row_id_column: row_id\n model_prediction_output:\n actual_score_column: \"\"\n negative_class_label: negative\n prediction_score_column: prediction_score\n prediction_label_column: prediction_label\n positive_class_label: positive\n score_threshold: null\n output_class: BinaryClassificationOutput\n tag_columns:\n - tag\n feature_types:\n featureA: float64\n featureB: float64\n featureC: int64\n featureD: boolean\n feature_orders: []\nobservation_sinks:\n- type: ARIZE\n config:\n api_key: api-key\n space_key: space-key\n- type: BIGQUERY\n config:\n project: bq-project\n dataset: dataset\n ttl_days: 10\nobservation_source:\n type: KAFKA\n config:\n topic: caraml-project-1-model-1-1-prediction-log\n bootstrap_servers: broker-1\n group_id: group-id\n batch_size: 100\n additional_consumer_config:\n auto.offset.reset: latest\n fetch.min.bytes: \"1024000\"\n", + }, + }, + }, nil, true) + deploymentAPI := clientSet.AppsV1().Deployments(namespace).(*fakeappsv1.FakeDeployments) + depl := createDeploymentSpec(&models.WorkerData{ + Project: "project-1", + ModelSchemaSpec: schemaSpec, + ModelName: "model-1", + ModelVersion: "2", + Revision: 2, + Metadata: models.Metadata{ + App: "model-1-observability-publisher", + Component: "worker", + Stream: "stream", + Team: "team", + }, + }, requestResource, limitResource, consumerConfig.ImageName) + preprendGetDeploymentReactor(t, deploymentAPI, depl, nil) + prependUpsertDeploymentReactor(t, deploymentAPI, depl, nil, true) + + updatedDepl := changeDeploymentStatus(depl, timeoutError, 2) + deplWatchReactor := newDeploymentWatchReactor(updatedDepl) + clientSet.WatchReactionChain = []ktesting.WatchReactor{deplWatchReactor} + + return clientSet + }(), + expectedErr: fmt.Errorf(`deployment "project-1-model-1-mlobs" exceeded its progress deadline`), + }, + } + for _, tt := range tests { + depl := &deployer{ + kubeClient: tt.kubeClient, + consumerConfig: tt.consumerConfig, + resourceRequest: tt.resourceRequest, + resourceLimit: tt.resourceLimit, + } + err := depl.Deploy(context.Background(), tt.data) + assert.Equal(t, tt.expectedErr, err) + } +} + +func Test_deployer_Undeploy(t *testing.T) { + + testCases := []struct { + name string + data *models.WorkerData + kubeClient kubernetes.Interface + expectedErr error + }{ + { + name: "success undeploy", + data: &models.WorkerData{ + Project: "project-1", + ModelName: "model-1", + ModelVersion: "1", + Revision: 1, + Metadata: models.Metadata{ + App: "model-1-observability-publisher", + Component: "worker", + Stream: "stream", + Team: "team", + }, + }, + kubeClient: func() kubernetes.Interface { + clientSet := fake.NewSimpleClientset() + deploymentAPI := clientSet.AppsV1().Deployments(namespace).(*fakeappsv1.FakeDeployments) + prependDeleteDeploymentReactor(t, deploymentAPI, "project-1-model-1-mlobs", nil) + + secretAPI := clientSet.CoreV1().Secrets(namespace).(*fakecorev1.FakeSecrets) + prependDeleteSecretReactor(t, secretAPI, "project-1-model-1-config", nil) + return clientSet + }(), + }, + { + name: "failed undeploy; error when delete deployment", + data: &models.WorkerData{ + Project: "project-1", + ModelName: "model-1", + ModelVersion: "1", + Revision: 1, + Metadata: models.Metadata{ + App: "model-1-observability-publisher", + Component: "worker", + Stream: "stream", + Team: "team", + }, + }, + kubeClient: func() kubernetes.Interface { + clientSet := fake.NewSimpleClientset() + deploymentAPI := clientSet.AppsV1().Deployments(namespace).(*fakeappsv1.FakeDeployments) + prependDeleteDeploymentReactor(t, deploymentAPI, "project-1-model-1-mlobs", fmt.Errorf("control plane is down")) + return clientSet + }(), + expectedErr: fmt.Errorf("control plane is down"), + }, + { + name: "faile undeploy; error when delete secret", + data: &models.WorkerData{ + Project: "project-1", + ModelName: "model-1", + ModelVersion: "1", + Revision: 1, + Metadata: models.Metadata{ + App: "model-1-observability-publisher", + Component: "worker", + Stream: "stream", + Team: "team", + }, + }, + kubeClient: func() kubernetes.Interface { + clientSet := fake.NewSimpleClientset() + deploymentAPI := clientSet.AppsV1().Deployments(namespace).(*fakeappsv1.FakeDeployments) + prependDeleteDeploymentReactor(t, deploymentAPI, "project-1-model-1-mlobs", nil) + + secretAPI := clientSet.CoreV1().Secrets(namespace).(*fakecorev1.FakeSecrets) + prependDeleteSecretReactor(t, secretAPI, "project-1-model-1-config", fmt.Errorf("control plane is down")) + return clientSet + }(), + expectedErr: fmt.Errorf("control plane is down"), + }, + } + for _, tC := range testCases { + t.Run(tC.name, func(t *testing.T) { + consumerConfig := config.ObservabilityPublisher{ + TargetNamespace: namespace, + } + depl := &deployer{ + kubeClient: tC.kubeClient, + consumerConfig: consumerConfig, + } + err := depl.Undeploy(context.Background(), tC.data) + assert.Equal(t, tC.expectedErr, err) + }) + } +} + +func Test_deployer_GetDeployedManifest(t *testing.T) { + requestResource := corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("1"), + corev1.ResourceMemory: resource.MustParse("1Gi"), + } + limitResource := corev1.ResourceList{ + corev1.ResourceMemory: resource.MustParse("1Gi"), + } + schemaSpec := &models.SchemaSpec{ + SessionIDColumn: "session_id", + RowIDColumn: "row_id", + TagColumns: []string{"tag"}, + FeatureTypes: map[string]models.ValueType{ + "featureA": models.Float64, + "featureB": models.Float64, + "featureC": models.Int64, + "featureD": models.Boolean, + }, + ModelPredictionOutput: &models.ModelPredictionOutput{ + BinaryClassificationOutput: &models.BinaryClassificationOutput{ + NegativeClassLabel: "negative", + PositiveClassLabel: "positive", + PredictionLabelColumn: "prediction_label", + PredictionScoreColumn: "prediction_score", + OutputClass: models.BinaryClassification, + }, + }, + } + secret := &corev1.Secret{ + ObjectMeta: metav1.ObjectMeta{ + Name: "project-1-model-1-config", + Namespace: namespace, + Labels: map[string]string{ + "app": "model-1-observability-publisher", + "component": "worker", + "orchestrator": "merlin", + "stream": "stream", + "team": "team", + "environment": "", + }, + }, + StringData: map[string]string{ + "config.yaml": "model_id: model-1\nmodel_version: \"1\"\ninference_schema:\n session_id_column: session_id\n row_id_column: row_id\n model_prediction_output:\n actual_score_column: \"\"\n negative_class_label: negative\n prediction_score_column: prediction_score\n prediction_label_column: prediction_label\n positive_class_label: positive\n score_threshold: null\n output_class: BinaryClassificationOutput\n tag_columns:\n - tag\n feature_types:\n featureA: float64\n featureB: float64\n featureC: int64\n featureD: boolean\n feature_orders: []\nobservation_sinks:\n- type: ARIZE\n config:\n api_key: api-key\n space_key: space-key\n- type: BIGQUERY\n config:\n project: bq-project\n dataset: dataset\n ttl_days: 10\nobservation_source:\n type: KAFKA\n config:\n topic: caraml-project-1-model-1-1-prediction-log\n bootstrap_servers: broker-1\n group_id: group-id\n batch_size: 100\n additional_consumer_config:\n auto.offset.reset: latest\n fetch.min.bytes: \"1024000\"\n", + }, + } + workerData := &models.WorkerData{ + Project: "project-1", + ModelSchemaSpec: schemaSpec, + ModelName: "model-1", + ModelVersion: "1", + Revision: 1, + Metadata: models.Metadata{ + App: "model-1-observability-publisher", + Component: "worker", + Stream: "stream", + Team: "team", + }, + TopicSource: "caraml-project-1-model-1-1-prediction-log", + } + depl := createDeploymentSpec(workerData, requestResource, limitResource, "image:v0.1") + testCases := []struct { + name string + data *models.WorkerData + kubeClient kubernetes.Interface + expectedErr error + expectedManifest *Manifest + }{ + { + name: "success get manifest", + data: workerData, + kubeClient: func() kubernetes.Interface { + clientSet := fake.NewSimpleClientset() + secretAPI := clientSet.CoreV1().Secrets(namespace).(*fakecorev1.FakeSecrets) + prependGetSecretReactor(t, secretAPI, secret, nil) + + deploymentAPI := clientSet.AppsV1().Deployments(namespace).(*fakeappsv1.FakeDeployments) + preprendGetDeploymentReactor(t, deploymentAPI, depl, nil) + return clientSet + }(), + expectedManifest: &Manifest{ + Secret: secret, + Deployment: depl, + OnProgress: true, + }, + }, + { + name: "success get manifest; rolled out deployment", + data: workerData, + kubeClient: func() kubernetes.Interface { + clientSet := fake.NewSimpleClientset() + secretAPI := clientSet.CoreV1().Secrets(namespace).(*fakecorev1.FakeSecrets) + prependGetSecretReactor(t, secretAPI, secret, nil) + + updatedDeployment := changeDeploymentStatus(depl, ready, workerData.Revision) + deploymentAPI := clientSet.AppsV1().Deployments(namespace).(*fakeappsv1.FakeDeployments) + preprendGetDeploymentReactor(t, deploymentAPI, updatedDeployment, nil) + return clientSet + }(), + expectedManifest: func() *Manifest { + updatedDeployment := changeDeploymentStatus(depl, ready, workerData.Revision) + manifest := &Manifest{ + Secret: secret, + Deployment: updatedDeployment, + OnProgress: false, + } + return manifest + }(), + }, + { + name: "failed get manifest; error fetching secret", + data: workerData, + kubeClient: func() kubernetes.Interface { + clientSet := fake.NewSimpleClientset() + secretAPI := clientSet.CoreV1().Secrets(namespace).(*fakecorev1.FakeSecrets) + prependGetSecretReactor(t, secretAPI, secret, fmt.Errorf("control plane is down")) + + return clientSet + }(), + expectedErr: fmt.Errorf("control plane is down"), + }, + { + name: "failed get manifest; error fetching deployment", + data: workerData, + kubeClient: func() kubernetes.Interface { + clientSet := fake.NewSimpleClientset() + deploymentAPI := clientSet.AppsV1().Deployments(namespace).(*fakeappsv1.FakeDeployments) + preprendGetDeploymentReactor(t, deploymentAPI, depl, fmt.Errorf("control plane is down")) + + return clientSet + }(), + expectedErr: fmt.Errorf("control plane is down"), + }, + } + for _, tC := range testCases { + t.Run(tC.name, func(t *testing.T) { + consumerConfig := config.ObservabilityPublisher{ + TargetNamespace: namespace, + } + depl := &deployer{ + kubeClient: tC.kubeClient, + consumerConfig: consumerConfig, + } + manifest, err := depl.GetDeployedManifest(context.Background(), tC.data) + assert.Equal(t, tC.expectedErr, err) + assert.Equal(t, tC.expectedManifest, manifest) + }) + } +} + +func prependGetSecretReactor(t *testing.T, secretAPI *fakecorev1.FakeSecrets, secretRet *corev1.Secret, expectedErr error) { + secretAPI.Fake.PrependReactor(getMethod, secretResource, func(action ktesting.Action) (handled bool, ret runtime.Object, err error) { + actualAction, ok := action.(ktesting.GetAction) + if !ok { + t.Fatalf("unexpected actual action") + } + if secretRet == nil { + return true, nil, &errors.StatusError{ + ErrStatus: metav1.Status{ + Code: http.StatusNotFound, + Reason: metav1.StatusReasonNotFound, + }} + } + + if actualAction.GetNamespace() != secretRet.GetNamespace() { + t.Fatalf("different namespace") + } + if actualAction.GetName() != secretRet.GetName() { + t.Fatalf("requested different secret name") + } + + return true, secretRet, expectedErr + }) + +} + +func prependUpsertSecretReactor(t *testing.T, secretAPI *fakecorev1.FakeSecrets, requestedSecrets []*corev1.Secret, expectedErr error, updateOperation bool) { + method := createMethod + if updateOperation { + method = updateMethod + } + secretAPI.Fake.PrependReactor(method, secretResource, func(action ktesting.Action) (handled bool, ret runtime.Object, err error) { + var actualReqSecret *corev1.Secret + if updateOperation { + actualAction := action.(ktesting.UpdateAction) + actualReqSecret = actualAction.GetObject().(*corev1.Secret) + } else { + actualAction := action.(ktesting.CreateAction) + actualReqSecret = actualAction.GetObject().(*corev1.Secret) + } + + foundAction := false + var secret *corev1.Secret + for _, requestedSecret := range requestedSecrets { + foundAction = actualReqSecret.Namespace == requestedSecret.Namespace && reflect.DeepEqual(requestedSecret.ObjectMeta, actualReqSecret.ObjectMeta) && reflect.DeepEqual(requestedSecret.StringData, actualReqSecret.StringData) + if foundAction { + secret = requestedSecret + break + } + } + + if !foundAction { + t.Fatalf("actual and expected secret is different") + } + + return true, secret, expectedErr + }) +} + +func prependDeleteSecretReactor(t *testing.T, secretAPI *fakecorev1.FakeSecrets, secretName string, expectedErr error) { + secretAPI.Fake.PrependReactor(deleteMethod, secretResource, func(action ktesting.Action) (handled bool, ret runtime.Object, err error) { + actualAction, ok := action.(ktesting.DeleteAction) + if !ok { + t.Fatalf("unexpected actual action") + } + if actualAction.GetName() != secretName { + t.Fatalf("requested and actual secret name is not the same") + } + return true, nil, expectedErr + }) +} + +func preprendGetDeploymentReactor(t *testing.T, deploymentAPI *fakeappsv1.FakeDeployments, deploymentRet *appsv1.Deployment, expectedErr error) { + deploymentAPI.Fake.PrependReactor(getMethod, deploymentResource, func(action ktesting.Action) (handled bool, ret runtime.Object, err error) { + actualAction, ok := action.(ktesting.GetAction) + if !ok { + t.Fatalf("unexpected actual action") + } + if deploymentRet == nil { + return true, nil, &errors.StatusError{ + ErrStatus: metav1.Status{ + Code: http.StatusNotFound, + Reason: metav1.StatusReasonNotFound, + }} + } + if actualAction.GetName() != deploymentRet.GetName() { + t.Fatalf("requested and actual deployment name is different") + } + return true, deploymentRet, expectedErr + }) +} + +func prependUpsertDeploymentReactor(t *testing.T, deploymentAPI *fakeappsv1.FakeDeployments, requestedDepl *appsv1.Deployment, expectedErr error, updateOperation bool) { + method := createMethod + if updateOperation { + method = updateMethod + } + + deploymentAPI.Fake.PrependReactor(method, deploymentResource, func(action ktesting.Action) (handled bool, ret runtime.Object, err error) { + var actualReqDepl *appsv1.Deployment + if updateOperation { + actualAction := action.(ktesting.CreateAction) + actualReqDepl = actualAction.GetObject().(*appsv1.Deployment) + } else { + actualAction := action.(ktesting.UpdateAction) + actualReqDepl = actualAction.GetObject().(*appsv1.Deployment) + } + + if actualReqDepl.Namespace != requestedDepl.GetNamespace() { + t.Fatalf("different namespace") + } + + assert.Equal(t, requestedDepl.ObjectMeta, actualReqDepl.ObjectMeta) + assert.Equal(t, requestedDepl.Spec, actualReqDepl.Spec) + if !reflect.DeepEqual(requestedDepl.ObjectMeta, actualReqDepl.ObjectMeta) || !reflect.DeepEqual(requestedDepl.Spec, actualReqDepl.Spec) { + t.Fatalf("actual and expected requested deployment is different") + } + + return true, requestedDepl, expectedErr + }) +} + +func prependDeleteDeploymentReactor(t *testing.T, deploymentAPI *fakeappsv1.FakeDeployments, deploymentName string, expectedErr error) { + deploymentAPI.Fake.PrependReactor(deleteMethod, deploymentResource, func(action ktesting.Action) (handled bool, ret runtime.Object, err error) { + actualAction, ok := action.(ktesting.DeleteAction) + if !ok { + t.Fatalf("unexpected actual action") + } + + if actualAction.GetName() != deploymentName { + t.Fatalf("requested and actual deployment name is not the same") + } + return true, nil, expectedErr + }) +} diff --git a/api/pkg/observability/deployment/mocks/deployer.go b/api/pkg/observability/deployment/mocks/deployer.go new file mode 100644 index 000000000..d3e7bd490 --- /dev/null +++ b/api/pkg/observability/deployment/mocks/deployer.go @@ -0,0 +1,97 @@ +// Code generated by mockery v2.39.2. DO NOT EDIT. + +package mocks + +import ( + context "context" + + deployment "github.com/caraml-dev/merlin/pkg/observability/deployment" + mock "github.com/stretchr/testify/mock" + + models "github.com/caraml-dev/merlin/models" +) + +// Deployer is an autogenerated mock type for the Deployer type +type Deployer struct { + mock.Mock +} + +// Deploy provides a mock function with given fields: ctx, data +func (_m *Deployer) Deploy(ctx context.Context, data *models.WorkerData) error { + ret := _m.Called(ctx, data) + + if len(ret) == 0 { + panic("no return value specified for Deploy") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, *models.WorkerData) error); ok { + r0 = rf(ctx, data) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// GetDeployedManifest provides a mock function with given fields: ctx, data +func (_m *Deployer) GetDeployedManifest(ctx context.Context, data *models.WorkerData) (*deployment.Manifest, error) { + ret := _m.Called(ctx, data) + + if len(ret) == 0 { + panic("no return value specified for GetDeployedManifest") + } + + var r0 *deployment.Manifest + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *models.WorkerData) (*deployment.Manifest, error)); ok { + return rf(ctx, data) + } + if rf, ok := ret.Get(0).(func(context.Context, *models.WorkerData) *deployment.Manifest); ok { + r0 = rf(ctx, data) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*deployment.Manifest) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *models.WorkerData) error); ok { + r1 = rf(ctx, data) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Undeploy provides a mock function with given fields: ctx, data +func (_m *Deployer) Undeploy(ctx context.Context, data *models.WorkerData) error { + ret := _m.Called(ctx, data) + + if len(ret) == 0 { + panic("no return value specified for Undeploy") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, *models.WorkerData) error); ok { + r0 = rf(ctx, data) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// NewDeployer creates a new instance of Deployer. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewDeployer(t interface { + mock.TestingT + Cleanup(func()) +}) *Deployer { + mock := &Deployer{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/api/pkg/observability/deployment/util.go b/api/pkg/observability/deployment/util.go new file mode 100644 index 000000000..9584c3f7d --- /dev/null +++ b/api/pkg/observability/deployment/util.go @@ -0,0 +1,30 @@ +package deployment + +import ( + "strconv" + + appsv1 "k8s.io/api/apps/v1" +) + +const ( + timeoutReason = "ProgressDeadlineExceeded" + k8sRevisionAnnotation = "deployment.kubernetes.io/revision" +) + +func getDeploymentCondition(status appsv1.DeploymentStatus, condType appsv1.DeploymentConditionType) *appsv1.DeploymentCondition { + for i := range status.Conditions { + c := status.Conditions[i] + if c.Type == condType { + return &c + } + } + return nil +} + +func getDeploymentRevision(depl *appsv1.Deployment) (int64, error) { + v, ok := depl.GetAnnotations()[PublisherRevisionAnnotationKey] + if !ok { + return 0, nil + } + return strconv.ParseInt(v, 10, 64) +} diff --git a/api/pkg/observability/event/event.go b/api/pkg/observability/event/event.go new file mode 100644 index 000000000..3a6a38ac2 --- /dev/null +++ b/api/pkg/observability/event/event.go @@ -0,0 +1,189 @@ +package event + +import ( + "context" + "fmt" + + "github.com/caraml-dev/merlin/models" + "github.com/caraml-dev/merlin/queue" + "github.com/caraml-dev/merlin/storage" +) + +const ( + ObservabilityPublisherDeployment = "observability_publisher_deployment" + dataArgKey = "data" +) + +type EventProducer interface { + ModelEndpointChangeEvent(modelEndpoint *models.ModelEndpoint, model *models.Model) error + VersionEndpointChangeEvent(versionEndpoint *models.VersionEndpoint, model *models.Model) error +} + +type eventProducer struct { + jobProducer queue.Producer + observabilityPublisherStorage storage.ObservabilityPublisherStorage + versionStorage storage.VersionStorage +} + +func NewEventProducer(jobProducer queue.Producer, observabilityPublisherStorage storage.ObservabilityPublisherStorage, versionStorage storage.VersionStorage) *eventProducer { + return &eventProducer{ + jobProducer: jobProducer, + observabilityPublisherStorage: observabilityPublisherStorage, + versionStorage: versionStorage, + } +} + +func (e *eventProducer) ModelEndpointChangeEvent(modelEndpoint *models.ModelEndpoint, model *models.Model) error { + if !model.ObservabilitySupported { + return nil + } + + ctx := context.Background() + publisher, err := e.observabilityPublisherStorage.GetByModelID(ctx, model.ID) + if err != nil { + return err + } + + // undeploy if + // model endpoint is nil or + // version endpoint observability is false + if isUndeployAction(modelEndpoint) { + if publisher == nil || publisher.Status == models.Terminated { + return nil + } + + var versionID models.ID + if modelEndpoint == nil { + versionID = publisher.VersionID + } else { + vEndpoint := modelEndpoint.GetVersionEndpoint() + versionID = vEndpoint.VersionID + } + + version, err := e.findVersionWithModelSchema(ctx, versionID, model.ID) + if err != nil { + return err + } + + return e.enqueueJob(version, model, publisher, models.UndeployPublisher) + } + + versionEndpoint := modelEndpoint.GetVersionEndpoint() + version, err := e.findVersionWithModelSchema(ctx, versionEndpoint.VersionID, model.ID) + if err != nil { + return err + } + + if publisher == nil { + publisher = &models.ObservabilityPublisher{ + VersionModelID: modelEndpoint.ModelID, + Revision: 1, + } + } + + publisher.VersionID = versionEndpoint.VersionID + publisher.ModelSchemaSpec = version.ModelSchema.Spec + + return e.enqueueJob(version, model, publisher, models.DeployPublisher) +} + +func (e *eventProducer) VersionEndpointChangeEvent(versionEndpoint *models.VersionEndpoint, model *models.Model) error { + if !model.ObservabilitySupported { + return nil + } + + // check if version endpoint is used by the model endpoint + // if version endpoint is not serving skipping deployment + if versionEndpoint.Status != models.EndpointServing { + return nil + } + + ctx := context.Background() + publisher, err := e.observabilityPublisherStorage.GetByModelID(ctx, model.ID) + if err != nil { + return err + } + + // Undeploy if version endpoint observability is false + if !versionEndpoint.EnableModelObservability { + if publisher == nil || publisher.Status == models.Terminated { + return nil + } + version, err := e.findVersionWithModelSchema(ctx, versionEndpoint.VersionID, model.ID) + if err != nil { + return err + } + return e.enqueueJob(version, model, publisher, models.UndeployPublisher) + } + + version, err := e.findVersionWithModelSchema(ctx, versionEndpoint.VersionID, model.ID) + if err != nil { + return err + } + + if publisher == nil { + publisher = &models.ObservabilityPublisher{ + VersionModelID: versionEndpoint.VersionModelID, + Revision: 1, + } + } + + publisher.VersionID = versionEndpoint.VersionID + publisher.ModelSchemaSpec = version.ModelSchema.Spec + return e.enqueueJob(version, model, publisher, models.DeployPublisher) +} + +func isUndeployAction(modelEndpoint *models.ModelEndpoint) bool { + if modelEndpoint == nil { + return true + } + if len(modelEndpoint.Rule.Destination) == 0 { + return false + } + destination := modelEndpoint.Rule.Destination[0] + return !destination.VersionEndpoint.EnableModelObservability +} + +func (e *eventProducer) findVersionWithModelSchema(ctx context.Context, versionID models.ID, modelID models.ID) (*models.Version, error) { + version, err := e.versionStorage.FindByID(ctx, versionID, modelID) + if err != nil { + return nil, err + } + if version.ModelSchema == nil { + return nil, fmt.Errorf("versionID: %d in modelID: %d doesn't have model schema", versionID, modelID) + } + return version, nil +} + +func (e *eventProducer) enqueueJob(version *models.Version, model *models.Model, publisher *models.ObservabilityPublisher, actionType models.ActionType) error { + publisher.Status = models.Pending + if version.ModelSchema != nil { + publisher.ModelSchemaSpec = version.ModelSchema.Spec + } + ctx := context.Background() + if publisher.ID > 0 { + increaseRevision := actionType == models.DeployPublisher + updatedPublisher, err := e.observabilityPublisherStorage.Update(ctx, publisher, increaseRevision) + if err != nil { + return err + } + publisher = updatedPublisher + } else { + updatedPublisher, err := e.observabilityPublisherStorage.Create(ctx, publisher) + if err != nil { + return err + } + publisher = updatedPublisher + } + + return e.jobProducer.EnqueueJob(&queue.Job{ + Name: ObservabilityPublisherDeployment, + Arguments: queue.Arguments{ + dataArgKey: models.ObservabilityPublisherJob{ + ActionType: actionType, + Publisher: publisher, + WorkerData: models.NewWorkerData(version, model, publisher), + }, + }, + }) +} diff --git a/api/pkg/observability/event/event_test.go b/api/pkg/observability/event/event_test.go new file mode 100644 index 000000000..db84d527d --- /dev/null +++ b/api/pkg/observability/event/event_test.go @@ -0,0 +1,1115 @@ +package event + +import ( + "fmt" + "testing" + + "github.com/caraml-dev/merlin/mlp" + "github.com/caraml-dev/merlin/models" + "github.com/caraml-dev/merlin/queue" + queueMock "github.com/caraml-dev/merlin/queue/mocks" + storageMock "github.com/caraml-dev/merlin/storage/mocks" + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +func Test_eventProducer_ModelEndpointChangeEvent(t *testing.T) { + model := &models.Model{ + ID: models.ID(1), + Name: "model-1", + Project: mlp.Project{ + Name: "project-1", + Stream: "stream", + Team: "team", + }, + ObservabilitySupported: true, + } + schemaSpec := &models.SchemaSpec{ + TagColumns: []string{"tag"}, + FeatureTypes: map[string]models.ValueType{ + "featureA": models.Float64, + "featureB": models.Float64, + "featureC": models.Int64, + "featureD": models.Boolean, + }, + ModelPredictionOutput: &models.ModelPredictionOutput{ + BinaryClassificationOutput: &models.BinaryClassificationOutput{ + NegativeClassLabel: "negative", + PositiveClassLabel: "positive", + PredictionLabelColumn: "prediction_label", + PredictionScoreColumn: "prediction_score", + OutputClass: models.BinaryClassification, + }, + }, + } + + regresionSchemaSpec := &models.SchemaSpec{ + TagColumns: []string{"tag"}, + FeatureTypes: map[string]models.ValueType{ + "featureA": models.Float64, + "featureB": models.Float64, + "featureC": models.Int64, + "featureD": models.Boolean, + }, + ModelPredictionOutput: &models.ModelPredictionOutput{ + RegressionOutput: &models.RegressionOutput{ + PredictionScoreColumn: "prediction_score", + ActualScoreColumn: "actual_score", + OutputClass: models.Regression, + }, + }, + } + + modelSchema := &models.ModelSchema{ + ModelID: model.ID, + ID: models.ID(1), + Spec: schemaSpec, + } + tests := []struct { + name string + jobProducer *queueMock.Producer + observabilityPublisherStorage *storageMock.ObservabilityPublisherStorage + versionStorage *storageMock.VersionStorage + modelEndpoint *models.ModelEndpoint + model *models.Model + expectedError error + }{ + { + name: "do nothing if model doesn't support model observability", + jobProducer: func() *queueMock.Producer { + producer := &queueMock.Producer{} + return producer + }(), + observabilityPublisherStorage: func() *storageMock.ObservabilityPublisherStorage { + mockStorage := &storageMock.ObservabilityPublisherStorage{} + return mockStorage + }(), + versionStorage: func() *storageMock.VersionStorage { + mockStorage := &storageMock.VersionStorage{} + return mockStorage + }(), + model: &models.Model{ + ID: models.ID(2), + ObservabilitySupported: false, + }, + modelEndpoint: &models.ModelEndpoint{ + ID: models.ID(1), + ModelID: model.ID, + Model: model, + Status: models.EndpointServing, + Rule: &models.ModelEndpointRule{ + Destination: []*models.ModelEndpointRuleDestination{ + { + VersionEndpoint: &models.VersionEndpoint{ + ID: uuid.UUID{}, + VersionID: models.ID(2), + Status: models.EndpointServing, + EnableModelObservability: true, + }, + }, + }, + }, + }, + }, + { + name: "no deployment; version endpoint model observability is disabled and never been deployed before", + jobProducer: func() *queueMock.Producer { + producer := &queueMock.Producer{} + return producer + }(), + observabilityPublisherStorage: func() *storageMock.ObservabilityPublisherStorage { + mockStorage := &storageMock.ObservabilityPublisherStorage{} + mockStorage.On("GetByModelID", mock.Anything, models.ID(1)).Return(nil, nil) + return mockStorage + }(), + versionStorage: func() *storageMock.VersionStorage { + mockStorage := &storageMock.VersionStorage{} + return mockStorage + }(), + model: model, + modelEndpoint: &models.ModelEndpoint{ + ID: models.ID(1), + ModelID: model.ID, + Model: model, + Status: models.EndpointServing, + Rule: &models.ModelEndpointRule{ + Destination: []*models.ModelEndpointRuleDestination{ + { + VersionEndpoint: &models.VersionEndpoint{ + ID: uuid.UUID{}, + VersionID: models.ID(2), + Status: models.EndpointServing, + EnableModelObservability: false, + }, + }, + }, + }, + }, + }, + { + name: "fresh deployment", + jobProducer: func() *queueMock.Producer { + producer := &queueMock.Producer{} + producer.On("EnqueueJob", &queue.Job{ + Name: ObservabilityPublisherDeployment, + Arguments: queue.Arguments{ + dataArgKey: models.ObservabilityPublisherJob{ + ActionType: models.DeployPublisher, + Publisher: &models.ObservabilityPublisher{ + ID: models.ID(1), + VersionID: models.ID(2), + VersionModelID: models.ID(1), + ModelSchemaSpec: schemaSpec, + Revision: 1, + Status: models.Pending, + }, + WorkerData: &models.WorkerData{ + Project: "project-1", + ModelSchemaSpec: schemaSpec, + Metadata: models.Metadata{ + App: "model-1-observability-publisher", + Component: "worker", + Stream: "stream", + Team: "team", + }, + ModelName: "model-1", + ModelVersion: "2", + Revision: 1, + TopicSource: "caraml-project-1-model-1-2-prediction-log", + }, + }, + }, + }).Return(nil) + return producer + }(), + observabilityPublisherStorage: func() *storageMock.ObservabilityPublisherStorage { + mockStorage := &storageMock.ObservabilityPublisherStorage{} + mockStorage.On("GetByModelID", mock.Anything, models.ID(1)).Return(nil, nil) + mockStorage.On("Create", mock.Anything, &models.ObservabilityPublisher{ + VersionID: models.ID(2), + VersionModelID: models.ID(1), + Status: models.Pending, + Revision: 1, + ModelSchemaSpec: schemaSpec, + }).Return(&models.ObservabilityPublisher{ + ID: models.ID(1), + VersionID: models.ID(2), + VersionModelID: models.ID(1), + Revision: 1, + Status: models.Pending, + ModelSchemaSpec: schemaSpec, + }, nil) + return mockStorage + }(), + versionStorage: func() *storageMock.VersionStorage { + mockStorage := &storageMock.VersionStorage{} + mockStorage.On("FindByID", mock.Anything, models.ID(2), model.ID).Return(&models.Version{ + ID: models.ID(2), + ModelID: model.ID, + ModelSchema: modelSchema, + Model: model, + }, nil) + return mockStorage + }(), + model: model, + modelEndpoint: &models.ModelEndpoint{ + ID: models.ID(1), + ModelID: model.ID, + Model: model, + Status: models.EndpointServing, + Rule: &models.ModelEndpointRule{ + Destination: []*models.ModelEndpointRuleDestination{ + { + VersionEndpoint: &models.VersionEndpoint{ + ID: uuid.UUID{}, + VersionID: models.ID(2), + Status: models.EndpointServing, + EnableModelObservability: true, + }, + }, + }, + }, + }, + }, + { + name: "fresh deployment request failed - version doesn't have schema ", + jobProducer: func() *queueMock.Producer { + producer := &queueMock.Producer{} + return producer + }(), + observabilityPublisherStorage: func() *storageMock.ObservabilityPublisherStorage { + mockStorage := &storageMock.ObservabilityPublisherStorage{} + mockStorage.On("GetByModelID", mock.Anything, models.ID(1)).Return(nil, nil) + return mockStorage + }(), + versionStorage: func() *storageMock.VersionStorage { + mockStorage := &storageMock.VersionStorage{} + mockStorage.On("FindByID", mock.Anything, models.ID(3), model.ID).Return(&models.Version{ + ID: models.ID(3), + ModelID: model.ID, + ModelSchema: nil, + Model: model, + }, nil) + return mockStorage + }(), + model: model, + modelEndpoint: &models.ModelEndpoint{ + ID: models.ID(1), + ModelID: model.ID, + Model: model, + Status: models.EndpointServing, + Rule: &models.ModelEndpointRule{ + Destination: []*models.ModelEndpointRuleDestination{ + { + VersionEndpoint: &models.VersionEndpoint{ + ID: uuid.UUID{}, + VersionID: models.ID(3), + Status: models.EndpointServing, + EnableModelObservability: true, + }, + }, + }, + }, + }, + expectedError: fmt.Errorf("versionID: 3 in modelID: 1 doesn't have model schema"), + }, + { + name: "redeployment - model endpoint change version", + jobProducer: func() *queueMock.Producer { + producer := &queueMock.Producer{} + producer.On("EnqueueJob", &queue.Job{ + Name: ObservabilityPublisherDeployment, + Arguments: queue.Arguments{ + dataArgKey: models.ObservabilityPublisherJob{ + ActionType: models.DeployPublisher, + Publisher: &models.ObservabilityPublisher{ + ID: models.ID(1), + VersionID: models.ID(3), + VersionModelID: models.ID(1), + ModelSchemaSpec: regresionSchemaSpec, + Revision: 2, + Status: models.Pending, + }, + WorkerData: &models.WorkerData{ + Project: "project-1", + ModelSchemaSpec: regresionSchemaSpec, + Metadata: models.Metadata{ + App: "model-1-observability-publisher", + Component: "worker", + Stream: "stream", + Team: "team", + }, + ModelName: "model-1", + ModelVersion: "3", + Revision: 2, + TopicSource: "caraml-project-1-model-1-3-prediction-log", + }, + }, + }, + }).Return(nil) + return producer + }(), + observabilityPublisherStorage: func() *storageMock.ObservabilityPublisherStorage { + mockStorage := &storageMock.ObservabilityPublisherStorage{} + mockStorage.On("GetByModelID", mock.Anything, models.ID(1)).Return(&models.ObservabilityPublisher{ + ID: models.ID(1), + VersionID: models.ID(2), + VersionModelID: models.ID(1), + Status: models.Pending, + ModelSchemaSpec: schemaSpec, + }, nil) + mockStorage.On("Update", mock.Anything, &models.ObservabilityPublisher{ + ID: models.ID(1), + VersionID: models.ID(3), + VersionModelID: models.ID(1), + Status: models.Pending, + ModelSchemaSpec: regresionSchemaSpec, + }, true).Return(&models.ObservabilityPublisher{ + ID: models.ID(1), + VersionID: models.ID(3), + VersionModelID: models.ID(1), + Revision: 2, + Status: models.Pending, + ModelSchemaSpec: regresionSchemaSpec, + }, nil) + return mockStorage + }(), + versionStorage: func() *storageMock.VersionStorage { + mockStorage := &storageMock.VersionStorage{} + mockStorage.On("FindByID", mock.Anything, models.ID(3), model.ID).Return(&models.Version{ + ID: models.ID(3), + ModelID: model.ID, + ModelSchema: &models.ModelSchema{ + ID: models.ID(1), + ModelID: model.ID, + Spec: regresionSchemaSpec, + }, + Model: model, + }, nil) + return mockStorage + }(), + model: model, + modelEndpoint: &models.ModelEndpoint{ + ID: models.ID(1), + ModelID: model.ID, + Model: model, + Status: models.EndpointServing, + Rule: &models.ModelEndpointRule{ + Destination: []*models.ModelEndpointRuleDestination{ + { + VersionEndpoint: &models.VersionEndpoint{ + ID: uuid.UUID{}, + VersionID: models.ID(3), + Status: models.EndpointServing, + EnableModelObservability: true, + }, + }, + }, + }, + }, + }, + { + name: "redeployment request failed - failed get version ", + jobProducer: func() *queueMock.Producer { + producer := &queueMock.Producer{} + return producer + }(), + observabilityPublisherStorage: func() *storageMock.ObservabilityPublisherStorage { + mockStorage := &storageMock.ObservabilityPublisherStorage{} + mockStorage.On("GetByModelID", mock.Anything, models.ID(1)).Return(&models.ObservabilityPublisher{ + ID: models.ID(1), + VersionID: models.ID(2), + VersionModelID: model.ID, + Revision: 1, + Status: models.Running, + }, nil) + return mockStorage + }(), + versionStorage: func() *storageMock.VersionStorage { + mockStorage := &storageMock.VersionStorage{} + mockStorage.On("FindByID", mock.Anything, models.ID(4), model.ID).Return(nil, fmt.Errorf("connection is broken")) + return mockStorage + }(), + model: model, + modelEndpoint: &models.ModelEndpoint{ + ID: models.ID(1), + ModelID: model.ID, + Model: model, + Status: models.EndpointServing, + Rule: &models.ModelEndpointRule{ + Destination: []*models.ModelEndpointRuleDestination{ + { + VersionEndpoint: &models.VersionEndpoint{ + ID: uuid.UUID{}, + VersionID: models.ID(4), + Status: models.EndpointServing, + EnableModelObservability: true, + }, + }, + }, + }, + }, + expectedError: fmt.Errorf("connection is broken"), + }, + { + name: "undeployment - model endpoint is nil", + jobProducer: func() *queueMock.Producer { + producer := &queueMock.Producer{} + producer.On("EnqueueJob", &queue.Job{ + Name: ObservabilityPublisherDeployment, + Arguments: queue.Arguments{ + dataArgKey: models.ObservabilityPublisherJob{ + ActionType: models.UndeployPublisher, + Publisher: &models.ObservabilityPublisher{ + ID: models.ID(1), + VersionID: models.ID(2), + VersionModelID: models.ID(1), + ModelSchemaSpec: schemaSpec, + Revision: 1, + Status: models.Pending, + }, + WorkerData: &models.WorkerData{ + Project: "project-1", + ModelSchemaSpec: schemaSpec, + Metadata: models.Metadata{ + App: "model-1-observability-publisher", + Component: "worker", + Stream: "stream", + Team: "team", + }, + ModelName: "model-1", + ModelVersion: "2", + Revision: 1, + TopicSource: "caraml-project-1-model-1-2-prediction-log", + }, + }, + }, + }).Return(nil) + return producer + }(), + observabilityPublisherStorage: func() *storageMock.ObservabilityPublisherStorage { + mockStorage := &storageMock.ObservabilityPublisherStorage{} + mockStorage.On("GetByModelID", mock.Anything, models.ID(1)).Return(&models.ObservabilityPublisher{ + ID: models.ID(1), + VersionID: models.ID(2), + VersionModelID: models.ID(1), + Status: models.Running, + Revision: 1, + ModelSchemaSpec: schemaSpec, + }, nil) + mockStorage.On("Update", mock.Anything, &models.ObservabilityPublisher{ + ID: models.ID(1), + VersionID: models.ID(2), + VersionModelID: models.ID(1), + Status: models.Pending, + Revision: 1, + ModelSchemaSpec: schemaSpec, + }, false).Return(&models.ObservabilityPublisher{ + ID: models.ID(1), + VersionID: models.ID(2), + VersionModelID: models.ID(1), + Revision: 1, + Status: models.Pending, + ModelSchemaSpec: schemaSpec, + }, nil) + return mockStorage + }(), + versionStorage: func() *storageMock.VersionStorage { + mockStorage := &storageMock.VersionStorage{} + mockStorage.On("FindByID", mock.Anything, models.ID(2), model.ID).Return(&models.Version{ + ID: models.ID(2), + ModelID: model.ID, + ModelSchema: modelSchema, + Model: model, + }, nil) + return mockStorage + }(), + model: model, + modelEndpoint: nil, + }, + { + name: "undeployment - model observability is disabled", + jobProducer: func() *queueMock.Producer { + producer := &queueMock.Producer{} + producer.On("EnqueueJob", &queue.Job{ + Name: ObservabilityPublisherDeployment, + Arguments: queue.Arguments{ + dataArgKey: models.ObservabilityPublisherJob{ + ActionType: models.UndeployPublisher, + Publisher: &models.ObservabilityPublisher{ + ID: models.ID(1), + VersionID: models.ID(2), + VersionModelID: models.ID(1), + ModelSchemaSpec: schemaSpec, + Revision: 1, + Status: models.Pending, + }, + WorkerData: &models.WorkerData{ + Project: "project-1", + ModelSchemaSpec: schemaSpec, + Metadata: models.Metadata{ + App: "model-1-observability-publisher", + Component: "worker", + Stream: "stream", + Team: "team", + }, + ModelName: "model-1", + ModelVersion: "2", + Revision: 1, + TopicSource: "caraml-project-1-model-1-2-prediction-log", + }, + }, + }, + }).Return(nil) + return producer + }(), + observabilityPublisherStorage: func() *storageMock.ObservabilityPublisherStorage { + mockStorage := &storageMock.ObservabilityPublisherStorage{} + mockStorage.On("GetByModelID", mock.Anything, models.ID(1)).Return(&models.ObservabilityPublisher{ + ID: models.ID(1), + VersionID: models.ID(2), + VersionModelID: models.ID(1), + Status: models.Running, + Revision: 1, + ModelSchemaSpec: schemaSpec, + }, nil) + mockStorage.On("Update", mock.Anything, &models.ObservabilityPublisher{ + ID: models.ID(1), + VersionID: models.ID(2), + VersionModelID: models.ID(1), + Status: models.Pending, + Revision: 1, + ModelSchemaSpec: schemaSpec, + }, false).Return(&models.ObservabilityPublisher{ + ID: models.ID(1), + VersionID: models.ID(2), + VersionModelID: models.ID(1), + Revision: 1, + Status: models.Pending, + ModelSchemaSpec: schemaSpec, + }, nil) + return mockStorage + }(), + versionStorage: func() *storageMock.VersionStorage { + mockStorage := &storageMock.VersionStorage{} + mockStorage.On("FindByID", mock.Anything, models.ID(2), model.ID).Return(&models.Version{ + ID: models.ID(2), + ModelID: model.ID, + ModelSchema: modelSchema, + Model: model, + }, nil) + return mockStorage + }(), + model: model, + modelEndpoint: &models.ModelEndpoint{ + ID: models.ID(1), + ModelID: model.ID, + Model: model, + Status: models.EndpointServing, + Rule: &models.ModelEndpointRule{ + Destination: []*models.ModelEndpointRuleDestination{ + { + VersionEndpoint: &models.VersionEndpoint{ + ID: uuid.UUID{}, + VersionID: models.ID(2), + Status: models.EndpointServing, + EnableModelObservability: false, + }, + }, + }, + }, + }, + }, + { + name: "do nothing; version endpoint model observability is disabled and last state of publisher is terminated", + jobProducer: func() *queueMock.Producer { + producer := &queueMock.Producer{} + return producer + }(), + observabilityPublisherStorage: func() *storageMock.ObservabilityPublisherStorage { + mockStorage := &storageMock.ObservabilityPublisherStorage{} + mockStorage.On("GetByModelID", mock.Anything, models.ID(1)).Return(&models.ObservabilityPublisher{ + ID: models.ID(1), + VersionID: models.ID(2), + VersionModelID: model.ID, + Status: models.Terminated, + Revision: 1, + }, nil) + return mockStorage + }(), + versionStorage: func() *storageMock.VersionStorage { + mockStorage := &storageMock.VersionStorage{} + return mockStorage + }(), + model: model, + modelEndpoint: &models.ModelEndpoint{ + ID: models.ID(1), + ModelID: model.ID, + Model: model, + Status: models.EndpointServing, + Rule: &models.ModelEndpointRule{ + Destination: []*models.ModelEndpointRuleDestination{ + { + VersionEndpoint: &models.VersionEndpoint{ + ID: uuid.UUID{}, + VersionID: models.ID(3), + Status: models.EndpointServing, + EnableModelObservability: false, + }, + }, + }, + }, + }, + }, + { + name: "undeployment request failed; fail fetch version", + jobProducer: func() *queueMock.Producer { + producer := &queueMock.Producer{} + return producer + }(), + observabilityPublisherStorage: func() *storageMock.ObservabilityPublisherStorage { + mockStorage := &storageMock.ObservabilityPublisherStorage{} + mockStorage.On("GetByModelID", mock.Anything, models.ID(1)).Return(&models.ObservabilityPublisher{ + ID: models.ID(1), + VersionID: models.ID(2), + VersionModelID: models.ID(1), + Status: models.Running, + Revision: 1, + ModelSchemaSpec: schemaSpec, + }, nil) + return mockStorage + }(), + versionStorage: func() *storageMock.VersionStorage { + mockStorage := &storageMock.VersionStorage{} + mockStorage.On("FindByID", mock.Anything, models.ID(2), model.ID).Return(nil, fmt.Errorf("connection error")) + return mockStorage + }(), + model: model, + modelEndpoint: &models.ModelEndpoint{ + ID: models.ID(1), + ModelID: model.ID, + Model: model, + Status: models.EndpointServing, + Rule: &models.ModelEndpointRule{ + Destination: []*models.ModelEndpointRuleDestination{ + { + VersionEndpoint: &models.VersionEndpoint{ + ID: uuid.UUID{}, + VersionID: models.ID(2), + Status: models.EndpointServing, + EnableModelObservability: false, + }, + }, + }, + }, + }, + expectedError: fmt.Errorf("connection error"), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + eventProducer := NewEventProducer(tt.jobProducer, tt.observabilityPublisherStorage, tt.versionStorage) + err := eventProducer.ModelEndpointChangeEvent(tt.modelEndpoint, tt.model) + assert.Equal(t, tt.expectedError, err) + tt.observabilityPublisherStorage.AssertExpectations(t) + tt.versionStorage.AssertExpectations(t) + tt.jobProducer.AssertExpectations(t) + }) + } +} + +func Test_eventProducer_VersionEndpointChangeEvent(t *testing.T) { + model := &models.Model{ + ID: models.ID(1), + Name: "model-1", + Project: mlp.Project{ + Name: "project-1", + Stream: "stream", + Team: "team", + }, + ObservabilitySupported: true, + } + schemaSpec := &models.SchemaSpec{ + TagColumns: []string{"tag"}, + FeatureTypes: map[string]models.ValueType{ + "featureA": models.Float64, + "featureB": models.Float64, + "featureC": models.Int64, + "featureD": models.Boolean, + }, + ModelPredictionOutput: &models.ModelPredictionOutput{ + BinaryClassificationOutput: &models.BinaryClassificationOutput{ + NegativeClassLabel: "negative", + PositiveClassLabel: "positive", + PredictionLabelColumn: "prediction_label", + PredictionScoreColumn: "prediction_score", + OutputClass: models.BinaryClassification, + }, + }, + } + + modelSchema := &models.ModelSchema{ + ID: models.ID(1), + ModelID: model.ID, + Spec: schemaSpec, + } + + regresionSchemaSpec := &models.SchemaSpec{ + TagColumns: []string{"tag"}, + FeatureTypes: map[string]models.ValueType{ + "featureA": models.Float64, + "featureB": models.Float64, + "featureC": models.Int64, + "featureD": models.Boolean, + }, + ModelPredictionOutput: &models.ModelPredictionOutput{ + RegressionOutput: &models.RegressionOutput{ + PredictionScoreColumn: "prediction_score", + ActualScoreColumn: "actual_score", + OutputClass: models.Regression, + }, + }, + } + tests := []struct { + name string + jobProducer *queueMock.Producer + observabilityPublisherStorage *storageMock.ObservabilityPublisherStorage + versionStorage *storageMock.VersionStorage + model *models.Model + versionEndpoint *models.VersionEndpoint + expectedError error + }{ + { + name: "do nothing if model not supported observability", + jobProducer: &queueMock.Producer{}, + observabilityPublisherStorage: &storageMock.ObservabilityPublisherStorage{}, + versionStorage: &storageMock.VersionStorage{}, + model: &models.Model{ + ID: models.ID(1), + ObservabilitySupported: false, + }, + }, + { + name: "no deployment; version endpoint model observability is disabled and never been deployed before", + jobProducer: func() *queueMock.Producer { + producer := &queueMock.Producer{} + return producer + }(), + observabilityPublisherStorage: func() *storageMock.ObservabilityPublisherStorage { + mockStorage := &storageMock.ObservabilityPublisherStorage{} + mockStorage.On("GetByModelID", mock.Anything, models.ID(1)).Return(nil, nil) + return mockStorage + }(), + versionStorage: &storageMock.VersionStorage{}, + model: model, + versionEndpoint: &models.VersionEndpoint{ + ID: uuid.UUID{}, + VersionID: models.ID(2), + Status: models.EndpointServing, + EnableModelObservability: false, + }, + }, + { + name: "fresh deployment", + jobProducer: func() *queueMock.Producer { + producer := &queueMock.Producer{} + producer.On("EnqueueJob", &queue.Job{ + Name: ObservabilityPublisherDeployment, + Arguments: queue.Arguments{ + dataArgKey: models.ObservabilityPublisherJob{ + ActionType: models.DeployPublisher, + Publisher: &models.ObservabilityPublisher{ + ID: models.ID(1), + VersionID: models.ID(2), + VersionModelID: models.ID(1), + ModelSchemaSpec: schemaSpec, + Revision: 1, + Status: models.Pending, + }, + WorkerData: &models.WorkerData{ + Project: "project-1", + ModelSchemaSpec: schemaSpec, + Metadata: models.Metadata{ + App: "model-1-observability-publisher", + Component: "worker", + Stream: "stream", + Team: "team", + }, + ModelName: "model-1", + ModelVersion: "2", + Revision: 1, + TopicSource: "caraml-project-1-model-1-2-prediction-log", + }, + }, + }, + }).Return(nil) + return producer + }(), + observabilityPublisherStorage: func() *storageMock.ObservabilityPublisherStorage { + mockStorage := &storageMock.ObservabilityPublisherStorage{} + mockStorage.On("GetByModelID", mock.Anything, models.ID(1)).Return(nil, nil) + mockStorage.On("Create", mock.Anything, &models.ObservabilityPublisher{ + VersionID: models.ID(2), + VersionModelID: models.ID(1), + Status: models.Pending, + Revision: 1, + ModelSchemaSpec: schemaSpec, + }).Return(&models.ObservabilityPublisher{ + ID: models.ID(1), + VersionID: models.ID(2), + VersionModelID: models.ID(1), + Revision: 1, + Status: models.Pending, + ModelSchemaSpec: schemaSpec, + }, nil) + return mockStorage + }(), + versionStorage: func() *storageMock.VersionStorage { + mockStorage := &storageMock.VersionStorage{} + mockStorage.On("FindByID", mock.Anything, models.ID(2), model.ID).Return(&models.Version{ + ID: models.ID(2), + ModelID: model.ID, + ModelSchema: modelSchema, + Model: model, + }, nil) + return mockStorage + }(), + model: model, + versionEndpoint: &models.VersionEndpoint{ + ID: uuid.UUID{}, + VersionID: models.ID(2), + VersionModelID: model.ID, + Status: models.EndpointServing, + EnableModelObservability: true, + }, + }, + { + name: "fresh deployment request failed - version doesn't have schema ", + jobProducer: func() *queueMock.Producer { + producer := &queueMock.Producer{} + return producer + }(), + observabilityPublisherStorage: func() *storageMock.ObservabilityPublisherStorage { + mockStorage := &storageMock.ObservabilityPublisherStorage{} + mockStorage.On("GetByModelID", mock.Anything, models.ID(1)).Return(nil, nil) + return mockStorage + }(), + versionStorage: func() *storageMock.VersionStorage { + mockStorage := &storageMock.VersionStorage{} + mockStorage.On("FindByID", mock.Anything, models.ID(3), model.ID).Return(&models.Version{ + ID: models.ID(3), + ModelID: model.ID, + ModelSchema: nil, + Model: model, + }, nil) + return mockStorage + }(), + model: model, + versionEndpoint: &models.VersionEndpoint{ + ID: uuid.UUID{}, + VersionID: models.ID(3), + VersionModelID: model.ID, + Status: models.EndpointServing, + EnableModelObservability: true, + }, + expectedError: fmt.Errorf("versionID: 3 in modelID: 1 doesn't have model schema"), + }, + { + name: "redeployment - model endpoint change version", + jobProducer: func() *queueMock.Producer { + producer := &queueMock.Producer{} + producer.On("EnqueueJob", &queue.Job{ + Name: ObservabilityPublisherDeployment, + Arguments: queue.Arguments{ + dataArgKey: models.ObservabilityPublisherJob{ + ActionType: models.DeployPublisher, + Publisher: &models.ObservabilityPublisher{ + ID: models.ID(1), + VersionID: models.ID(3), + VersionModelID: models.ID(1), + ModelSchemaSpec: regresionSchemaSpec, + Revision: 2, + Status: models.Pending, + }, + WorkerData: &models.WorkerData{ + Project: "project-1", + ModelSchemaSpec: regresionSchemaSpec, + Metadata: models.Metadata{ + App: "model-1-observability-publisher", + Component: "worker", + Stream: "stream", + Team: "team", + }, + ModelName: "model-1", + ModelVersion: "3", + Revision: 2, + TopicSource: "caraml-project-1-model-1-3-prediction-log", + }, + }, + }, + }).Return(nil) + return producer + }(), + observabilityPublisherStorage: func() *storageMock.ObservabilityPublisherStorage { + mockStorage := &storageMock.ObservabilityPublisherStorage{} + mockStorage.On("GetByModelID", mock.Anything, models.ID(1)).Return(&models.ObservabilityPublisher{ + ID: models.ID(1), + VersionID: models.ID(2), + VersionModelID: models.ID(1), + Status: models.Pending, + ModelSchemaSpec: schemaSpec, + }, nil) + mockStorage.On("Update", mock.Anything, &models.ObservabilityPublisher{ + ID: models.ID(1), + VersionID: models.ID(3), + VersionModelID: models.ID(1), + Status: models.Pending, + ModelSchemaSpec: regresionSchemaSpec, + }, true).Return(&models.ObservabilityPublisher{ + ID: models.ID(1), + VersionID: models.ID(3), + VersionModelID: models.ID(1), + Revision: 2, + Status: models.Pending, + ModelSchemaSpec: regresionSchemaSpec, + }, nil) + return mockStorage + }(), + versionStorage: func() *storageMock.VersionStorage { + mockStorage := &storageMock.VersionStorage{} + mockStorage.On("FindByID", mock.Anything, models.ID(3), model.ID).Return(&models.Version{ + ID: models.ID(3), + ModelID: model.ID, + ModelSchema: &models.ModelSchema{ + ID: models.ID(1), + ModelID: model.ID, + Spec: regresionSchemaSpec, + }, + Model: model, + }, nil) + return mockStorage + }(), + model: model, + versionEndpoint: &models.VersionEndpoint{ + ID: uuid.UUID{}, + VersionID: models.ID(3), + VersionModelID: model.ID, + Status: models.EndpointServing, + EnableModelObservability: true, + }, + }, + { + name: "undeployment - model observability is disabled", + jobProducer: func() *queueMock.Producer { + producer := &queueMock.Producer{} + producer.On("EnqueueJob", &queue.Job{ + Name: ObservabilityPublisherDeployment, + Arguments: queue.Arguments{ + dataArgKey: models.ObservabilityPublisherJob{ + ActionType: models.UndeployPublisher, + Publisher: &models.ObservabilityPublisher{ + ID: models.ID(1), + VersionID: models.ID(2), + VersionModelID: models.ID(1), + ModelSchemaSpec: schemaSpec, + Revision: 1, + Status: models.Pending, + }, + WorkerData: &models.WorkerData{ + Project: "project-1", + ModelSchemaSpec: schemaSpec, + Metadata: models.Metadata{ + App: "model-1-observability-publisher", + Component: "worker", + Stream: "stream", + Team: "team", + }, + ModelName: "model-1", + ModelVersion: "2", + Revision: 1, + TopicSource: "caraml-project-1-model-1-2-prediction-log", + }, + }, + }, + }).Return(nil) + return producer + }(), + observabilityPublisherStorage: func() *storageMock.ObservabilityPublisherStorage { + mockStorage := &storageMock.ObservabilityPublisherStorage{} + mockStorage.On("GetByModelID", mock.Anything, models.ID(1)).Return(&models.ObservabilityPublisher{ + ID: models.ID(1), + VersionID: models.ID(2), + VersionModelID: models.ID(1), + Status: models.Running, + Revision: 1, + ModelSchemaSpec: schemaSpec, + }, nil) + mockStorage.On("Update", mock.Anything, &models.ObservabilityPublisher{ + ID: models.ID(1), + VersionID: models.ID(2), + VersionModelID: models.ID(1), + Status: models.Pending, + Revision: 1, + ModelSchemaSpec: schemaSpec, + }, false).Return(&models.ObservabilityPublisher{ + ID: models.ID(1), + VersionID: models.ID(2), + VersionModelID: models.ID(1), + Revision: 1, + Status: models.Pending, + ModelSchemaSpec: schemaSpec, + }, nil) + return mockStorage + }(), + versionStorage: func() *storageMock.VersionStorage { + mockStorage := &storageMock.VersionStorage{} + mockStorage.On("FindByID", mock.Anything, models.ID(2), model.ID).Return(&models.Version{ + ID: models.ID(2), + ModelID: model.ID, + ModelSchema: modelSchema, + Model: model, + }, nil) + return mockStorage + }(), + model: model, + versionEndpoint: &models.VersionEndpoint{ + ID: uuid.UUID{}, + VersionID: models.ID(2), + VersionModelID: model.ID, + Status: models.EndpointServing, + EnableModelObservability: false, + }, + }, + { + name: "do nothing; version endpoint model observability is disabled and last state of publisher is terminated", + jobProducer: func() *queueMock.Producer { + producer := &queueMock.Producer{} + return producer + }(), + observabilityPublisherStorage: func() *storageMock.ObservabilityPublisherStorage { + mockStorage := &storageMock.ObservabilityPublisherStorage{} + mockStorage.On("GetByModelID", mock.Anything, models.ID(1)).Return(&models.ObservabilityPublisher{ + ID: models.ID(1), + VersionID: models.ID(2), + VersionModelID: model.ID, + Status: models.Terminated, + Revision: 1, + }, nil) + return mockStorage + }(), + versionStorage: func() *storageMock.VersionStorage { + mockStorage := &storageMock.VersionStorage{} + return mockStorage + }(), + model: model, + versionEndpoint: &models.VersionEndpoint{ + ID: uuid.UUID{}, + VersionID: models.ID(3), + VersionModelID: model.ID, + Status: models.EndpointServing, + EnableModelObservability: false, + }, + }, + { + name: "undeployment request failed; fail fetch version", + jobProducer: func() *queueMock.Producer { + producer := &queueMock.Producer{} + return producer + }(), + observabilityPublisherStorage: func() *storageMock.ObservabilityPublisherStorage { + mockStorage := &storageMock.ObservabilityPublisherStorage{} + mockStorage.On("GetByModelID", mock.Anything, models.ID(1)).Return(&models.ObservabilityPublisher{ + ID: models.ID(1), + VersionID: models.ID(2), + VersionModelID: models.ID(1), + Status: models.Running, + Revision: 1, + ModelSchemaSpec: schemaSpec, + }, nil) + return mockStorage + }(), + versionStorage: func() *storageMock.VersionStorage { + mockStorage := &storageMock.VersionStorage{} + mockStorage.On("FindByID", mock.Anything, models.ID(2), model.ID).Return(nil, fmt.Errorf("connection error")) + return mockStorage + }(), + model: model, + versionEndpoint: &models.VersionEndpoint{ + ID: uuid.UUID{}, + VersionID: models.ID(2), + VersionModelID: model.ID, + Status: models.EndpointServing, + EnableModelObservability: false, + }, + expectedError: fmt.Errorf("connection error"), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + eventProducer := NewEventProducer(tt.jobProducer, tt.observabilityPublisherStorage, tt.versionStorage) + err := eventProducer.VersionEndpointChangeEvent(tt.versionEndpoint, tt.model) + assert.Equal(t, tt.expectedError, err) + }) + } +} diff --git a/api/pkg/observability/event/mocks/event_producer.go b/api/pkg/observability/event/mocks/event_producer.go new file mode 100644 index 000000000..82233c0aa --- /dev/null +++ b/api/pkg/observability/event/mocks/event_producer.go @@ -0,0 +1,63 @@ +// Code generated by mockery v2.39.2. DO NOT EDIT. + +package mocks + +import ( + models "github.com/caraml-dev/merlin/models" + mock "github.com/stretchr/testify/mock" +) + +// EventProducer is an autogenerated mock type for the EventProducer type +type EventProducer struct { + mock.Mock +} + +// ModelEndpointChangeEvent provides a mock function with given fields: modelEndpoint, model +func (_m *EventProducer) ModelEndpointChangeEvent(modelEndpoint *models.ModelEndpoint, model *models.Model) error { + ret := _m.Called(modelEndpoint, model) + + if len(ret) == 0 { + panic("no return value specified for ModelEndpointChangeEvent") + } + + var r0 error + if rf, ok := ret.Get(0).(func(*models.ModelEndpoint, *models.Model) error); ok { + r0 = rf(modelEndpoint, model) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// VersionEndpointChangeEvent provides a mock function with given fields: versionEndpoint, model +func (_m *EventProducer) VersionEndpointChangeEvent(versionEndpoint *models.VersionEndpoint, model *models.Model) error { + ret := _m.Called(versionEndpoint, model) + + if len(ret) == 0 { + panic("no return value specified for VersionEndpointChangeEvent") + } + + var r0 error + if rf, ok := ret.Get(0).(func(*models.VersionEndpoint, *models.Model) error); ok { + r0 = rf(versionEndpoint, model) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// NewEventProducer creates a new instance of EventProducer. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewEventProducer(t interface { + mock.TestingT + Cleanup(func()) +}) *EventProducer { + mock := &EventProducer{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/api/queue/work/model_service_deployment.go b/api/queue/work/model_service_deployment.go index 167a4899c..ce54f16b5 100644 --- a/api/queue/work/model_service_deployment.go +++ b/api/queue/work/model_service_deployment.go @@ -12,6 +12,7 @@ import ( "github.com/caraml-dev/merlin/mlp" "github.com/caraml-dev/merlin/models" "github.com/caraml-dev/merlin/pkg/imagebuilder" + "github.com/caraml-dev/merlin/pkg/observability/event" "github.com/caraml-dev/merlin/queue" "github.com/caraml-dev/merlin/storage" "github.com/prometheus/client_golang/prometheus" @@ -34,11 +35,12 @@ func init() { } type ModelServiceDeployment struct { - ClusterControllers map[string]cluster.Controller - ImageBuilder imagebuilder.ImageBuilder - Storage storage.VersionEndpointStorage - DeploymentStorage storage.DeploymentStorage - LoggerDestinationURL string + ClusterControllers map[string]cluster.Controller + ImageBuilder imagebuilder.ImageBuilder + Storage storage.VersionEndpointStorage + DeploymentStorage storage.DeploymentStorage + LoggerDestinationURL string + ObservabilityEventProducer event.EventProducer } type EndpointJob struct { @@ -195,6 +197,12 @@ func (depl *ModelServiceDeployment) Deploy(job *queue.Job) error { log.Errorf("unable to update endpoint status for model: %s, version: %s, reason: %v", model.Name, version.ID, err) } + if model.ObservabilitySupported { + if err := depl.ObservabilityEventProducer.VersionEndpointChangeEvent(endpoint, model); err != nil { + log.Errorf("error publishing event for observability deployment for model: %s, version: %s with error: %w", model.Name, version.ID, err) + } + } + return nil } diff --git a/api/queue/work/model_service_deployment_test.go b/api/queue/work/model_service_deployment_test.go index 7cac4903f..f91ed8dae 100644 --- a/api/queue/work/model_service_deployment_test.go +++ b/api/queue/work/model_service_deployment_test.go @@ -11,6 +11,7 @@ import ( "github.com/caraml-dev/merlin/mlp" "github.com/caraml-dev/merlin/models" imageBuilderMock "github.com/caraml-dev/merlin/pkg/imagebuilder/mocks" + eventMock "github.com/caraml-dev/merlin/pkg/observability/event/mocks" "github.com/caraml-dev/merlin/queue" "github.com/caraml-dev/merlin/storage/mocks" "github.com/stretchr/testify/assert" @@ -55,7 +56,7 @@ func TestExecuteDeployment(t *testing.T) { } project := mlp.Project{Name: "project", Labels: mlpLabels} - model := &models.Model{Name: "model", Project: project} + model := &models.Model{Name: "model", Project: project, ObservabilitySupported: false} version := &models.Version{ID: 1, Labels: versionLabels} iSvcName := fmt.Sprintf("%s-%d-1", model.Name, version.ID) svcName := fmt.Sprintf("%s-%d-1.project.svc.cluster.local", model.Name, version.ID) @@ -71,6 +72,7 @@ func TestExecuteDeployment(t *testing.T) { storage func() *mocks.VersionEndpointStorage controller func() *clusterMock.Controller imageBuilder func() *imageBuilderMock.ImageBuilder + eventProducer *eventMock.EventProducer }{ { name: "Success: Default", @@ -116,6 +118,126 @@ func TestExecuteDeployment(t *testing.T) { return mockImgBuilder }, }, + { + name: "Success: Default - Model Observability Supported", + model: &models.Model{Name: "model", Project: project, ObservabilitySupported: true}, + version: version, + endpoint: &models.VersionEndpoint{ + EnvironmentName: env.Name, + ResourceRequest: env.DefaultResourceRequest, + VersionID: version.ID, + Namespace: project.Name, + EnableModelObservability: true, + }, + deploymentStorage: func() *mocks.DeploymentStorage { + mockStorage := createDefaultMockDeploymentStorage() + mockStorage.On("OnDeploymentSuccess", mock.Anything).Return(nil) + return mockStorage + }, + storage: func() *mocks.VersionEndpointStorage { + mockStorage := &mocks.VersionEndpointStorage{} + mockStorage.On("Save", mock.Anything).Return(nil) + mockStorage.On("Get", mock.Anything).Return(&models.VersionEndpoint{ + Environment: env, + EnvironmentName: env.Name, + ResourceRequest: env.DefaultResourceRequest, + VersionID: version.ID, + Namespace: project.Name, + }, nil) + return mockStorage + }, + controller: func() *clusterMock.Controller { + ctrl := &clusterMock.Controller{} + ctrl.On("Deploy", mock.Anything, mock.Anything). + Return(&models.Service{ + Name: iSvcName, + Namespace: project.Name, + ServiceName: svcName, + URL: url, + Metadata: svcMetadata, + }, nil) + return ctrl + }, + imageBuilder: func() *imageBuilderMock.ImageBuilder { + mockImgBuilder := &imageBuilderMock.ImageBuilder{} + return mockImgBuilder + }, + eventProducer: func() *eventMock.EventProducer { + producer := &eventMock.EventProducer{} + producer.On("VersionEndpointChangeEvent", &models.VersionEndpoint{ + EnvironmentName: env.Name, + ResourceRequest: env.DefaultResourceRequest, + VersionID: version.ID, + Namespace: project.Name, + RevisionID: models.ID(1), + Status: models.EndpointRunning, + URL: fmt.Sprintf("%s-%d-1.example.com", model.Name, version.ID), + ServiceName: fmt.Sprintf("%s-%d-1.project.svc.cluster.local", model.Name, version.ID), + EnableModelObservability: true, + }, &models.Model{Name: "model", Project: project, ObservabilitySupported: true}).Return(nil) + return producer + }(), + }, + { + name: "Success eventhough error when produce event", + model: &models.Model{Name: "model", Project: project, ObservabilitySupported: true}, + version: version, + endpoint: &models.VersionEndpoint{ + EnvironmentName: env.Name, + ResourceRequest: env.DefaultResourceRequest, + VersionID: version.ID, + Namespace: project.Name, + EnableModelObservability: true, + }, + deploymentStorage: func() *mocks.DeploymentStorage { + mockStorage := createDefaultMockDeploymentStorage() + mockStorage.On("OnDeploymentSuccess", mock.Anything).Return(nil) + return mockStorage + }, + storage: func() *mocks.VersionEndpointStorage { + mockStorage := &mocks.VersionEndpointStorage{} + mockStorage.On("Save", mock.Anything).Return(nil) + mockStorage.On("Get", mock.Anything).Return(&models.VersionEndpoint{ + Environment: env, + EnvironmentName: env.Name, + ResourceRequest: env.DefaultResourceRequest, + VersionID: version.ID, + Namespace: project.Name, + }, nil) + return mockStorage + }, + controller: func() *clusterMock.Controller { + ctrl := &clusterMock.Controller{} + ctrl.On("Deploy", mock.Anything, mock.Anything). + Return(&models.Service{ + Name: iSvcName, + Namespace: project.Name, + ServiceName: svcName, + URL: url, + Metadata: svcMetadata, + }, nil) + return ctrl + }, + imageBuilder: func() *imageBuilderMock.ImageBuilder { + mockImgBuilder := &imageBuilderMock.ImageBuilder{} + return mockImgBuilder + }, + eventProducer: func() *eventMock.EventProducer { + producer := &eventMock.EventProducer{} + producer.On("VersionEndpointChangeEvent", &models.VersionEndpoint{ + EnvironmentName: env.Name, + ResourceRequest: env.DefaultResourceRequest, + VersionID: version.ID, + Namespace: project.Name, + RevisionID: models.ID(1), + Status: models.EndpointRunning, + URL: fmt.Sprintf("%s-%d-1.example.com", model.Name, version.ID), + ServiceName: fmt.Sprintf("%s-%d-1.project.svc.cluster.local", model.Name, version.ID), + EnableModelObservability: true, + }, &models.Model{Name: "model", Project: project, ObservabilitySupported: true}).Return(fmt.Errorf("producer error")) + return producer + }(), + }, { name: "Success: Latest deployment entry in storage stuck in pending", model: model, @@ -555,11 +677,12 @@ func TestExecuteDeployment(t *testing.T) { }, } svc := &ModelServiceDeployment{ - ClusterControllers: controllers, - ImageBuilder: imgBuilder, - Storage: mockStorage, - DeploymentStorage: mockDeploymentStorage, - LoggerDestinationURL: loggerDestinationURL, + ClusterControllers: controllers, + ImageBuilder: imgBuilder, + Storage: mockStorage, + DeploymentStorage: mockDeploymentStorage, + LoggerDestinationURL: loggerDestinationURL, + ObservabilityEventProducer: tt.eventProducer, } err := svc.Deploy(job) diff --git a/api/queue/work/observability_publisher_deployment.go b/api/queue/work/observability_publisher_deployment.go new file mode 100644 index 000000000..03387019c --- /dev/null +++ b/api/queue/work/observability_publisher_deployment.go @@ -0,0 +1,100 @@ +package work + +import ( + "context" + "encoding/json" + "fmt" + "strconv" + + "github.com/caraml-dev/merlin/log" + "github.com/caraml-dev/merlin/models" + "github.com/caraml-dev/merlin/pkg/observability/deployment" + "github.com/caraml-dev/merlin/queue" + "github.com/caraml-dev/merlin/storage" +) + +type ObservabilityPublisherDeployment struct { + Deployer deployment.Deployer + ObservabilityPublisherStorage storage.ObservabilityPublisherStorage +} + +func (op *ObservabilityPublisherDeployment) Deploy(job *queue.Job) (err error) { + ctx := context.Background() + + data := job.Arguments[dataArgKey] + byte, _ := json.Marshal(data) + + var obsPublisherJobData models.ObservabilityPublisherJob + if err := json.Unmarshal(byte, &obsPublisherJobData); err != nil { + return fmt.Errorf("job data for ID: %d is not in ObservabilityPublisherJob type", job.ID) + } + + publisherRecord := obsPublisherJobData.Publisher + actualPublisherRecord, err := op.ObservabilityPublisherStorage.Get(ctx, publisherRecord.ID) + if err != nil { + return queue.RetryableError{Message: err.Error()} + } + + // new deployment request already queued, hence this process can be skip + if actualPublisherRecord.Revision > publisherRecord.Revision { + log.Infof("publisher deployment for model: %s is skip because newer deployment request already submitted", obsPublisherJobData.WorkerData.ModelName) + return nil + } + + if actualPublisherRecord.Revision < publisherRecord.Revision { + return fmt.Errorf("actual publisher revision should not be lower than the one from submitted job") + } + + if err := op.deploymentIsOngoing(ctx, &obsPublisherJobData); err != nil { + return err + } + + defer func() { + publisherRecord.Status = models.Running + if obsPublisherJobData.ActionType == models.UndeployPublisher { + publisherRecord.Status = models.Terminated + } + + if err != nil { + publisherRecord.Status = models.Failed + } + + if _, updateError := op.ObservabilityPublisherStorage.Update(ctx, publisherRecord, false); updateError != nil { + log.Warnf("fail to update state of observability publisher with error %w", updateError) + err = queue.RetryableError{Message: updateError.Error()} + } + }() + + if obsPublisherJobData.ActionType == models.UndeployPublisher { + return op.Deployer.Undeploy(ctx, obsPublisherJobData.WorkerData) + } + + return op.Deployer.Deploy(ctx, obsPublisherJobData.WorkerData) +} + +func (op *ObservabilityPublisherDeployment) deploymentIsOngoing(ctx context.Context, jobData *models.ObservabilityPublisherJob) error { + deployedManifest, err := op.Deployer.GetDeployedManifest(ctx, jobData.WorkerData) + if err != nil { + return queue.RetryableError{Message: err.Error()} + } + if deployedManifest == nil { + return nil + } + + currentRevision := deployedManifest.Deployment.Annotations[deployment.PublisherRevisionAnnotationKey] + if currentRevision == "" { + return fmt.Errorf("deployed manifest doesn't have revision annotation") + } + currentRevisionInt, err := strconv.Atoi(currentRevision) + if err != nil { + return err + } + + if currentRevisionInt > jobData.Publisher.Revision { + return fmt.Errorf("latest deployment already being deployed") + } + if currentRevisionInt < jobData.Publisher.Revision && deployedManifest.OnProgress { + return queue.RetryableError{Message: "there is on going deployment for previous revision, this deployment request must wait until previous deployment success"} + } + return nil +} diff --git a/api/queue/work/observability_publisher_deployment_test.go b/api/queue/work/observability_publisher_deployment_test.go new file mode 100644 index 000000000..7435c9b18 --- /dev/null +++ b/api/queue/work/observability_publisher_deployment_test.go @@ -0,0 +1,821 @@ +package work + +import ( + "fmt" + "testing" + + "github.com/caraml-dev/merlin/models" + "github.com/caraml-dev/merlin/pkg/observability/deployment" + deploymentMock "github.com/caraml-dev/merlin/pkg/observability/deployment/mocks" + "github.com/caraml-dev/merlin/pkg/observability/event" + "github.com/caraml-dev/merlin/queue" + storageMock "github.com/caraml-dev/merlin/storage/mocks" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + v1 "k8s.io/api/apps/v1" + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" +) + +func TestDeploy(t *testing.T) { + schemaSpec := &models.SchemaSpec{ + SessionIDColumn: "session_id", + RowIDColumn: "row_id", + TagColumns: []string{"tag"}, + FeatureTypes: map[string]models.ValueType{ + "featureA": models.Float64, + "featureB": models.Float64, + "featureC": models.Int64, + "featureD": models.Boolean, + }, + ModelPredictionOutput: &models.ModelPredictionOutput{ + BinaryClassificationOutput: &models.BinaryClassificationOutput{ + ActualScoreColumn: "actual_score", + NegativeClassLabel: "negative", + PositiveClassLabel: "positive", + PredictionLabelColumn: "prediction_label", + PredictionScoreColumn: "prediction_score", + OutputClass: models.BinaryClassification, + }, + }, + } + testCases := []struct { + desc string + deployer *deploymentMock.Deployer + observabilityPublisherStorage *storageMock.ObservabilityPublisherStorage + jobData *queue.Job + expectedError error + }{ + { + desc: "deployment completed; fresh deployment", + deployer: func() *deploymentMock.Deployer { + mockDeployer := &deploymentMock.Deployer{} + mockDeployer.On("GetDeployedManifest", mock.Anything, &models.WorkerData{ + Project: "project-1", + ModelSchemaSpec: schemaSpec, + ModelName: "model-1", + ModelVersion: "1", + Revision: 1, + TopicSource: "caraml-project-1-model-1-1-prediction-log", + }).Return(nil, nil) + mockDeployer.On("Deploy", mock.Anything, &models.WorkerData{ + Project: "project-1", + ModelSchemaSpec: schemaSpec, + ModelName: "model-1", + ModelVersion: "1", + Revision: 1, + TopicSource: "caraml-project-1-model-1-1-prediction-log", + }).Return(nil) + return mockDeployer + }(), + observabilityPublisherStorage: func() *storageMock.ObservabilityPublisherStorage { + mockStorage := &storageMock.ObservabilityPublisherStorage{} + mockStorage.On("Get", mock.Anything, model.ID).Return(&models.ObservabilityPublisher{ + ID: models.ID(1), + VersionModelID: model.ID, + VersionID: models.ID(1), + Status: models.Pending, + Revision: 1, + ModelSchemaSpec: schemaSpec, + }, nil) + mockStorage.On("Update", mock.Anything, &models.ObservabilityPublisher{ + ID: models.ID(1), + VersionModelID: models.ID(1), + VersionID: model.ID, + Status: models.Running, + Revision: 1, + ModelSchemaSpec: schemaSpec, + }, false).Return(&models.ObservabilityPublisher{ + ID: models.ID(1), + VersionModelID: model.ID, + VersionID: models.ID(1), + Status: models.Running, + Revision: 1, + ModelSchemaSpec: schemaSpec, + }, nil) + return mockStorage + }(), + jobData: &queue.Job{ + Name: event.ObservabilityPublisherDeployment, + Arguments: queue.Arguments{ + dataArgKey: models.ObservabilityPublisherJob{ + ActionType: models.DeployPublisher, + Publisher: &models.ObservabilityPublisher{ + ID: models.ID(1), + VersionModelID: model.ID, + VersionID: models.ID(1), + Status: models.Pending, + Revision: 1, + ModelSchemaSpec: schemaSpec, + }, + WorkerData: &models.WorkerData{ + Project: "project-1", + ModelSchemaSpec: schemaSpec, + ModelName: "model-1", + ModelVersion: "1", + Revision: 1, + TopicSource: "caraml-project-1-model-1-1-prediction-log", + }, + }, + }, + }, + }, + { + desc: "deployment failed; fresh deployment", + deployer: func() *deploymentMock.Deployer { + mockDeployer := &deploymentMock.Deployer{} + mockDeployer.On("GetDeployedManifest", mock.Anything, &models.WorkerData{ + Project: "project-1", + ModelSchemaSpec: schemaSpec, + ModelName: "model-1", + ModelVersion: "1", + Revision: 1, + TopicSource: "caraml-project-1-model-1-1-prediction-log", + }).Return(nil, nil) + mockDeployer.On("Deploy", mock.Anything, &models.WorkerData{ + Project: "project-1", + ModelSchemaSpec: schemaSpec, + ModelName: "model-1", + ModelVersion: "1", + Revision: 1, + TopicSource: "caraml-project-1-model-1-1-prediction-log", + }).Return(fmt.Errorf("control plane is down")) + return mockDeployer + }(), + observabilityPublisherStorage: func() *storageMock.ObservabilityPublisherStorage { + mockStorage := &storageMock.ObservabilityPublisherStorage{} + mockStorage.On("Get", mock.Anything, model.ID).Return(&models.ObservabilityPublisher{ + ID: models.ID(1), + VersionModelID: model.ID, + VersionID: models.ID(1), + Status: models.Pending, + Revision: 1, + ModelSchemaSpec: schemaSpec, + }, nil) + mockStorage.On("Update", mock.Anything, &models.ObservabilityPublisher{ + ID: models.ID(1), + VersionModelID: models.ID(1), + VersionID: model.ID, + Status: models.Failed, + Revision: 1, + ModelSchemaSpec: schemaSpec, + }, false).Return(&models.ObservabilityPublisher{ + ID: models.ID(1), + VersionModelID: model.ID, + VersionID: models.ID(1), + Status: models.Failed, + Revision: 1, + ModelSchemaSpec: schemaSpec, + }, nil) + return mockStorage + }(), + jobData: &queue.Job{ + Name: event.ObservabilityPublisherDeployment, + Arguments: queue.Arguments{ + dataArgKey: models.ObservabilityPublisherJob{ + ActionType: models.DeployPublisher, + Publisher: &models.ObservabilityPublisher{ + ID: models.ID(1), + VersionModelID: model.ID, + VersionID: models.ID(1), + Status: models.Pending, + Revision: 1, + ModelSchemaSpec: schemaSpec, + }, + WorkerData: &models.WorkerData{ + Project: "project-1", + ModelSchemaSpec: schemaSpec, + ModelName: "model-1", + ModelVersion: "1", + Revision: 1, + TopicSource: "caraml-project-1-model-1-1-prediction-log", + }, + }, + }, + }, + expectedError: fmt.Errorf("control plane is down"), + }, + { + desc: "deployment requeue due to fail fetch from db", + deployer: &deploymentMock.Deployer{}, + observabilityPublisherStorage: func() *storageMock.ObservabilityPublisherStorage { + mockStorage := &storageMock.ObservabilityPublisherStorage{} + mockStorage.On("Get", mock.Anything, model.ID).Return(nil, fmt.Errorf("database is down")) + return mockStorage + }(), + jobData: &queue.Job{ + Name: event.ObservabilityPublisherDeployment, + Arguments: queue.Arguments{ + dataArgKey: models.ObservabilityPublisherJob{ + ActionType: models.DeployPublisher, + Publisher: &models.ObservabilityPublisher{ + ID: models.ID(1), + VersionModelID: model.ID, + VersionID: models.ID(1), + Status: models.Pending, + Revision: 1, + ModelSchemaSpec: schemaSpec, + }, + WorkerData: &models.WorkerData{ + Project: "project-1", + ModelSchemaSpec: schemaSpec, + ModelName: "model-1", + ModelVersion: "1", + Revision: 1, + TopicSource: "caraml-project-1-model-1-1-prediction-log", + }, + }, + }, + }, + expectedError: queue.RetryableError{Message: "database is down"}, + }, + { + desc: "undeployment completed", + deployer: func() *deploymentMock.Deployer { + mockDeployer := &deploymentMock.Deployer{} + mockDeployer.On("GetDeployedManifest", mock.Anything, &models.WorkerData{ + Project: "project-1", + ModelSchemaSpec: schemaSpec, + ModelName: "model-1", + ModelVersion: "1", + Revision: 1, + TopicSource: "caraml-project-1-model-1-1-prediction-log", + }).Return(nil, nil) + mockDeployer.On("Undeploy", mock.Anything, &models.WorkerData{ + Project: "project-1", + ModelSchemaSpec: schemaSpec, + ModelName: "model-1", + ModelVersion: "1", + Revision: 1, + TopicSource: "caraml-project-1-model-1-1-prediction-log", + }).Return(nil) + return mockDeployer + }(), + observabilityPublisherStorage: func() *storageMock.ObservabilityPublisherStorage { + mockStorage := &storageMock.ObservabilityPublisherStorage{} + mockStorage.On("Get", mock.Anything, model.ID).Return(&models.ObservabilityPublisher{ + ID: models.ID(1), + VersionModelID: model.ID, + VersionID: models.ID(1), + Status: models.Running, + Revision: 1, + ModelSchemaSpec: schemaSpec, + }, nil) + mockStorage.On("Update", mock.Anything, &models.ObservabilityPublisher{ + ID: models.ID(1), + VersionModelID: models.ID(1), + VersionID: model.ID, + Status: models.Terminated, + Revision: 1, + ModelSchemaSpec: schemaSpec, + }, false).Return(&models.ObservabilityPublisher{ + ID: models.ID(1), + VersionModelID: model.ID, + VersionID: models.ID(1), + Status: models.Terminated, + Revision: 1, + ModelSchemaSpec: schemaSpec, + }, nil) + return mockStorage + }(), + jobData: &queue.Job{ + Name: event.ObservabilityPublisherDeployment, + Arguments: queue.Arguments{ + dataArgKey: models.ObservabilityPublisherJob{ + ActionType: models.UndeployPublisher, + Publisher: &models.ObservabilityPublisher{ + ID: models.ID(1), + VersionModelID: model.ID, + VersionID: models.ID(1), + Status: models.Pending, + Revision: 1, + ModelSchemaSpec: schemaSpec, + }, + WorkerData: &models.WorkerData{ + Project: "project-1", + ModelSchemaSpec: schemaSpec, + ModelName: "model-1", + ModelVersion: "1", + Revision: 1, + TopicSource: "caraml-project-1-model-1-1-prediction-log", + }, + }, + }, + }, + }, + { + desc: "undeployment failed; error during deployment to control plane", + deployer: func() *deploymentMock.Deployer { + mockDeployer := &deploymentMock.Deployer{} + mockDeployer.On("GetDeployedManifest", mock.Anything, &models.WorkerData{ + Project: "project-1", + ModelSchemaSpec: schemaSpec, + ModelName: "model-1", + ModelVersion: "1", + Revision: 1, + TopicSource: "caraml-project-1-model-1-1-prediction-log", + }).Return(nil, nil) + mockDeployer.On("Undeploy", mock.Anything, &models.WorkerData{ + Project: "project-1", + ModelSchemaSpec: schemaSpec, + ModelName: "model-1", + ModelVersion: "1", + Revision: 1, + TopicSource: "caraml-project-1-model-1-1-prediction-log", + }).Return(fmt.Errorf("control plane is down")) + return mockDeployer + }(), + observabilityPublisherStorage: func() *storageMock.ObservabilityPublisherStorage { + mockStorage := &storageMock.ObservabilityPublisherStorage{} + mockStorage.On("Get", mock.Anything, model.ID).Return(&models.ObservabilityPublisher{ + ID: models.ID(1), + VersionModelID: model.ID, + VersionID: models.ID(1), + Status: models.Pending, + Revision: 1, + ModelSchemaSpec: schemaSpec, + }, nil) + mockStorage.On("Update", mock.Anything, &models.ObservabilityPublisher{ + ID: models.ID(1), + VersionModelID: models.ID(1), + VersionID: model.ID, + Status: models.Failed, + Revision: 1, + ModelSchemaSpec: schemaSpec, + }, false).Return(&models.ObservabilityPublisher{ + ID: models.ID(1), + VersionModelID: model.ID, + VersionID: models.ID(1), + Status: models.Failed, + Revision: 1, + ModelSchemaSpec: schemaSpec, + }, nil) + return mockStorage + }(), + jobData: &queue.Job{ + Name: event.ObservabilityPublisherDeployment, + Arguments: queue.Arguments{ + dataArgKey: models.ObservabilityPublisherJob{ + ActionType: models.UndeployPublisher, + Publisher: &models.ObservabilityPublisher{ + ID: models.ID(1), + VersionModelID: model.ID, + VersionID: models.ID(1), + Status: models.Pending, + Revision: 1, + ModelSchemaSpec: schemaSpec, + }, + WorkerData: &models.WorkerData{ + Project: "project-1", + ModelSchemaSpec: schemaSpec, + ModelName: "model-1", + ModelVersion: "1", + Revision: 1, + TopicSource: "caraml-project-1-model-1-1-prediction-log", + }, + }, + }, + }, + expectedError: fmt.Errorf("control plane is down"), + }, + { + desc: "redeployment completed", + deployer: func() *deploymentMock.Deployer { + mockDeployer := &deploymentMock.Deployer{} + mockDeployer.On("GetDeployedManifest", mock.Anything, &models.WorkerData{ + Project: "project-1", + ModelSchemaSpec: schemaSpec, + ModelName: "model-1", + ModelVersion: "1", + Revision: 2, + TopicSource: "caraml-project-1-model-1-1-prediction-log", + }).Return(&deployment.Manifest{ + Deployment: &v1.Deployment{ + ObjectMeta: metav1.ObjectMeta{ + Name: "model-1-observability-publisher", + Namespace: "project-1", + Annotations: map[string]string{ + deployment.PublisherRevisionAnnotationKey: "1", + }, + }, + Status: v1.DeploymentStatus{ + UnavailableReplicas: 0, + AvailableReplicas: 1, + }, + }, + Secret: &corev1.Secret{ + ObjectMeta: metav1.ObjectMeta{ + Name: "model-1-config", + }, + }, + OnProgress: false, + }, nil) + mockDeployer.On("Deploy", mock.Anything, &models.WorkerData{ + Project: "project-1", + ModelSchemaSpec: schemaSpec, + ModelName: "model-1", + ModelVersion: "1", + Revision: 2, + TopicSource: "caraml-project-1-model-1-1-prediction-log", + }).Return(nil) + return mockDeployer + }(), + observabilityPublisherStorage: func() *storageMock.ObservabilityPublisherStorage { + mockStorage := &storageMock.ObservabilityPublisherStorage{} + mockStorage.On("Get", mock.Anything, model.ID).Return(&models.ObservabilityPublisher{ + ID: models.ID(1), + VersionModelID: model.ID, + VersionID: models.ID(1), + Status: models.Pending, + Revision: 2, + ModelSchemaSpec: schemaSpec, + }, nil) + mockStorage.On("Update", mock.Anything, &models.ObservabilityPublisher{ + ID: models.ID(1), + VersionModelID: models.ID(1), + VersionID: model.ID, + Status: models.Running, + Revision: 2, + ModelSchemaSpec: schemaSpec, + }, false).Return(&models.ObservabilityPublisher{ + ID: models.ID(1), + VersionModelID: model.ID, + VersionID: models.ID(1), + Status: models.Running, + Revision: 2, + ModelSchemaSpec: schemaSpec, + }, nil) + return mockStorage + }(), + jobData: &queue.Job{ + Name: event.ObservabilityPublisherDeployment, + Arguments: queue.Arguments{ + dataArgKey: models.ObservabilityPublisherJob{ + ActionType: models.DeployPublisher, + Publisher: &models.ObservabilityPublisher{ + ID: models.ID(1), + VersionModelID: model.ID, + VersionID: models.ID(1), + Status: models.Pending, + Revision: 2, + ModelSchemaSpec: schemaSpec, + }, + WorkerData: &models.WorkerData{ + Project: "project-1", + ModelSchemaSpec: schemaSpec, + ModelName: "model-1", + ModelVersion: "1", + Revision: 2, + TopicSource: "caraml-project-1-model-1-1-prediction-log", + }, + }, + }, + }, + }, + { + desc: "redeployment requeued due to failed save state", + deployer: func() *deploymentMock.Deployer { + mockDeployer := &deploymentMock.Deployer{} + mockDeployer.On("GetDeployedManifest", mock.Anything, &models.WorkerData{ + Project: "project-1", + ModelSchemaSpec: schemaSpec, + ModelName: "model-1", + ModelVersion: "1", + Revision: 2, + TopicSource: "caraml-project-1-model-1-1-prediction-log", + }).Return(&deployment.Manifest{ + Deployment: &v1.Deployment{ + ObjectMeta: metav1.ObjectMeta{ + Name: "model-1-observability-publisher", + Namespace: "project-1", + Annotations: map[string]string{ + deployment.PublisherRevisionAnnotationKey: "1", + }, + }, + Status: v1.DeploymentStatus{ + UnavailableReplicas: 1, + AvailableReplicas: 1, + }, + }, + Secret: &corev1.Secret{ + ObjectMeta: metav1.ObjectMeta{ + Name: "model-1-config", + }, + }, + OnProgress: false, + }, nil) + mockDeployer.On("Deploy", mock.Anything, &models.WorkerData{ + Project: "project-1", + ModelSchemaSpec: schemaSpec, + ModelName: "model-1", + ModelVersion: "1", + Revision: 2, + TopicSource: "caraml-project-1-model-1-1-prediction-log", + }).Return(nil) + return mockDeployer + }(), + observabilityPublisherStorage: func() *storageMock.ObservabilityPublisherStorage { + mockStorage := &storageMock.ObservabilityPublisherStorage{} + mockStorage.On("Get", mock.Anything, model.ID).Return(&models.ObservabilityPublisher{ + ID: models.ID(1), + VersionModelID: model.ID, + VersionID: models.ID(1), + Status: models.Pending, + Revision: 2, + ModelSchemaSpec: schemaSpec, + }, nil) + mockStorage.On("Update", mock.Anything, &models.ObservabilityPublisher{ + ID: models.ID(1), + VersionModelID: models.ID(1), + VersionID: model.ID, + Status: models.Running, + Revision: 2, + ModelSchemaSpec: schemaSpec, + }, false).Return(nil, fmt.Errorf("connection is lost")) + return mockStorage + }(), + jobData: &queue.Job{ + Name: event.ObservabilityPublisherDeployment, + Arguments: queue.Arguments{ + dataArgKey: models.ObservabilityPublisherJob{ + ActionType: models.DeployPublisher, + Publisher: &models.ObservabilityPublisher{ + ID: models.ID(1), + VersionModelID: model.ID, + VersionID: models.ID(1), + Status: models.Pending, + Revision: 2, + ModelSchemaSpec: schemaSpec, + }, + WorkerData: &models.WorkerData{ + Project: "project-1", + ModelSchemaSpec: schemaSpec, + ModelName: "model-1", + ModelVersion: "1", + Revision: 2, + TopicSource: "caraml-project-1-model-1-1-prediction-log", + }, + }, + }, + }, + expectedError: queue.RetryableError{Message: "connection is lost"}, + }, + { + desc: "deployment request revision somehow greater than from db", + deployer: func() *deploymentMock.Deployer { + mockDeployer := &deploymentMock.Deployer{} + return mockDeployer + }(), + observabilityPublisherStorage: func() *storageMock.ObservabilityPublisherStorage { + mockStorage := &storageMock.ObservabilityPublisherStorage{} + mockStorage.On("Get", mock.Anything, model.ID).Return(&models.ObservabilityPublisher{ + ID: models.ID(1), + VersionModelID: model.ID, + VersionID: models.ID(1), + Status: models.Pending, + Revision: 1, + ModelSchemaSpec: schemaSpec, + }, nil) + return mockStorage + }(), + jobData: &queue.Job{ + Name: event.ObservabilityPublisherDeployment, + Arguments: queue.Arguments{ + dataArgKey: models.ObservabilityPublisherJob{ + ActionType: models.DeployPublisher, + Publisher: &models.ObservabilityPublisher{ + ID: models.ID(1), + VersionModelID: model.ID, + VersionID: models.ID(1), + Status: models.Pending, + Revision: 2, + ModelSchemaSpec: schemaSpec, + }, + WorkerData: &models.WorkerData{ + Project: "project-1", + ModelSchemaSpec: schemaSpec, + ModelName: "model-1", + ModelVersion: "1", + Revision: 2, + TopicSource: "caraml-project-1-model-1-1-prediction-log", + }, + }, + }, + }, + expectedError: fmt.Errorf("actual publisher revision should not be lower than the one from submitted job"), + }, + { + desc: "deployment request already stale compare from db, hence skipped", + deployer: func() *deploymentMock.Deployer { + mockDeployer := &deploymentMock.Deployer{} + return mockDeployer + }(), + observabilityPublisherStorage: func() *storageMock.ObservabilityPublisherStorage { + mockStorage := &storageMock.ObservabilityPublisherStorage{} + mockStorage.On("Get", mock.Anything, model.ID).Return(&models.ObservabilityPublisher{ + ID: models.ID(1), + VersionModelID: model.ID, + VersionID: models.ID(1), + Status: models.Running, + Revision: 2, + ModelSchemaSpec: schemaSpec, + }, nil) + return mockStorage + }(), + jobData: &queue.Job{ + Name: event.ObservabilityPublisherDeployment, + Arguments: queue.Arguments{ + dataArgKey: models.ObservabilityPublisherJob{ + ActionType: models.DeployPublisher, + Publisher: &models.ObservabilityPublisher{ + ID: models.ID(1), + VersionModelID: model.ID, + VersionID: models.ID(1), + Status: models.Pending, + Revision: 1, + ModelSchemaSpec: schemaSpec, + }, + WorkerData: &models.WorkerData{ + Project: "project-1", + ModelSchemaSpec: schemaSpec, + ModelName: "model-1", + ModelVersion: "1", + Revision: 1, + TopicSource: "caraml-project-1-model-1-1-prediction-log", + }, + }, + }, + }, + }, + { + desc: "deployment request is requeue due to another deployment from previous revision is ongoing", + deployer: func() *deploymentMock.Deployer { + mockDeployer := &deploymentMock.Deployer{} + mockDeployer.On("GetDeployedManifest", mock.Anything, &models.WorkerData{ + Project: "project-1", + ModelSchemaSpec: schemaSpec, + ModelName: "model-1", + ModelVersion: "1", + Revision: 2, + TopicSource: "caraml-project-1-model-1-1-prediction-log", + }).Return(&deployment.Manifest{ + Deployment: &v1.Deployment{ + ObjectMeta: metav1.ObjectMeta{ + Name: "model-1-observability-publisher", + Namespace: "project-1", + Annotations: map[string]string{ + deployment.PublisherRevisionAnnotationKey: "1", + }, + }, + Status: v1.DeploymentStatus{ + UnavailableReplicas: 1, + AvailableReplicas: 1, + }, + }, + Secret: &corev1.Secret{ + ObjectMeta: metav1.ObjectMeta{ + Name: "model-1-config", + }, + }, + OnProgress: true, + }, nil) + return mockDeployer + }(), + observabilityPublisherStorage: func() *storageMock.ObservabilityPublisherStorage { + mockStorage := &storageMock.ObservabilityPublisherStorage{} + mockStorage.On("Get", mock.Anything, model.ID).Return(&models.ObservabilityPublisher{ + ID: models.ID(1), + VersionModelID: model.ID, + VersionID: models.ID(1), + Status: models.Pending, + Revision: 2, + ModelSchemaSpec: schemaSpec, + }, nil) + return mockStorage + }(), + jobData: &queue.Job{ + Name: event.ObservabilityPublisherDeployment, + Arguments: queue.Arguments{ + dataArgKey: models.ObservabilityPublisherJob{ + ActionType: models.DeployPublisher, + Publisher: &models.ObservabilityPublisher{ + ID: models.ID(1), + VersionModelID: model.ID, + VersionID: models.ID(1), + Status: models.Pending, + Revision: 2, + ModelSchemaSpec: schemaSpec, + }, + WorkerData: &models.WorkerData{ + Project: "project-1", + ModelSchemaSpec: schemaSpec, + ModelName: "model-1", + ModelVersion: "1", + Revision: 2, + TopicSource: "caraml-project-1-model-1-1-prediction-log", + }, + }, + }, + }, + expectedError: queue.RetryableError{Message: "there is on going deployment for previous revision, this deployment request must wait until previous deployment success"}, + }, + { + desc: "deployment request is failed due to another deployment from greater revision is ongoing", + deployer: func() *deploymentMock.Deployer { + mockDeployer := &deploymentMock.Deployer{} + mockDeployer.On("GetDeployedManifest", mock.Anything, &models.WorkerData{ + Project: "project-1", + ModelSchemaSpec: schemaSpec, + ModelName: "model-1", + ModelVersion: "1", + Revision: 2, + TopicSource: "caraml-project-1-model-1-1-prediction-log", + }).Return(&deployment.Manifest{ + Deployment: &v1.Deployment{ + ObjectMeta: metav1.ObjectMeta{ + Name: "model-1-observability-publisher", + Namespace: "project-1", + Annotations: map[string]string{ + deployment.PublisherRevisionAnnotationKey: "3", + }, + }, + Status: v1.DeploymentStatus{ + UnavailableReplicas: 1, + AvailableReplicas: 1, + }, + }, + Secret: &corev1.Secret{ + ObjectMeta: metav1.ObjectMeta{ + Name: "model-1-config", + }, + }, + OnProgress: true, + }, nil) + return mockDeployer + }(), + observabilityPublisherStorage: func() *storageMock.ObservabilityPublisherStorage { + mockStorage := &storageMock.ObservabilityPublisherStorage{} + mockStorage.On("Get", mock.Anything, model.ID).Return(&models.ObservabilityPublisher{ + ID: models.ID(1), + VersionModelID: model.ID, + VersionID: models.ID(1), + Status: models.Pending, + Revision: 2, + ModelSchemaSpec: schemaSpec, + }, nil) + return mockStorage + }(), + jobData: &queue.Job{ + Name: event.ObservabilityPublisherDeployment, + Arguments: queue.Arguments{ + dataArgKey: models.ObservabilityPublisherJob{ + ActionType: models.DeployPublisher, + Publisher: &models.ObservabilityPublisher{ + ID: models.ID(1), + VersionModelID: model.ID, + VersionID: models.ID(1), + Status: models.Pending, + Revision: 2, + ModelSchemaSpec: schemaSpec, + }, + WorkerData: &models.WorkerData{ + Project: "project-1", + ModelSchemaSpec: schemaSpec, + ModelName: "model-1", + ModelVersion: "1", + Revision: 2, + TopicSource: "caraml-project-1-model-1-1-prediction-log", + }, + }, + }, + }, + expectedError: fmt.Errorf("latest deployment already being deployed"), + }, + { + desc: "deployment fail due to incorrect job data", + deployer: &deploymentMock.Deployer{}, + observabilityPublisherStorage: &storageMock.ObservabilityPublisherStorage{}, + jobData: &queue.Job{ + ID: 1, + Name: event.ObservabilityPublisherDeployment, + Arguments: queue.Arguments{ + dataArgKey: "randomString", + }, + }, + expectedError: fmt.Errorf("job data for ID: 1 is not in ObservabilityPublisherJob type"), + }, + } + for _, tC := range testCases { + t.Run(tC.desc, func(t *testing.T) { + depl := &ObservabilityPublisherDeployment{ + Deployer: tC.deployer, + ObservabilityPublisherStorage: tC.observabilityPublisherStorage, + } + err := depl.Deploy(tC.jobData) + assert.Equal(t, tC.expectedError, err) + tC.deployer.AssertExpectations(t) + tC.observabilityPublisherStorage.AssertExpectations(t) + + }) + } +} diff --git a/api/service/model_endpoint_service.go b/api/service/model_endpoint_service.go index c0735ba0d..41dd8afdf 100644 --- a/api/service/model_endpoint_service.go +++ b/api/service/model_endpoint_service.go @@ -20,6 +20,7 @@ import ( "fmt" "strings" + "github.com/caraml-dev/merlin/pkg/observability/event" "github.com/caraml-dev/merlin/pkg/protocol" "github.com/caraml-dev/merlin/storage" "github.com/pkg/errors" @@ -65,23 +66,25 @@ type ModelEndpointsService interface { } // NewModelEndpointsService returns an initialized ModelEndpointsService. -func NewModelEndpointsService(istioClients map[string]istio.Client, modelEndpointStorage storage.ModelEndpointStorage, versionEndpointStorage storage.VersionEndpointStorage, environment string) ModelEndpointsService { - return newModelEndpointsService(istioClients, modelEndpointStorage, versionEndpointStorage, environment) +func NewModelEndpointsService(istioClients map[string]istio.Client, modelEndpointStorage storage.ModelEndpointStorage, versionEndpointStorage storage.VersionEndpointStorage, environment string, observabilityEventProducer event.EventProducer) ModelEndpointsService { + return newModelEndpointsService(istioClients, modelEndpointStorage, versionEndpointStorage, environment, observabilityEventProducer) } type modelEndpointsService struct { - istioClients map[string]istio.Client - modelEndpointStorage storage.ModelEndpointStorage - versionEndpointStorage storage.VersionEndpointStorage - environment string + istioClients map[string]istio.Client + modelEndpointStorage storage.ModelEndpointStorage + versionEndpointStorage storage.VersionEndpointStorage + environment string + observabilityEventProducer event.EventProducer } -func newModelEndpointsService(istioClients map[string]istio.Client, modelEndpointStorage storage.ModelEndpointStorage, versionEndpointStorage storage.VersionEndpointStorage, environment string) *modelEndpointsService { +func newModelEndpointsService(istioClients map[string]istio.Client, modelEndpointStorage storage.ModelEndpointStorage, versionEndpointStorage storage.VersionEndpointStorage, environment string, observabilityEventProducer event.EventProducer) *modelEndpointsService { return &modelEndpointsService{ - istioClients: istioClients, - modelEndpointStorage: modelEndpointStorage, - versionEndpointStorage: versionEndpointStorage, - environment: environment, + istioClients: istioClients, + modelEndpointStorage: modelEndpointStorage, + versionEndpointStorage: versionEndpointStorage, + environment: environment, + observabilityEventProducer: observabilityEventProducer, } } @@ -140,6 +143,13 @@ func (s *modelEndpointsService) DeployEndpoint(ctx context.Context, model *model return nil, err } + // publish model endpoint change event to trigger consumer deployment + if model.ObservabilitySupported { + if err := s.observabilityEventProducer.ModelEndpointChangeEvent(endpoint, model); err != nil { + return nil, err + } + } + return endpoint, nil } @@ -184,6 +194,13 @@ func (s *modelEndpointsService) UpdateEndpoint(ctx context.Context, model *model return nil, err } + // publish model endpoint change event to trigger consumer deployment + if model.ObservabilitySupported { + if err := s.observabilityEventProducer.ModelEndpointChangeEvent(newEndpoint, model); err != nil { + return nil, err + } + } + return newEndpoint, nil } @@ -208,6 +225,13 @@ func (s *modelEndpointsService) UndeployEndpoint(ctx context.Context, model *mod return nil, err } + // publish model endpoint change event to trigger consumer undeployment + if model.ObservabilitySupported { + if err := s.observabilityEventProducer.ModelEndpointChangeEvent(nil, model); err != nil { + return nil, err + } + } + return endpoint, nil } diff --git a/api/service/model_endpoint_service_test.go b/api/service/model_endpoint_service_test.go index b6960a4f1..d6647331d 100644 --- a/api/service/model_endpoint_service_test.go +++ b/api/service/model_endpoint_service_test.go @@ -21,6 +21,9 @@ import ( "github.com/caraml-dev/merlin/istio" istioCliMock "github.com/caraml-dev/merlin/istio/mocks" "github.com/caraml-dev/merlin/models" + "github.com/caraml-dev/merlin/pkg/observability/event" + eventMock "github.com/caraml-dev/merlin/pkg/observability/event/mocks" + "github.com/caraml-dev/merlin/pkg/protocol" "github.com/caraml-dev/merlin/storage" storageMock "github.com/caraml-dev/merlin/storage/mocks" @@ -43,6 +46,7 @@ func Test_modelEndpointsService_DeployEndpoint(t *testing.T) { modelEndpointStorage storage.ModelEndpointStorage versionEndpointStorage storage.VersionEndpointStorage environment string + eventProducer event.EventProducer } type args struct { @@ -66,6 +70,11 @@ func Test_modelEndpointsService_DeployEndpoint(t *testing.T) { modelEndpointStorage: &storageMock.ModelEndpointStorage{}, versionEndpointStorage: &storageMock.VersionEndpointStorage{}, environment: "staging", + eventProducer: func() event.EventProducer { + eProducer := &eventMock.EventProducer{} + eProducer.On("ModelEndpointChangeEvent", mock.Anything, mock.Anything).Return(nil) + return eProducer + }(), }, mockFunc: func(s *modelEndpointsService) { vs, _ := s.createVirtualService(model1, modelEndpointRequest1) @@ -146,7 +155,7 @@ func Test_modelEndpointsService_DeployEndpoint(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - s := newModelEndpointsService(tt.fields.istioClients, tt.fields.modelEndpointStorage, tt.fields.versionEndpointStorage, tt.fields.environment) + s := newModelEndpointsService(tt.fields.istioClients, tt.fields.modelEndpointStorage, tt.fields.versionEndpointStorage, tt.fields.environment, tt.fields.eventProducer) tt.mockFunc(s) @@ -240,10 +249,11 @@ func Test_modelEndpointsService_UpdateEndpoint(t *testing.T) { } type fields struct { - istioClients map[string]istio.Client - modelEndpointStorage storage.ModelEndpointStorage - versionEndpointStorage storage.VersionEndpointStorage - environment string + istioClients map[string]istio.Client + modelEndpointStorage storage.ModelEndpointStorage + versionEndpointStorage storage.VersionEndpointStorage + environment string + observabilityEventProducer event.EventProducer } type args struct { ctx context.Context @@ -349,7 +359,7 @@ func Test_modelEndpointsService_UpdateEndpoint(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - s := newModelEndpointsService(tt.fields.istioClients, tt.fields.modelEndpointStorage, tt.fields.versionEndpointStorage, tt.fields.environment) + s := newModelEndpointsService(tt.fields.istioClients, tt.fields.modelEndpointStorage, tt.fields.versionEndpointStorage, tt.fields.environment, tt.fields.observabilityEventProducer) tt.mockFunc(s) @@ -368,10 +378,11 @@ func Test_modelEndpointsService_UndeployEndpoint(t *testing.T) { modelEndpointResponseTerminated.Status = models.EndpointTerminated type fields struct { - istioClients map[string]istio.Client - modelEndpointStorage storage.ModelEndpointStorage - versionEndpointStorage storage.VersionEndpointStorage - environment string + istioClients map[string]istio.Client + modelEndpointStorage storage.ModelEndpointStorage + versionEndpointStorage storage.VersionEndpointStorage + environment string + observabilityEventProducer event.EventProducer } type args struct { ctx context.Context @@ -432,7 +443,7 @@ func Test_modelEndpointsService_UndeployEndpoint(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - s := newModelEndpointsService(tt.fields.istioClients, tt.fields.modelEndpointStorage, tt.fields.versionEndpointStorage, tt.fields.environment) + s := newModelEndpointsService(tt.fields.istioClients, tt.fields.modelEndpointStorage, tt.fields.versionEndpointStorage, tt.fields.environment, tt.fields.observabilityEventProducer) tt.mockFunc(s) diff --git a/api/service/version_endpoint_service.go b/api/service/version_endpoint_service.go index 8152dbc6c..b4ef2b347 100644 --- a/api/service/version_endpoint_service.go +++ b/api/service/version_endpoint_service.go @@ -124,7 +124,7 @@ func (k *endpointService) DeployEndpoint(ctx context.Context, environment *model } // override existing endpoint configuration with the user request - err = k.override(endpoint, newEndpoint, environment) + err = k.override(endpoint, newEndpoint, environment, model) if err != nil { return nil, err } @@ -150,7 +150,7 @@ func (k *endpointService) DeployEndpoint(ctx context.Context, environment *model } // override left version endpoint with values on the right version endpoint -func (k *endpointService) override(left *models.VersionEndpoint, right *models.VersionEndpoint, environment *models.Environment) error { +func (k *endpointService) override(left *models.VersionEndpoint, right *models.VersionEndpoint, environment *models.Environment, model *models.Model) error { // override deployment mode if right.DeploymentMode != deployment.EmptyDeploymentMode { left.DeploymentMode = right.DeploymentMode @@ -242,7 +242,7 @@ func (k *endpointService) override(left *models.VersionEndpoint, right *models.V left.Protocol = protocol.HttpJson } - left.EnableModelObservability = right.EnableModelObservability + left.EnableModelObservability = right.EnableModelObservability && model.ObservabilitySupported return nil } diff --git a/api/storage/mocks/observability_publisher_storage.go b/api/storage/mocks/observability_publisher_storage.go new file mode 100644 index 000000000..7cc874082 --- /dev/null +++ b/api/storage/mocks/observability_publisher_storage.go @@ -0,0 +1,149 @@ +// Code generated by mockery v2.39.2. DO NOT EDIT. + +package mocks + +import ( + context "context" + + models "github.com/caraml-dev/merlin/models" + mock "github.com/stretchr/testify/mock" +) + +// ObservabilityPublisherStorage is an autogenerated mock type for the ObservabilityPublisherStorage type +type ObservabilityPublisherStorage struct { + mock.Mock +} + +// Create provides a mock function with given fields: ctx, publisher +func (_m *ObservabilityPublisherStorage) Create(ctx context.Context, publisher *models.ObservabilityPublisher) (*models.ObservabilityPublisher, error) { + ret := _m.Called(ctx, publisher) + + if len(ret) == 0 { + panic("no return value specified for Create") + } + + var r0 *models.ObservabilityPublisher + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *models.ObservabilityPublisher) (*models.ObservabilityPublisher, error)); ok { + return rf(ctx, publisher) + } + if rf, ok := ret.Get(0).(func(context.Context, *models.ObservabilityPublisher) *models.ObservabilityPublisher); ok { + r0 = rf(ctx, publisher) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*models.ObservabilityPublisher) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *models.ObservabilityPublisher) error); ok { + r1 = rf(ctx, publisher) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Get provides a mock function with given fields: ctx, publisherID +func (_m *ObservabilityPublisherStorage) Get(ctx context.Context, publisherID models.ID) (*models.ObservabilityPublisher, error) { + ret := _m.Called(ctx, publisherID) + + if len(ret) == 0 { + panic("no return value specified for Get") + } + + var r0 *models.ObservabilityPublisher + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, models.ID) (*models.ObservabilityPublisher, error)); ok { + return rf(ctx, publisherID) + } + if rf, ok := ret.Get(0).(func(context.Context, models.ID) *models.ObservabilityPublisher); ok { + r0 = rf(ctx, publisherID) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*models.ObservabilityPublisher) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, models.ID) error); ok { + r1 = rf(ctx, publisherID) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// GetByModelID provides a mock function with given fields: ctx, modelID +func (_m *ObservabilityPublisherStorage) GetByModelID(ctx context.Context, modelID models.ID) (*models.ObservabilityPublisher, error) { + ret := _m.Called(ctx, modelID) + + if len(ret) == 0 { + panic("no return value specified for GetByModelID") + } + + var r0 *models.ObservabilityPublisher + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, models.ID) (*models.ObservabilityPublisher, error)); ok { + return rf(ctx, modelID) + } + if rf, ok := ret.Get(0).(func(context.Context, models.ID) *models.ObservabilityPublisher); ok { + r0 = rf(ctx, modelID) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*models.ObservabilityPublisher) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, models.ID) error); ok { + r1 = rf(ctx, modelID) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Update provides a mock function with given fields: ctx, publisher, increseRevision +func (_m *ObservabilityPublisherStorage) Update(ctx context.Context, publisher *models.ObservabilityPublisher, increseRevision bool) (*models.ObservabilityPublisher, error) { + ret := _m.Called(ctx, publisher, increseRevision) + + if len(ret) == 0 { + panic("no return value specified for Update") + } + + var r0 *models.ObservabilityPublisher + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *models.ObservabilityPublisher, bool) (*models.ObservabilityPublisher, error)); ok { + return rf(ctx, publisher, increseRevision) + } + if rf, ok := ret.Get(0).(func(context.Context, *models.ObservabilityPublisher, bool) *models.ObservabilityPublisher); ok { + r0 = rf(ctx, publisher, increseRevision) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*models.ObservabilityPublisher) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *models.ObservabilityPublisher, bool) error); ok { + r1 = rf(ctx, publisher, increseRevision) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// NewObservabilityPublisherStorage creates a new instance of ObservabilityPublisherStorage. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewObservabilityPublisherStorage(t interface { + mock.TestingT + Cleanup(func()) +}) *ObservabilityPublisherStorage { + mock := &ObservabilityPublisherStorage{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/api/storage/mocks/version_storage.go b/api/storage/mocks/version_storage.go new file mode 100644 index 000000000..cb528a880 --- /dev/null +++ b/api/storage/mocks/version_storage.go @@ -0,0 +1,59 @@ +// Code generated by mockery v2.39.2. DO NOT EDIT. + +package mocks + +import ( + context "context" + + models "github.com/caraml-dev/merlin/models" + mock "github.com/stretchr/testify/mock" +) + +// VersionStorage is an autogenerated mock type for the VersionStorage type +type VersionStorage struct { + mock.Mock +} + +// FindByID provides a mock function with given fields: ctx, versionID, modelID +func (_m *VersionStorage) FindByID(ctx context.Context, versionID models.ID, modelID models.ID) (*models.Version, error) { + ret := _m.Called(ctx, versionID, modelID) + + if len(ret) == 0 { + panic("no return value specified for FindByID") + } + + var r0 *models.Version + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, models.ID, models.ID) (*models.Version, error)); ok { + return rf(ctx, versionID, modelID) + } + if rf, ok := ret.Get(0).(func(context.Context, models.ID, models.ID) *models.Version); ok { + r0 = rf(ctx, versionID, modelID) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*models.Version) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, models.ID, models.ID) error); ok { + r1 = rf(ctx, versionID, modelID) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// NewVersionStorage creates a new instance of VersionStorage. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewVersionStorage(t interface { + mock.TestingT + Cleanup(func()) +}) *VersionStorage { + mock := &VersionStorage{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/api/storage/observability_publisher_storage.go b/api/storage/observability_publisher_storage.go new file mode 100644 index 000000000..cf5875e5a --- /dev/null +++ b/api/storage/observability_publisher_storage.go @@ -0,0 +1,71 @@ +package storage + +import ( + "context" + "errors" + + "github.com/caraml-dev/merlin/models" + "gorm.io/gorm" +) + +// ObservabilityPublisherStorage +type ObservabilityPublisherStorage interface { + GetByModelID(ctx context.Context, modelID models.ID) (*models.ObservabilityPublisher, error) + Get(ctx context.Context, publisherID models.ID) (*models.ObservabilityPublisher, error) + Create(ctx context.Context, publisher *models.ObservabilityPublisher) (*models.ObservabilityPublisher, error) + Update(ctx context.Context, publisher *models.ObservabilityPublisher, increseRevision bool) (*models.ObservabilityPublisher, error) +} + +type obsPublisherStorage struct { + db *gorm.DB +} + +func NewObservabilityPublisherStorage(db *gorm.DB) *obsPublisherStorage { + return &obsPublisherStorage{ + db: db, + } +} + +func (op *obsPublisherStorage) GetByModelID(ctx context.Context, modelID models.ID) (*models.ObservabilityPublisher, error) { + var publisher models.ObservabilityPublisher + if err := op.db.Where("version_model_id = ?", modelID).First(&publisher).Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, nil + } + return nil, err + } + return &publisher, nil +} + +func (op *obsPublisherStorage) Get(ctx context.Context, publisherID models.ID) (*models.ObservabilityPublisher, error) { + var publisher *models.ObservabilityPublisher + if err := op.db.Where("id = ?", publisherID).First(&publisher).Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, nil + } + return nil, err + } + return publisher, nil +} +func (op *obsPublisherStorage) Create(ctx context.Context, publisher *models.ObservabilityPublisher) (*models.ObservabilityPublisher, error) { + publisher.Revision = 1 + if err := op.db.Create(publisher).Error; err != nil { + return nil, err + } + return publisher, nil +} + +func (op *obsPublisherStorage) Update(ctx context.Context, publisher *models.ObservabilityPublisher, increseRevision bool) (*models.ObservabilityPublisher, error) { + currentRevision := publisher.Revision + if increseRevision { + publisher.Revision++ + } + result := op.db.Model(&models.ObservabilityPublisher{}).Where("id = ? AND revision = ?", publisher.ID, currentRevision).Updates(publisher) + if result.Error != nil { + return nil, result.Error + } + if result.RowsAffected == 0 { + return nil, gorm.ErrRecordNotFound + } + return publisher, nil +} diff --git a/api/storage/observability_publisher_storage_test.go b/api/storage/observability_publisher_storage_test.go new file mode 100644 index 000000000..b905618d0 --- /dev/null +++ b/api/storage/observability_publisher_storage_test.go @@ -0,0 +1,127 @@ +//go:build integration_local || integration +// +build integration_local integration + +package storage + +import ( + "context" + "testing" + + "github.com/caraml-dev/merlin/database" + "github.com/caraml-dev/merlin/mlp" + "github.com/caraml-dev/merlin/models" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "gorm.io/gorm" +) + +func Test_obsPublisherStorage(t *testing.T) { + database.WithTestDatabase(t, func(t *testing.T, db *gorm.DB) { + env1 := models.Environment{ + Name: "env1", + Cluster: "k8s", + IsPredictionJobEnabled: true, + } + db.Create(&env1) + + p := mlp.Project{ + ID: 1, + Name: "project", + MLFlowTrackingURL: "http://mlflow:5000", + } + + m := models.Model{ + ID: 1, + ProjectID: models.ID(p.ID), + ExperimentID: 1, + Name: "model", + Type: models.ModelTypeSkLearn, + } + db.Create(&m) + + modelSchema := &models.ModelSchema{ + ModelID: m.ID, + Spec: &models.SchemaSpec{ + SessionIDColumn: "session", + RowIDColumn: "row", + TagColumns: []string{"tag"}, + FeatureTypes: map[string]models.ValueType{ + "featureA": models.Float64, + "featureB": models.Float64, + "featureC": models.Int64, + "featureD": models.Boolean, + }, + ModelPredictionOutput: &models.ModelPredictionOutput{ + BinaryClassificationOutput: &models.BinaryClassificationOutput{ + ActualScoreColumn: "actual_score", + NegativeClassLabel: "negative", + PositiveClassLabel: "positive", + PredictionLabelColumn: "prediction_label", + PredictionScoreColumn: "prediction_score", + OutputClass: models.BinaryClassification, + }, + }, + }, + } + version := models.Version{ + ID: models.ID(1), + ModelID: m.ID, + ModelSchema: modelSchema, + } + db.Create(&version) + + observabilityPublisherStorage := NewObservabilityPublisherStorage(db) + observabilityPublisher := &models.ObservabilityPublisher{ + VersionModelID: m.ID, + VersionID: version.ID, + Status: models.Pending, + ModelSchemaSpec: version.ModelSchema.Spec, + } + + ctx := context.Background() + publisher, err := observabilityPublisherStorage.Create(ctx, observabilityPublisher) + require.NoError(t, err) + assert.Equal(t, models.ID(1), publisher.ID) + assert.Equal(t, 1, publisher.Revision) + + publisherFromDB, fetchErr := observabilityPublisherStorage.GetByModelID(ctx, m.ID) + require.NoError(t, fetchErr) + assert.Equal(t, publisher.ID, publisherFromDB.ID) + assert.Equal(t, publisher.Revision, publisherFromDB.Revision) + assert.Equal(t, publisher.ModelSchemaSpec, publisherFromDB.ModelSchemaSpec) + + noPublisher, fetchErr := observabilityPublisherStorage.GetByModelID(ctx, models.ID(2)) + require.NoError(t, fetchErr) + assert.Nil(t, noPublisher) + + publisherFromDB, fetchErr = observabilityPublisherStorage.Get(ctx, publisherFromDB.ID) + require.NoError(t, fetchErr) + assert.Equal(t, publisher.ID, publisherFromDB.ID) + assert.Equal(t, publisher.Revision, publisherFromDB.Revision) + assert.Equal(t, publisher.ModelSchemaSpec, publisherFromDB.ModelSchemaSpec) + + signal := make(chan bool, 1) + go func(p *models.ObservabilityPublisher) { + _p := *p + updatedPublisher, err := observabilityPublisherStorage.Update(ctx, &_p, true) + signal <- true + require.NoError(t, err) + assert.Equal(t, p.ID, updatedPublisher.ID) + assert.Equal(t, p.Revision+1, updatedPublisher.Revision) + }(publisherFromDB) + + <-signal + + // case when update when revision is not matched anymore + // update with revision 1 but the record already revision 2 + publisherFromDB.Status = models.Running + conflictedPublisher, err := observabilityPublisherStorage.Update(ctx, publisherFromDB, false) + assert.Equal(t, gorm.ErrRecordNotFound, err) + assert.Nil(t, conflictedPublisher) + + publisherFromDB, fetchErr = observabilityPublisherStorage.Get(ctx, publisherFromDB.ID) + require.NoError(t, fetchErr) + assert.Equal(t, 2, publisherFromDB.Revision) + + }) +} diff --git a/api/storage/version_storage.go b/api/storage/version_storage.go new file mode 100644 index 000000000..b4b3abd08 --- /dev/null +++ b/api/storage/version_storage.go @@ -0,0 +1,32 @@ +package storage + +import ( + "context" + "errors" + + "github.com/caraml-dev/merlin/models" + "gorm.io/gorm" +) + +type VersionStorage interface { + FindByID(ctx context.Context, versionID models.ID, modelID models.ID) (*models.Version, error) +} + +type versionStorage struct { + db *gorm.DB +} + +func NewVersionStorage(db *gorm.DB) VersionStorage { + return &versionStorage{db: db} +} + +func (v *versionStorage) FindByID(ctx context.Context, versionID models.ID, modelID models.ID) (*models.Version, error) { + var version models.Version + if err := v.db.Preload("ModelSchema").Where("versions.id = ? AND versions.model_id = ?", versionID, modelID).First(&version).Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, nil + } + return nil, err + } + return &version, nil +} diff --git a/api/storage/version_storage_test.go b/api/storage/version_storage_test.go new file mode 100644 index 000000000..9c36fc0c1 --- /dev/null +++ b/api/storage/version_storage_test.go @@ -0,0 +1,86 @@ +//go:build integration_local || integration +// +build integration_local integration + +package storage + +import ( + "context" + "testing" + + "github.com/caraml-dev/merlin/database" + "github.com/caraml-dev/merlin/mlp" + "github.com/caraml-dev/merlin/models" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "gorm.io/gorm" +) + +func Test_versionStorage_FindByID(t *testing.T) { + database.WithTestDatabase(t, func(t *testing.T, db *gorm.DB) { + env1 := models.Environment{ + Name: "env1", + Cluster: "k8s", + IsPredictionJobEnabled: true, + } + db.Create(&env1) + + p := mlp.Project{ + ID: 1, + Name: "project", + MLFlowTrackingURL: "http://mlflow:5000", + } + + m := models.Model{ + ID: 1, + ProjectID: models.ID(p.ID), + ExperimentID: 1, + Name: "model", + Type: models.ModelTypeSkLearn, + } + db.Create(&m) + + modelSchema := &models.ModelSchema{ + ID: models.ID(1), + ModelID: m.ID, + Spec: &models.SchemaSpec{ + SessionIDColumn: "prediction_id", + RowIDColumn: "row", + TagColumns: []string{"tag"}, + FeatureTypes: map[string]models.ValueType{ + "featureA": models.Float64, + "featureB": models.Float64, + "featureC": models.Int64, + "featureD": models.Boolean, + }, + ModelPredictionOutput: &models.ModelPredictionOutput{ + BinaryClassificationOutput: &models.BinaryClassificationOutput{ + ActualScoreColumn: "actual_score", + NegativeClassLabel: "negative", + PositiveClassLabel: "positive", + PredictionLabelColumn: "prediction_label", + PredictionScoreColumn: "prediction_score", + OutputClass: models.BinaryClassification, + }, + }, + }, + } + version := models.Version{ + ID: models.ID(1), + ModelID: m.ID, + ModelSchema: modelSchema, + } + db.Create(&version) + + storage := NewVersionStorage(db) + v, err := storage.FindByID(context.Background(), version.ID, m.ID) + require.NoError(t, err) + assert.Equal(t, v.ID, version.ID) + assert.Equal(t, &modelSchema.ID, version.ModelSchemaID) + + v, err = storage.FindByID(context.Background(), 2, m.ID) + + require.NoError(t, err) + assert.Nil(t, v) + + }) +} diff --git a/db-migrations/38_supported_observability_on_models.down.sql b/db-migrations/38_supported_observability_on_models.down.sql new file mode 100644 index 000000000..e5ef5582d --- /dev/null +++ b/db-migrations/38_supported_observability_on_models.down.sql @@ -0,0 +1 @@ +ALTER TABLE models DROP COLUMN observability_supported; \ No newline at end of file diff --git a/db-migrations/38_supported_observability_on_models.up.sql b/db-migrations/38_supported_observability_on_models.up.sql new file mode 100644 index 000000000..101d9b872 --- /dev/null +++ b/db-migrations/38_supported_observability_on_models.up.sql @@ -0,0 +1 @@ +ALTER TABLE models ADD COLUMN observability_supported BOOLEAN NOT NULL DEFAULT false; \ No newline at end of file diff --git a/db-migrations/39_observability_publisher.down.sql b/db-migrations/39_observability_publisher.down.sql new file mode 100644 index 000000000..0acc2bd78 --- /dev/null +++ b/db-migrations/39_observability_publisher.down.sql @@ -0,0 +1,2 @@ +DROP TABLE observability_publishers; +DROP TYPE publisher_status; \ No newline at end of file diff --git a/db-migrations/39_observability_publisher.up.sql b/db-migrations/39_observability_publisher.up.sql new file mode 100644 index 000000000..6b2a4fce6 --- /dev/null +++ b/db-migrations/39_observability_publisher.up.sql @@ -0,0 +1,14 @@ +CREATE TYPE publisher_status as ENUM ('pending', 'running', 'failed', 'terminated'); + +CREATE TABLE IF NOT EXISTS observability_publishers +( + id serial PRIMARY KEY, + version_model_id integer, + version_id integer, + revision integer, + status publisher_status, + model_schema_spec JSONB, + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + UNIQUE(version_model_id) +); \ No newline at end of file diff --git a/python/sdk/test/pyfunc_integration_test.py b/python/sdk/test/pyfunc_integration_test.py index 69d2d0adf..fe14a8d45 100644 --- a/python/sdk/test/pyfunc_integration_test.py +++ b/python/sdk/test/pyfunc_integration_test.py @@ -14,6 +14,7 @@ import os import warnings +from functools import reduce from test.utils import undeploy_all_version import joblib @@ -21,6 +22,8 @@ import pytest import xgboost as xgb from merlin.model import ModelType, PyFuncModel, PyFuncV3Model +from merlin.model_schema import ModelSchema +from merlin.observability.inference import InferenceSchema, RegressionOutput, ValueType from merlin.pyfunc import ModelInput, ModelOutput, Values from merlin.resource_request import ResourceRequest from sklearn import svm @@ -63,12 +66,12 @@ def infer(self, model_input): class ModelObservabilityModel(PyFuncV3Model): def initialize(self, artifacts): self._feature_names = [ - "sepal length (cm)", - "sepal width (cm)", - "petal length (cm)", - "petal width (cm)", + "sepal_length", + "sepal_width", + "petal_length", + "petal_width", ] - self._target_names = ["setosa", "versicolor", "virginica"] + self._target_names = ["prediction_score"] self._model = xgb.Booster(model_file=artifacts["xgb_model"]) def preprocess(self, request: dict, **kwargs) -> ModelInput: @@ -81,9 +84,17 @@ def preprocess(self, request: dict, **kwargs) -> ModelInput: def infer(self, model_input: ModelInput) -> ModelOutput: dmatrix = xgb.DMatrix(model_input.features.data) outputs = self._model.predict(dmatrix).tolist() + + def max(first, second): + if first > second: + return first + return second + + prediction_outputs = [[reduce(max, row)] for row in outputs] + return ModelOutput( prediction_ids=model_input.prediction_ids, - predictions=Values(columns=self._target_names, data=outputs), + predictions=Values(columns=self._target_names, data=prediction_outputs), ) def postprocess(self, model_output: ModelOutput, request: dict) -> dict: @@ -234,13 +245,26 @@ def test_pyfunc_model_observability( merlin.set_project(project_name) merlin.set_model("pyfunc-mlobs", ModelType.PYFUNC_V3) + model_schema = ModelSchema( + spec=InferenceSchema( + feature_types={ + "sepal_length": ValueType.FLOAT64, + "sepal_width": ValueType.FLOAT64, + "petal_length": ValueType.FLOAT64, + "petal_width": ValueType.FLOAT64, + }, + model_prediction_output=RegressionOutput( + prediction_score_column="prediction_score", actual_score_column="actual" + ), + ) + ) undeploy_all_version() - with merlin.new_model_version() as v: + with merlin.new_model_version(model_schema=model_schema) as v: iris = load_iris() y = iris["target"] X = iris["data"] - xgb_path = train_xgboost_model(X, y) + xgb_path = train_xgboost_model(X, y) v.log_pyfunc_model( model_instance=ModelObservabilityModel(), conda_env="test/pyfunc/env.yaml", diff --git a/scripts/e2e/values-e2e.yaml b/scripts/e2e/values-e2e.yaml index 74cac258e..8826bc1d3 100644 --- a/scripts/e2e/values-e2e.yaml +++ b/scripts/e2e/values-e2e.yaml @@ -25,6 +25,18 @@ config: node-workload-type: "batch" AuthorizationConfig: AuthorizationEnabled: false + ObservabilityPublisher: + TargetNamespace: caraml-mlobs + EnvironmentName: dev + KafkaConsumer: + Brokers: broker-sample:6666 + DefaultResources: + Requests: + CPU: "1" + Memory: 512Mi + Limits: + CPU: "1" + Memory: 1Gi imageBuilder: serviceAccount: create: true