Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support overriding task pod_template via with_overrides #2981

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions flytekit/core/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from flyteidl.core import tasks_pb2

from flytekit.core.pod_template import PodTemplate
from flytekit.core.resources import Resources, convert_resources_to_resource_model
from flytekit.core.utils import _dnsify
from flytekit.extras.accelerators import BaseAccelerator
Expand Down Expand Up @@ -67,6 +68,7 @@ def __init__(
self._resources: typing.Optional[_resources_model] = None
self._extended_resources: typing.Optional[tasks_pb2.ExtendedResources] = None
self._container_image: typing.Optional[str] = None
self._pod_template: typing.Optional[PodTemplate] = None

def runs_before(self, other: Node):
"""
Expand Down Expand Up @@ -140,6 +142,7 @@ def with_overrides(
cache: Optional[bool] = None,
cache_version: Optional[str] = None,
cache_serialize: Optional[bool] = None,
pod_template: Optional[PodTemplate] = None,
*args,
**kwargs,
):
Expand Down Expand Up @@ -221,6 +224,10 @@ def with_overrides(
assert_not_promise(cache_serialize, "cache_serialize")
self._metadata._cache_serializable = cache_serialize

if pod_template is not None:
assert_not_promise(pod_template, "podtemplate")
self._pod_template = pod_template

return self


Expand Down
15 changes: 15 additions & 0 deletions flytekit/models/core/workflow.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import datetime
import json
import typing

from flyteidl.core import tasks_pb2
from flyteidl.core import workflow_pb2 as _core_workflow
from google.protobuf import json_format, struct_pb2
from google.protobuf.wrappers_pb2 import BoolValue

from flytekit.models import common as _common
Expand Down Expand Up @@ -612,10 +614,12 @@ def __init__(
resources: typing.Optional[Resources],
extended_resources: typing.Optional[tasks_pb2.ExtendedResources],
container_image: typing.Optional[str] = None,
pod_template: typing.Optional[tasks_pb2.K8sPod] = None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The return type should be PodTemplate, right

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pingsutw I didn't create a new proto named PodTemplate in tasks_pb2, so I used the existing protobuf K8sPod.

):
self._resources = resources
self._extended_resources = extended_resources
self._container_image = container_image
self._pod_template = pod_template

@property
def resources(self) -> Resources:
Expand All @@ -629,11 +633,22 @@ def extended_resources(self) -> tasks_pb2.ExtendedResources:
def container_image(self) -> typing.Optional[str]:
return self._container_image

@property
def pod_template(self) -> typing.Optional[tasks_pb2.K8sPod]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

return self._pod_template

def to_flyte_idl(self):
return _core_workflow.TaskNodeOverrides(
resources=self.resources.to_flyte_idl() if self.resources is not None else None,
extended_resources=self.extended_resources,
container_image=self.container_image,
pod_template=tasks_pb2.K8sPod(
metadata=self.pod_template.metadata.to_flyte_idl() if self.pod_template else None,
pod_spec=json_format.Parse(json.dumps(self.pod_template.pod_spec), struct_pb2.Struct())
if self.pod_template
else None,
Comment on lines +647 to +649
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider handling pod_spec JSON serialization errors

Consider handling potential JSON serialization errors when converting pod_spec to protobuf struct. The current implementation may raise exceptions if pod_spec contains non-serializable objects.

Code suggestion
Check the AI-generated fix before applying
Suggested change
pod_spec=json_format.Parse(json.dumps(self.pod_template.pod_spec), struct_pb2.Struct())
if self.pod_template
else None,
pod_spec=(
try:
json_format.Parse(json.dumps(self.pod_template.pod_spec), struct_pb2.Struct())
except (TypeError, ValueError, json.JSONDecodeError) as e:
logger.error(f"Failed to serialize pod_spec: {e}")
None
) if self.pod_template else None,

Code Review Run #0770cb


Is this a valid issue, or was it incorrectly flagged by the Agent?

  • it was incorrectly flagged

primary_container_name=self.pod_template.primary_container_name if self.pod_template else None,
),
)

@classmethod
Expand Down
6 changes: 6 additions & 0 deletions flytekit/models/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -1005,13 +1005,15 @@ def __init__(
metadata: K8sObjectMetadata = None,
pod_spec: typing.Dict[str, typing.Any] = None,
data_config: typing.Optional[DataLoadingConfig] = None,
primary_container_name: typing.Optional[str] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider adding validation for container name

Consider adding validation for primary_container_name to ensure it's not an empty string when provided.

Code suggestion
Check the AI-generated fix before applying
 @@ -1013,4 +1013,6 @@
          self._metadata = metadata
          self._pod_spec = pod_spec
          self._data_config = data_config
 +        if primary_container_name is not None and not primary_container_name.strip():
 +            raise ValueError("primary_container_name cannot be empty if provided")
          self._primary_container_name = primary_container_name

Code Review Run #0770cb


Is this a valid issue, or was it incorrectly flagged by the Agent?

  • it was incorrectly flagged

):
"""
This defines a kubernetes pod target. It will build the pod target during task execution
"""
self._metadata = metadata
self._pod_spec = pod_spec
self._data_config = data_config
self._primary_container_name = primary_container_name

@property
def metadata(self) -> K8sObjectMetadata:
Expand All @@ -1025,6 +1027,10 @@ def pod_spec(self) -> typing.Dict[str, typing.Any]:
def data_config(self) -> typing.Optional[DataLoadingConfig]:
return self._data_config

@property
def primary_container_name(self) -> typing.Optional[str]:
return self._primary_container_name

def to_flyte_idl(self) -> _core_task.K8sPod:
return _core_task.K8sPod(
metadata=self._metadata.to_flyte_idl() if self.metadata else None,
Expand Down
20 changes: 18 additions & 2 deletions flytekit/tools/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from flytekit.core.python_function_task import EagerAsyncPythonFunctionTask
from flytekit.core.reference_entity import ReferenceEntity, ReferenceSpec, ReferenceTemplate
from flytekit.core.task import ReferenceTask
from flytekit.core.utils import ClassDecorator, _dnsify
from flytekit.core.utils import ClassDecorator, _dnsify, _serialize_pod_spec
from flytekit.core.workflow import ReferenceWorkflow, WorkflowBase
from flytekit.models import common as _common_models
from flytekit.models import interface as interface_models
Expand All @@ -38,7 +38,7 @@
from flytekit.models.core.workflow import ApproveCondition, GateNode, SignalCondition, SleepCondition, TaskNodeOverrides
from flytekit.models.core.workflow import ArrayNode as ArrayNodeModel
from flytekit.models.core.workflow import BranchNode as BranchNodeModel
from flytekit.models.task import TaskSpec, TaskTemplate
from flytekit.models.task import K8sObjectMetadata, K8sPod, TaskSpec, TaskTemplate

FlyteLocalEntity = Union[
PythonTask,
Expand Down Expand Up @@ -453,6 +453,12 @@ def get_serializable_node(
# if entity._aliases:
# node_model._output_aliases = entity._aliases
elif isinstance(entity.flyte_entity, PythonTask):
override_pod_spec = {}
if entity._pod_template is not None:
entity.flyte_entity.set_command_fn(_fast_serialize_command_fn(settings, entity.flyte_entity))
override_pod_spec = _serialize_pod_spec(
entity._pod_template, entity.flyte_entity._get_container(settings), settings
)
Comment on lines +457 to +461
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider command serialization location

Consider moving the command function serialization logic outside the if block since it's needed regardless of whether _pod_template exists or not.

Code suggestion
Check the AI-generated fix before applying
Suggested change
if entity._pod_template is not None:
entity.flyte_entity.set_command_fn(_fast_serialize_command_fn(settings, entity.flyte_entity))
override_pod_spec = _serialize_pod_spec(
entity._pod_template, entity.flyte_entity._get_container(settings), settings
)
entity.flyte_entity.set_command_fn(_fast_serialize_command_fn(settings, entity.flyte_entity))
if entity._pod_template is not None:
override_pod_spec = _serialize_pod_spec(
entity._pod_template, entity.flyte_entity._get_container(settings), settings
)

Code Review Run #0770cb


Is this a valid issue, or was it incorrectly flagged by the Agent?

  • it was incorrectly flagged

task_spec = get_serializable(entity_mapping, settings, entity.flyte_entity, options=options)
node_model = workflow_model.Node(
id=_dnsify(entity.id),
Expand All @@ -466,6 +472,16 @@ def get_serializable_node(
resources=entity._resources,
extended_resources=entity._extended_resources,
container_image=entity._container_image,
pod_template=K8sPod(
pod_spec=override_pod_spec if override_pod_spec is not None else None,
metadata=K8sObjectMetadata(
labels=entity._pod_template.labels if entity._pod_template else None,
annotations=entity._pod_template.annotations if entity._pod_template else None,
),
primary_container_name=entity._pod_template.primary_container_name
if entity._pod_template
else None,
),
Comment on lines +475 to +484
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider extracting pod template creation logic

Consider extracting the K8sPod template creation logic into a separate helper function to improve code readability and maintainability. The nested structure with multiple conditional checks makes the code harder to follow. A similar issue was also found in tests/flytekit/unit/core/test_array_node_map_task.py (line 371-389).

Code suggestion
Check the AI-generated fix before applying
Suggested change
pod_template=K8sPod(
pod_spec=override_pod_spec if override_pod_spec is not None else None,
metadata=K8sObjectMetadata(
labels=entity._pod_template.labels if entity._pod_template else None,
annotations=entity._pod_template.annotations if entity._pod_template else None,
),
primary_container_name=entity._pod_template.primary_container_name
if entity._pod_template
else None,
),
pod_template=_create_pod_template(entity, override_pod_spec),
def _create_pod_template(entity, override_pod_spec):
return K8sPod(
pod_spec=override_pod_spec if override_pod_spec is not None else None,
metadata=K8sObjectMetadata(
labels=entity._pod_template.labels if entity._pod_template else None,
annotations=entity._pod_template.annotations if entity._pod_template else None,
),
primary_container_name=entity._pod_template.primary_container_name
if entity._pod_template
else None,
)

Code Review Run #0770cb


Is this a valid issue, or was it incorrectly flagged by the Agent?

  • it was incorrectly flagged

),
),
)
Expand Down
35 changes: 34 additions & 1 deletion tests/flytekit/unit/core/test_array_node_map_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@

import pytest
from flyteidl.core import workflow_pb2 as _core_workflow
from kubernetes.client import V1PodSpec, V1Container, V1EnvVar

from flytekit import dynamic, map_task, task, workflow, eager, PythonFunctionTask
from flytekit import dynamic, map_task, task, workflow, eager, PythonFunctionTask, PodTemplate
from flytekit.configuration import FastSerializationSettings, Image, ImageConfig, SerializationSettings
from flytekit.core import context_manager
from flytekit.core.array_node_map_task import ArrayNodeMapTask, ArrayNodeMapTaskResolver
Expand Down Expand Up @@ -360,6 +361,38 @@ def wf(x: typing.List[int]):

assert wf.nodes[0]._container_image == "random:image"

def test_map_task_pod_template_override(serialization_settings):
@task
def my_mappable_task(a: int) -> typing.Optional[str]:
return str(a)

@workflow
def wf(x: typing.List[int]):
map_task(my_mappable_task)(a=x).with_overrides(pod_template=PodTemplate(
primary_container_name="primary1",
labels={"lKeyA": "lValA", "lKeyB": "lValB"},
annotations={"aKeyA": "aValA", "aKeyB": "aValB"},
pod_spec=V1PodSpec(
containers=[
V1Container(
name="primary1",
image="random:image",
env=[V1EnvVar(name="eKeyC", value="eValC"), V1EnvVar(name="eKeyD", value="eValD")],
),
V1Container(
name="primary2",
image="random:image2",
env=[V1EnvVar(name="eKeyC", value="eValC"), V1EnvVar(name="eKeyD", value="eValD")],
),
],
)
))

assert wf.nodes[0]._pod_template.primary_container_name == "primary1"
assert wf.nodes[0]._pod_template.pod_spec.containers[0].image == "random:image"
assert wf.nodes[0]._pod_template.labels == {"lKeyA": "lValA", "lKeyB": "lValB"}
assert wf.nodes[0]._pod_template.annotations["aKeyA"] == "aValA"


def test_serialization_metadata(serialization_settings):
@task(interruptible=True)
Expand Down
36 changes: 35 additions & 1 deletion tests/flytekit/unit/core/test_map_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
from collections import OrderedDict

import pytest
from kubernetes.client import V1PodSpec, V1Container, V1EnvVar

import flytekit.configuration
from flytekit import LaunchPlan, Resources
from flytekit import LaunchPlan, Resources, PodTemplate
from flytekit.configuration import Image, ImageConfig
from flytekit.core.legacy_map_task import MapPythonTask, MapTaskResolver, map_task
from flytekit.core.task import TaskMetadata, task
Expand Down Expand Up @@ -354,6 +355,39 @@ def wf(x: typing.List[int]):

assert wf.nodes[0]._container_image == "random:image"

def test_map_task_pod_template_override(serialization_settings):
@task
def my_mappable_task(a: int) -> typing.Optional[str]:
return str(a)

@workflow
def wf(x: typing.List[int]):
map_task(my_mappable_task)(a=x).with_overrides(pod_template=PodTemplate(
primary_container_name="primary1",
labels={"lKeyA": "lValA", "lKeyB": "lValB"},
annotations={"aKeyA": "aValA", "aKeyB": "aValB"},
pod_spec=V1PodSpec(
containers=[
V1Container(
name="primary1",
image="random:image",
env=[V1EnvVar(name="eKeyC", value="eValC"), V1EnvVar(name="eKeyD", value="eValD")],
),
V1Container(
name="primary2",
image="random:image2",
env=[V1EnvVar(name="eKeyC", value="eValC"), V1EnvVar(name="eKeyD", value="eValD")],
),
],
)
))
Comment on lines +365 to +383
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider improving map task call readability

Consider breaking down the nested map_task call with with_overrides into multiple lines for better readability. The current structure makes it difficult to understand the configuration hierarchy.

Code suggestion
Check the AI-generated fix before applying
Suggested change
map_task(my_mappable_task)(a=x).with_overrides(pod_template=PodTemplate(
primary_container_name="primary1",
labels={"lKeyA": "lValA", "lKeyB": "lValB"},
annotations={"aKeyA": "aValA", "aKeyB": "aValB"},
pod_spec=V1PodSpec(
containers=[
V1Container(
name="primary1",
image="random:image",
env=[V1EnvVar(name="eKeyC", value="eValC"), V1EnvVar(name="eKeyD", value="eValD")],
),
V1Container(
name="primary2",
image="random:image2",
env=[V1EnvVar(name="eKeyC", value="eValC"), V1EnvVar(name="eKeyD", value="eValD")],
),
],
)
))
pod_template = PodTemplate(
primary_container_name="primary1",
labels={"lKeyA": "lValA", "lKeyB": "lValB"},
annotations={"aKeyA": "aValA", "aKeyB": "aValB"},
pod_spec=V1PodSpec(
containers=[
V1Container(
name="primary1",
image="random:image",
env=[V1EnvVar(name="eKeyC", value="eValC"), V1EnvVar(name="eKeyD", value="eValD")],
),
V1Container(
name="primary2",
image="random:image2",
env=[V1EnvVar(name="eKeyC", value="eValC"), V1EnvVar(name="eKeyD", value="eValD")],
),
],

Code Review Run #0770cb


Is this a valid issue, or was it incorrectly flagged by the Agent?

  • it was incorrectly flagged



assert wf.nodes[0]._pod_template.primary_container_name == "primary1"
assert wf.nodes[0]._pod_template.pod_spec.containers[0].image == "random:image"
assert wf.nodes[0]._pod_template.labels == {"lKeyA": "lValA", "lKeyB": "lValB"}
assert wf.nodes[0]._pod_template.annotations["aKeyA"] == "aValA"


def test_bounded_inputs_vars_order(serialization_settings):
mt = map_task(functools.partial(t3, c=1.0, b="hello", a=1))
Expand Down
36 changes: 35 additions & 1 deletion tests/flytekit/unit/core/test_node_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
from dataclasses import dataclass

import pytest
from kubernetes.client import V1PodSpec, V1Container, V1EnvVar

import flytekit.configuration
from flytekit import Resources, map_task
from flytekit import Resources, map_task, PodTemplate
from flytekit.configuration import Image, ImageConfig
from flytekit.core.dynamic_workflow_task import dynamic
from flytekit.core.node_creation import create_node
Expand Down Expand Up @@ -470,6 +471,39 @@ def wf() -> str:

assert wf.nodes[0]._container_image == "hello/world"

def test_pod_template_override():
@task
def bar():
print("hello")

@workflow
def wf() -> str:
bar().with_overrides(pod_template=PodTemplate(
primary_container_name="primary1",
labels={"lKeyA": "lValA", "lKeyB": "lValB"},
annotations={"aKeyA": "aValA", "aKeyB": "aValB"},
pod_spec=V1PodSpec(
containers=[
V1Container(
name="primary1",
image="random:image",
env=[V1EnvVar(name="eKeyC", value="eValC"), V1EnvVar(name="eKeyD", value="eValD")],
),
V1Container(
name="primary2",
image="random:image2",
env=[V1EnvVar(name="eKeyC", value="eValC"), V1EnvVar(name="eKeyD", value="eValD")],
),
],
)
))
return "hi"

assert wf.nodes[0]._pod_template.primary_container_name == "primary1"
assert wf.nodes[0]._pod_template.pod_spec.containers[0].image == "random:image"
assert wf.nodes[0]._pod_template.labels == {"lKeyA": "lValA", "lKeyB": "lValB"}
assert wf.nodes[0]._pod_template.annotations["aKeyA"] == "aValA"


def test_override_accelerator():
@task(accelerator=T4)
Expand Down
Loading