From 2f493c01e0c28ca597b921c306859740b03784ff Mon Sep 17 00:00:00 2001 From: Humair Khan Date: Tue, 19 Nov 2024 10:26:51 -0500 Subject: [PATCH] enable parameterization of container images This change allows component base images to be parameterized using runtime pipeline parameters. The container images can be specified within an @pipeline decorated function, and takes precedence over the @component(base_image=..) argument. This change also adds logic to resolve these runtime parameters in the argo driver logic. It also includes resolution steps for resolving the accelerator type which functions the same way but was missing the resolution logic. The resolution logic is a generic workaround solution for any run time pod spec input parameters that cannot be resolved because they cannot be added dynamically in the argo pod spec container template. Signed-off-by: Humair Khan --- backend/src/v2/driver/driver.go | 20 ++- backend/src/v2/driver/util.go | 78 +++++++++ backend/src/v2/driver/util_test.go | 159 ++++++++++++++++++ sdk/python/kfp/compiler/compiler_test.py | 64 +++++++ .../kfp/compiler/pipeline_spec_builder.py | 7 +- sdk/python/kfp/dsl/pipeline_task.py | 20 +++ 6 files changed, 342 insertions(+), 6 deletions(-) create mode 100644 backend/src/v2/driver/util.go create mode 100644 backend/src/v2/driver/util_test.go diff --git a/backend/src/v2/driver/driver.go b/backend/src/v2/driver/driver.go index 960ec6148e86..f17454382b5e 100644 --- a/backend/src/v2/driver/driver.go +++ b/backend/src/v2/driver/driver.go @@ -11,6 +11,7 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. + package driver import ( @@ -448,19 +449,28 @@ func initPodSpecPatch( accelerator := container.GetResources().GetAccelerator() if accelerator != nil { if accelerator.GetType() != "" && accelerator.GetCount() > 0 { - q, err := k8sres.ParseQuantity(fmt.Sprintf("%v", accelerator.GetCount())) - if err != nil { - return nil, fmt.Errorf("failed to init podSpecPatch: %w", err) + acceleratorType, err1 := resolvePodSpecInputRuntimeParameter(accelerator.GetType(), executorInput) + if err1 != nil { + return nil, fmt.Errorf("failed to init podSpecPatch: %w", err1) } - res.Limits[k8score.ResourceName(accelerator.GetType())] = q + q, err1 := k8sres.ParseQuantity(fmt.Sprintf("%v", accelerator.GetCount())) + if err1 != nil { + return nil, fmt.Errorf("failed to init podSpecPatch: %w", err1) + } + res.Limits[k8score.ResourceName(acceleratorType)] = q } } + + containerImage, err := resolvePodSpecInputRuntimeParameter(container.Image, executorInput) + if err != nil { + return nil, fmt.Errorf("failed to init podSpecPatch: %w", err) + } podSpec := &k8score.PodSpec{ Containers: []k8score.Container{{ Name: "main", // argo task user container is always called "main" Command: launcherCmd, Args: userCmdArgs, - Image: container.Image, + Image: containerImage, Resources: res, Env: userEnvVar, }}, diff --git a/backend/src/v2/driver/util.go b/backend/src/v2/driver/util.go new file mode 100644 index 000000000000..b85e08ffe105 --- /dev/null +++ b/backend/src/v2/driver/util.go @@ -0,0 +1,78 @@ +// Copyright 2021-2024 The Kubeflow Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package driver + +import ( + "fmt" + "github.com/kubeflow/pipelines/api/v2alpha1/go/pipelinespec" + "regexp" +) + +// inputPipelineChannelPattern define a regex pattern to match the content within single quotes +// example input channel looks like "{{$.inputs.parameters['pipelinechannel--val']}}" +const inputPipelineChannelPattern = `\$.inputs.parameters\['(.+?)'\]` + +func isInputParameterChannel(inputChannel string) bool { + re := regexp.MustCompile(inputPipelineChannelPattern) + match := re.FindStringSubmatch(inputChannel) + if len(match) == 2 { + return true + } else { + // if len(match) > 2, then this is still incorrect because + // inputChannel should contain only one parameter channel input + return false + } +} + +// extractInputParameterFromChannel takes an inputChannel that adheres to +// inputPipelineChannelPattern and extracts the channel parameter name. +// For example given an input channel of the form "{{$.inputs.parameters['pipelinechannel--val']}}" +// the channel parameter name "pipelinechannel--val" is returned. +func extractInputParameterFromChannel(inputChannel string) (string, error) { + re := regexp.MustCompile(inputPipelineChannelPattern) + match := re.FindStringSubmatch(inputChannel) + if len(match) > 1 { + extractedValue := match[1] + return extractedValue, nil + } else { + return "", fmt.Errorf("failed to extract input parameter from channel: %s", inputChannel) + } +} + +// resolvePodSpecInputRuntimeParameter resolves runtime value that is intended to be +// utilized within the Pod Spec. parameterValue takes the form of: +// "{{$.inputs.parameters['pipelinechannel--someParameterName']}}" +// +// parameterValue is a runtime parameter value that has been resolved and included within +// the executor input. Since the pod spec patch cannot dynamically update the underlying +// container template's inputs in an Argo Workflow, this is a workaround for resolving +// such parameters. +// +// If parameter value is not a parameter channel, then a constant value is assumed and +// returned as is. +func resolvePodSpecInputRuntimeParameter(parameterValue string, executorInput *pipelinespec.ExecutorInput) (string, error) { + if isInputParameterChannel(parameterValue) { + inputImage, err := extractInputParameterFromChannel(parameterValue) + if err != nil { + return "", err + } + if val, ok := executorInput.Inputs.ParameterValues[inputImage]; ok { + return val.GetStringValue(), nil + } else { + return "", fmt.Errorf("executorInput did not contain container Image input parameter") + } + } + return parameterValue, nil +} diff --git a/backend/src/v2/driver/util_test.go b/backend/src/v2/driver/util_test.go new file mode 100644 index 000000000000..15d0ffc7e82a --- /dev/null +++ b/backend/src/v2/driver/util_test.go @@ -0,0 +1,159 @@ +// Copyright 2021-2024 The Kubeflow Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package driver + +import ( + "github.com/kubeflow/pipelines/api/v2alpha1/go/pipelinespec" + "github.com/stretchr/testify/assert" + structpb "google.golang.org/protobuf/types/known/structpb" + "testing" +) + +func Test_isInputParameterChannel(t *testing.T) { + tests := []struct { + name string + input string + isValid bool + }{ + { + name: "wellformed pipeline channel should produce no errors", + input: "{{$.inputs.parameters['pipelinechannel--someParameterName']}}", + isValid: true, + }, + { + name: "pipeline channel index should have quotes", + input: "{{$.inputs.parameters[pipelinechannel--someParameterName]}}", + isValid: false, + }, + { + name: "plain text as pipelinechannel of parameter type is invalid", + input: "randomtext", + isValid: false, + }, + { + name: "inputs should be prefixed with $.", + input: "{{inputs.parameters['pipelinechannel--someParameterName']}}", + isValid: false, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + assert.Equal(t, isInputParameterChannel(test.input), test.isValid) + }) + } +} + +func Test_extractInputParameterFromChannel(t *testing.T) { + tests := []struct { + name string + input string + expected string + wantErr bool + }{ + { + name: "standard parameter pipeline channel input", + input: "{{$.inputs.parameters['pipelinechannel--someParameterName']}}", + expected: "pipelinechannel--someParameterName", + wantErr: false, + }, + { + name: "a more complex parameter pipeline channel input", + input: "{{$.inputs.parameters['pipelinechannel--somePara-me_terName']}}", + expected: "pipelinechannel--somePara-me_terName", + wantErr: false, + }, + { + name: "invalid input should return err", + input: "invalidvalue", + wantErr: true, + }, + { + name: "invalid input should return err 2", + input: "pipelinechannel--somePara-me_terName", + wantErr: true, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + actual, err := extractInputParameterFromChannel(test.input) + if test.wantErr { + assert.NotNil(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, actual, test.expected) + } + }) + } +} + +func Test_resolvePodSpecRuntimeParameter(t *testing.T) { + tests := []struct { + name string + input string + expected string + executorInput *pipelinespec.ExecutorInput + wantErr bool + }{ + { + name: "should retrieve correct parameter value", + input: "{{$.inputs.parameters['pipelinechannel--someParameterName']}}", + expected: "test2", + executorInput: &pipelinespec.ExecutorInput{ + Inputs: &pipelinespec.ExecutorInput_Inputs{ + ParameterValues: map[string]*structpb.Value{ + "pipelinechannel--": structpb.NewStringValue("test1"), + "pipelinechannel--someParameterName": structpb.NewStringValue("test2"), + "someParameterName": structpb.NewStringValue("test3"), + }, + }, + }, + wantErr: false, + }, + { + name: "return err when no match is found", + input: "{{$.inputs.parameters['pipelinechannel--someParameterName']}}", + expected: "test1", + executorInput: &pipelinespec.ExecutorInput{ + Inputs: &pipelinespec.ExecutorInput_Inputs{ + ParameterValues: map[string]*structpb.Value{ + "doesNotMatch": structpb.NewStringValue("test2"), + }, + }, + }, + wantErr: true, + }, + { + name: "return const val when input is not a pipeline channel", + input: "not-pipeline-channel", + expected: "not-pipeline-channel", + executorInput: &pipelinespec.ExecutorInput{}, + wantErr: false, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + actual, err := resolvePodSpecInputRuntimeParameter(test.input, test.executorInput) + if test.wantErr { + assert.NotNil(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, actual, test.expected) + } + }) + } +} diff --git a/sdk/python/kfp/compiler/compiler_test.py b/sdk/python/kfp/compiler/compiler_test.py index 2433f09bc6d1..b09b1b52c5df 100644 --- a/sdk/python/kfp/compiler/compiler_test.py +++ b/sdk/python/kfp/compiler/compiler_test.py @@ -909,6 +909,70 @@ def my_pipeline() -> NamedTuple('Outputs', [ ]): task = print_and_return(text='Hello') + def test_pipeline_with_parameterized_container_image(self): + with tempfile.TemporaryDirectory() as tmpdir: + + @dsl.component(base_image='docker.io/python:3.9.17') + def empty_component(): + pass + + @dsl.pipeline() + def simple_pipeline(img: str): + task = empty_component() + # overwrite base_image="docker.io/python:3.9.17" + task.set_container_image(img) + + output_yaml = os.path.join(tmpdir, 'result.yaml') + compiler.Compiler().compile( + pipeline_func=simple_pipeline, + package_path=output_yaml, + pipeline_parameters={'img': 'someimage'}) + self.assertTrue(os.path.exists(output_yaml)) + + with open(output_yaml, 'r') as f: + pipeline_spec = yaml.safe_load(f) + container = pipeline_spec['deploymentSpec']['executors'][ + 'exec-empty-component']['container'] + self.assertEqual( + container['image'], + "{{$.inputs.parameters['pipelinechannel--img']}}") + # A parameter value should result in 2 input parameters + # One for storing pipeline channel template to be resolved during runtime. + # Two for holding the key to the resolved input. + input_parameters = pipeline_spec['root']['dag']['tasks'][ + 'empty-component']['inputs']['parameters'] + self.assertTrue('base_image' in input_parameters) + self.assertTrue('pipelinechannel--img' in input_parameters) + + def test_pipeline_with_constant_container_image(self): + with tempfile.TemporaryDirectory() as tmpdir: + + @dsl.component(base_image='docker.io/python:3.9.17') + def empty_component(): + pass + + @dsl.pipeline() + def simple_pipeline(): + task = empty_component() + # overwrite base_image="docker.io/python:3.9.17" + task.set_container_image('constant-value') + + output_yaml = os.path.join(tmpdir, 'result.yaml') + compiler.Compiler().compile( + pipeline_func=simple_pipeline, package_path=output_yaml) + + self.assertTrue(os.path.exists(output_yaml)) + + with open(output_yaml, 'r') as f: + pipeline_spec = yaml.safe_load(f) + container = pipeline_spec['deploymentSpec']['executors'][ + 'exec-empty-component']['container'] + self.assertEqual(container['image'], 'constant-value') + # A constant value should yield no parameters + dag_task = pipeline_spec['root']['dag']['tasks'][ + 'empty-component'] + self.assertTrue('inputs' not in dag_task) + class TestCompilePipelineCaching(unittest.TestCase): diff --git a/sdk/python/kfp/compiler/pipeline_spec_builder.py b/sdk/python/kfp/compiler/pipeline_spec_builder.py index afc014530fa2..ffd1871bc2a7 100644 --- a/sdk/python/kfp/compiler/pipeline_spec_builder.py +++ b/sdk/python/kfp/compiler/pipeline_spec_builder.py @@ -135,6 +135,11 @@ def build_task_spec_for_task( if val and pipeline_channel.extract_pipeline_channels_from_any(val): task.inputs[key] = val + if task.container_spec and task.container_spec.image: + val = task.container_spec.image + if val and pipeline_channel.extract_pipeline_channels_from_any(val): + task.inputs['base_image'] = val + for input_name, input_value in task.inputs.items(): # Since LoopParameterArgument and LoopArtifactArgument and LoopArgumentVariable are narrower # types than PipelineParameterChannel, start with them. @@ -634,7 +639,7 @@ def convert_to_placeholder(input_value: str) -> str: container_spec = ( pipeline_spec_pb2.PipelineDeploymentConfig.PipelineContainerSpec( - image=task.container_spec.image, + image=convert_to_placeholder(task.container_spec.image), command=task.container_spec.command, args=task.container_spec.args, env=[ diff --git a/sdk/python/kfp/dsl/pipeline_task.py b/sdk/python/kfp/dsl/pipeline_task.py index 822f55207889..b41a14ef82db 100644 --- a/sdk/python/kfp/dsl/pipeline_task.py +++ b/sdk/python/kfp/dsl/pipeline_task.py @@ -631,6 +631,26 @@ def set_env_variable(self, name: str, value: str) -> 'PipelineTask': self.container_spec.env = {name: value} return self + @block_if_final() + def set_container_image( + self, + name: Union[str, + pipeline_channel.PipelineChannel]) -> 'PipelineTask': + """Sets container type to use when executing this task. Takes + precedence over @component(base_image=...) + + Args: + name: The name of the image, e.g. "python:3.9-alpine". + + Returns: + Self return to allow chained setting calls. + """ + self._ensure_container_spec_exists() + if isinstance(name, pipeline_channel.PipelineChannel): + name = str(name) + self.container_spec.image = name + return self + @block_if_final() def after(self, *tasks) -> 'PipelineTask': """Specifies an explicit dependency on other tasks by requiring this