Skip to content

Commit

Permalink
enable parameterization of container images
Browse files Browse the repository at this point in the history
This change allows component base images to be parameterized using runtime pipeline parameters. The container images can be specified within an @pipeline decorated function, and takes precedence over the @component(base_image=..) argument.

This change also adds logic to resolve these runtime parameters in the argo driver logic. It also includes resolution steps for resolving the accelerator type which functions the same way but was missing the resolution logic. The resolution logic is a generic workaround solution for any run time pod spec input parameters that cannot be resolved because they cannot be added dynamically in the argo pod spec container template.

Signed-off-by: Humair Khan <[email protected]>
  • Loading branch information
HumairAK committed Nov 27, 2024
1 parent 634aadf commit 2f493c0
Show file tree
Hide file tree
Showing 6 changed files with 342 additions and 6 deletions.
20 changes: 15 additions & 5 deletions backend/src/v2/driver/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package driver

import (
Expand Down Expand Up @@ -448,19 +449,28 @@ func initPodSpecPatch(
accelerator := container.GetResources().GetAccelerator()
if accelerator != nil {
if accelerator.GetType() != "" && accelerator.GetCount() > 0 {
q, err := k8sres.ParseQuantity(fmt.Sprintf("%v", accelerator.GetCount()))
if err != nil {
return nil, fmt.Errorf("failed to init podSpecPatch: %w", err)
acceleratorType, err1 := resolvePodSpecInputRuntimeParameter(accelerator.GetType(), executorInput)
if err1 != nil {
return nil, fmt.Errorf("failed to init podSpecPatch: %w", err1)
}
res.Limits[k8score.ResourceName(accelerator.GetType())] = q
q, err1 := k8sres.ParseQuantity(fmt.Sprintf("%v", accelerator.GetCount()))
if err1 != nil {
return nil, fmt.Errorf("failed to init podSpecPatch: %w", err1)
}
res.Limits[k8score.ResourceName(acceleratorType)] = q
}
}

containerImage, err := resolvePodSpecInputRuntimeParameter(container.Image, executorInput)
if err != nil {
return nil, fmt.Errorf("failed to init podSpecPatch: %w", err)
}
podSpec := &k8score.PodSpec{
Containers: []k8score.Container{{
Name: "main", // argo task user container is always called "main"
Command: launcherCmd,
Args: userCmdArgs,
Image: container.Image,
Image: containerImage,
Resources: res,
Env: userEnvVar,
}},
Expand Down
78 changes: 78 additions & 0 deletions backend/src/v2/driver/util.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
// Copyright 2021-2024 The Kubeflow Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package driver

import (
"fmt"
"github.com/kubeflow/pipelines/api/v2alpha1/go/pipelinespec"
"regexp"
)

// inputPipelineChannelPattern define a regex pattern to match the content within single quotes
// example input channel looks like "{{$.inputs.parameters['pipelinechannel--val']}}"
const inputPipelineChannelPattern = `\$.inputs.parameters\['(.+?)'\]`

func isInputParameterChannel(inputChannel string) bool {
re := regexp.MustCompile(inputPipelineChannelPattern)
match := re.FindStringSubmatch(inputChannel)
if len(match) == 2 {
return true
} else {
// if len(match) > 2, then this is still incorrect because
// inputChannel should contain only one parameter channel input
return false
}
}

// extractInputParameterFromChannel takes an inputChannel that adheres to
// inputPipelineChannelPattern and extracts the channel parameter name.
// For example given an input channel of the form "{{$.inputs.parameters['pipelinechannel--val']}}"
// the channel parameter name "pipelinechannel--val" is returned.
func extractInputParameterFromChannel(inputChannel string) (string, error) {
re := regexp.MustCompile(inputPipelineChannelPattern)
match := re.FindStringSubmatch(inputChannel)
if len(match) > 1 {
extractedValue := match[1]
return extractedValue, nil
} else {
return "", fmt.Errorf("failed to extract input parameter from channel: %s", inputChannel)
}
}

// resolvePodSpecInputRuntimeParameter resolves runtime value that is intended to be
// utilized within the Pod Spec. parameterValue takes the form of:
// "{{$.inputs.parameters['pipelinechannel--someParameterName']}}"
//
// parameterValue is a runtime parameter value that has been resolved and included within
// the executor input. Since the pod spec patch cannot dynamically update the underlying
// container template's inputs in an Argo Workflow, this is a workaround for resolving
// such parameters.
//
// If parameter value is not a parameter channel, then a constant value is assumed and
// returned as is.
func resolvePodSpecInputRuntimeParameter(parameterValue string, executorInput *pipelinespec.ExecutorInput) (string, error) {
if isInputParameterChannel(parameterValue) {
inputImage, err := extractInputParameterFromChannel(parameterValue)
if err != nil {
return "", err
}
if val, ok := executorInput.Inputs.ParameterValues[inputImage]; ok {
return val.GetStringValue(), nil
} else {
return "", fmt.Errorf("executorInput did not contain container Image input parameter")
}
}
return parameterValue, nil
}
159 changes: 159 additions & 0 deletions backend/src/v2/driver/util_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
// Copyright 2021-2024 The Kubeflow Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package driver

import (
"github.com/kubeflow/pipelines/api/v2alpha1/go/pipelinespec"
"github.com/stretchr/testify/assert"
structpb "google.golang.org/protobuf/types/known/structpb"
"testing"
)

func Test_isInputParameterChannel(t *testing.T) {
tests := []struct {
name string
input string
isValid bool
}{
{
name: "wellformed pipeline channel should produce no errors",
input: "{{$.inputs.parameters['pipelinechannel--someParameterName']}}",
isValid: true,
},
{
name: "pipeline channel index should have quotes",
input: "{{$.inputs.parameters[pipelinechannel--someParameterName]}}",
isValid: false,
},
{
name: "plain text as pipelinechannel of parameter type is invalid",
input: "randomtext",
isValid: false,
},
{
name: "inputs should be prefixed with $.",
input: "{{inputs.parameters['pipelinechannel--someParameterName']}}",
isValid: false,
},
}

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
assert.Equal(t, isInputParameterChannel(test.input), test.isValid)
})
}
}

func Test_extractInputParameterFromChannel(t *testing.T) {
tests := []struct {
name string
input string
expected string
wantErr bool
}{
{
name: "standard parameter pipeline channel input",
input: "{{$.inputs.parameters['pipelinechannel--someParameterName']}}",
expected: "pipelinechannel--someParameterName",
wantErr: false,
},
{
name: "a more complex parameter pipeline channel input",
input: "{{$.inputs.parameters['pipelinechannel--somePara-me_terName']}}",
expected: "pipelinechannel--somePara-me_terName",
wantErr: false,
},
{
name: "invalid input should return err",
input: "invalidvalue",
wantErr: true,
},
{
name: "invalid input should return err 2",
input: "pipelinechannel--somePara-me_terName",
wantErr: true,
},
}

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
actual, err := extractInputParameterFromChannel(test.input)
if test.wantErr {
assert.NotNil(t, err)
} else {
assert.NoError(t, err)
assert.Equal(t, actual, test.expected)
}
})
}
}

