diff --git a/pkg/apis/types/update_serving.go b/pkg/apis/types/update_serving.go index 4dbcc6a80..f0fce5a81 100644 --- a/pkg/apis/types/update_serving.go +++ b/pkg/apis/types/update_serving.go @@ -19,6 +19,7 @@ type CommonUpdateServingArgs struct { Tolerations []TolerationArgs `yaml:"tolerations"` // --toleration Shell string `yaml:"shell"` // --shell Command string `yaml:"command"` // --command + ModelDirs map[string]string `yaml:"modelDirs"` // --data } type UpdateTensorFlowServingArgs struct { diff --git a/pkg/argsbuilder/update_serving.go b/pkg/argsbuilder/update_serving.go index f77bba15c..8a4073984 100644 --- a/pkg/argsbuilder/update_serving.go +++ b/pkg/argsbuilder/update_serving.go @@ -16,6 +16,7 @@ package argsbuilder import ( "fmt" "github.com/kubeflow/arena/pkg/apis/types" + "github.com/kubeflow/arena/pkg/util" log "github.com/sirupsen/logrus" "github.com/spf13/cobra" "k8s.io/apimachinery/pkg/api/resource" @@ -68,6 +69,7 @@ func (s *UpdateServingArgsBuilder) AddCommandFlags(command *cobra.Command) { envs []string selectors []string tolerations []string + dataset []string ) command.Flags().StringVar(&s.args.Name, "name", "", "the serving name") @@ -85,11 +87,14 @@ func (s *UpdateServingArgsBuilder) AddCommandFlags(command *cobra.Command) { command.Flags().StringVar(&s.args.Command, "command", "", "the command will inject to container's command.") command.Flags().StringArrayVarP(&selectors, "selector", "", []string{}, `assigning jobs to some k8s particular nodes, usage: "--selector=key=value" or "--selector key=value" `) command.Flags().StringArrayVarP(&tolerations, "toleration", "", []string{}, `tolerate some k8s nodes with taints,usage: "--toleration key=value:effect,operator" or "--toleration all" `) + command.Flags().StringArrayVarP(&dataset, "data", "d", []string{}, "specify the trained models datasource to mount for serving, like :") + s.AddArgValue("env", &envs). AddArgValue("annotation", &annotations). AddArgValue("selector", &selectors). AddArgValue("label", &labels). - AddArgValue("toleration", &tolerations) + AddArgValue("toleration", &tolerations). + AddArgValue("data", &dataset) } func (s *UpdateServingArgsBuilder) PreBuild() error { @@ -130,6 +135,10 @@ func (s *UpdateServingArgsBuilder) PreBuild() error { return err } + if err := s.setDataSet(); err != nil { + return err + } + if err := s.check(); err != nil { return err } @@ -260,6 +269,28 @@ func (s *UpdateServingArgsBuilder) setLabels() error { return nil } +// setDataSets is used to handle option --data +func (s *UpdateServingArgsBuilder) setDataSet() error { + s.args.ModelDirs = map[string]string{} + argKey := "data" + var dataSet *[]string + value, ok := s.argValues[argKey] + if !ok { + return nil + } + dataSet = value.(*[]string) + log.Debugf("dataset: %v", *dataSet) + if len(*dataSet) <= 0 { + return nil + } + err := util.ValidateDatasets(*dataSet) + if err != nil { + return err + } + s.args.ModelDirs = transformSliceToMap(*dataSet, ":") + return nil +} + func (s *UpdateServingArgsBuilder) checkNamespace() error { if s.args.Namespace == "" { return fmt.Errorf("namespace not set, please set it") diff --git a/pkg/serving/update.go b/pkg/serving/update.go index 0f2b36f83..b80df5e55 100644 --- a/pkg/serving/update.go +++ b/pkg/serving/update.go @@ -563,6 +563,30 @@ func setInferenceServiceForCustomModel(args *types.UpdateKServeArgs, inferenceSe inferenceService.Spec.Predictor.Containers[0].Image = args.Image } + //set volume + if len(args.ModelDirs) != 0 { + log.Debugf("update modelDirs: [%+v]", args.ModelDirs) + var volumes []v1.Volume + var volumeMounts []v1.VolumeMount + + for pvName, mountPath := range args.ModelDirs { + volumes = append(volumes, v1.Volume{ + Name: pvName, + VolumeSource: v1.VolumeSource{ + PersistentVolumeClaim: &v1.PersistentVolumeClaimVolumeSource{ + ClaimName: pvName, + }, + }, + }) + volumeMounts = append(volumeMounts, v1.VolumeMount{ + Name: pvName, + MountPath: mountPath, + }) + } + inferenceService.Spec.Predictor.Containers[0].VolumeMounts = volumeMounts + inferenceService.Spec.Predictor.Volumes = volumes + } + // set resources limits resourceLimits := inferenceService.Spec.Predictor.Containers[0].Resources.Limits if resourceLimits == nil {