func Test_resolvePodSpecRuntimeParameter(t *testing.T) {
tests := []struct {
name string
input string
expected string
executorInput *pipelinespec.ExecutorInput
wantErr bool
}{
{
name: "should retrieve correct parameter value",
input: "{{$.inputs.parameters['pipelinechannel--someParameterName']}}",
expected: "test2",
executorInput: &pipelinespec.ExecutorInput{
Inputs: &pipelinespec.ExecutorInput_Inputs{
ParameterValues: map[string]*structpb.Value{
"pipelinechannel--": structpb.NewStringValue("test1"),
"pipelinechannel--someParameterName": structpb.NewStringValue("test2"),
"someParameterName": structpb.NewStringValue("test3"),
},
},
},
wantErr: false,
},
{
name: "return err when no match is found",
input: "{{$.inputs.parameters['pipelinechannel--someParameterName']}}",
expected: "test1",
executorInput: &pipelinespec.ExecutorInput{
Inputs: &pipelinespec.ExecutorInput_Inputs{
ParameterValues: map[string]*structpb.Value{
"doesNotMatch": structpb.NewStringValue("test2"),
},
},
},
wantErr: true,
},
{
name: "return const val when input is not a pipeline channel",
input: "not-pipeline-channel",
expected: "not-pipeline-channel",
executorInput: &pipelinespec.ExecutorInput{},
wantErr: false,
},
}

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
actual, err := resolvePodSpecInputRuntimeParameter(test.input, test.executorInput)
if test.wantErr {
assert.NotNil(t, err)
} else {
assert.NoError(t, err)
assert.Equal(t, actual, test.expected)
}
})
}
}
64 changes: 64 additions & 0 deletions sdk/python/kfp/compiler/compiler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -909,6 +909,70 @@ def my_pipeline() -> NamedTuple('Outputs', [
]):
task = print_and_return(text='Hello')

def test_pipeline_with_parameterized_container_image(self):
with tempfile.TemporaryDirectory() as tmpdir:

@dsl.component(base_image='docker.io/python:3.9.17')
def empty_component():
pass

@dsl.pipeline()
def simple_pipeline(img: str):
task = empty_component()
# overwrite base_image="docker.io/python:3.9.17"
task.set_container_image(img)

output_yaml = os.path.join(tmpdir, 'result.yaml')
compiler.Compiler().compile(
pipeline_func=simple_pipeline,
package_path=output_yaml,
pipeline_parameters={'img': 'someimage'})
self.assertTrue(os.path.exists(output_yaml))

with open(output_yaml, 'r') as f:
pipeline_spec = yaml.safe_load(f)
container = pipeline_spec['deploymentSpec']['executors'][
'exec-empty-component']['container']
self.assertEqual(
container['image'],
"{{$.inputs.parameters['pipelinechannel--img']}}")
# A parameter value should result in 2 input parameters
# One for storing pipeline channel template to be resolved during runtime.
# Two for holding the key to the resolved input.
input_parameters = pipeline_spec['root']['dag']['tasks'][
'empty-component']['inputs']['parameters']
self.assertTrue('base_image' in input_parameters)
self.assertTrue('pipelinechannel--img' in input_parameters)

def test_pipeline_with_constant_container_image(self):
with tempfile.TemporaryDirectory() as tmpdir:

@dsl.component(base_image='docker.io/python:3.9.17')
def empty_component():
pass

@dsl.pipeline()
def simple_pipeline():
task = empty_component()
# overwrite base_image="docker.io/python:3.9.17"
task.set_container_image('constant-value')

output_yaml = os.path.join(tmpdir, 'result.yaml')
compiler.Compiler().compile(
pipeline_func=simple_pipeline, package_path=output_yaml)

self.assertTrue(os.path.exists(output_yaml))

with open(output_yaml, 'r') as f:
pipeline_spec = yaml.safe_load(f)
container = pipeline_spec['deploymentSpec']['executors'][
'exec-empty-component']['container']
self.assertEqual(container['image'], 'constant-value')
# A constant value should yield no parameters
dag_task = pipeline_spec['root']['dag']['tasks'][
'empty-component']
self.assertTrue('inputs' not in dag_task)


class TestCompilePipelineCaching(unittest.TestCase):

Expand Down
7 changes: 6 additions & 1 deletion sdk/python/kfp/compiler/pipeline_spec_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,11 @@ def build_task_spec_for_task(
if val and pipeline_channel.extract_pipeline_channels_from_any(val):
task.inputs[key] = val

if task.container_spec and task.container_spec.image:
val = task.container_spec.image
if val and pipeline_channel.extract_pipeline_channels_from_any(val):
task.inputs['base_image'] = val

for input_name, input_value in task.inputs.items():
# Since LoopParameterArgument and LoopArtifactArgument and LoopArgumentVariable are narrower
# types than PipelineParameterChannel, start with them.
Expand Down Expand Up @@ -634,7 +639,7 @@ def convert_to_placeholder(input_value: str) -> str:

container_spec = (
pipeline_spec_pb2.PipelineDeploymentConfig.PipelineContainerSpec(
image=task.container_spec.image,
image=convert_to_placeholder(task.container_spec.image),
command=task.container_spec.command,
args=task.container_spec.args,
env=[
Expand Down
20 changes: 20 additions & 0 deletions sdk/python/kfp/dsl/pipeline_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -631,6 +631,26 @@ def set_env_variable(self, name: str, value: str) -> 'PipelineTask':
self.container_spec.env = {name: value}
return self

@block_if_final()
def set_container_image(
self,
name: Union[str,
pipeline_channel.PipelineChannel]) -> 'PipelineTask':
"""Sets container type to use when executing this task. Takes
precedence over @component(base_image=...)
Args:
name: The name of the image, e.g. "python:3.9-alpine".
Returns:
Self return to allow chained setting calls.
"""
self._ensure_container_spec_exists()
if isinstance(name, pipeline_channel.PipelineChannel):
name = str(name)
self.container_spec.image = name
return self

@block_if_final()
def after(self, *tasks) -> 'PipelineTask':
"""Specifies an explicit dependency on other tasks by requiring this
Expand Down

0 comments on commit 2f493c0

Please sign in to comment.