diff --git a/flytekit/core/node.py b/flytekit/core/node.py index ea089c6fd3..96e5c38560 100644 --- a/flytekit/core/node.py +++ b/flytekit/core/node.py @@ -6,6 +6,7 @@ from flyteidl.core import tasks_pb2 +from flytekit.core.pod_template import PodTemplate from flytekit.core.resources import Resources, convert_resources_to_resource_model from flytekit.core.utils import _dnsify from flytekit.extras.accelerators import BaseAccelerator @@ -67,6 +68,7 @@ def __init__( self._resources: typing.Optional[_resources_model] = None self._extended_resources: typing.Optional[tasks_pb2.ExtendedResources] = None self._container_image: typing.Optional[str] = None + self._pod_template: typing.Optional[PodTemplate] = None def runs_before(self, other: Node): """ @@ -140,6 +142,7 @@ def with_overrides( cache: Optional[bool] = None, cache_version: Optional[str] = None, cache_serialize: Optional[bool] = None, + pod_template: Optional[PodTemplate] = None, *args, **kwargs, ): @@ -221,6 +224,10 @@ def with_overrides( assert_not_promise(cache_serialize, "cache_serialize") self._metadata._cache_serializable = cache_serialize + if pod_template is not None: + assert_not_promise(pod_template, "podtemplate") + self._pod_template = pod_template + return self diff --git a/flytekit/models/core/workflow.py b/flytekit/models/core/workflow.py index 8d8bf9c9ef..57dc90e96e 100644 --- a/flytekit/models/core/workflow.py +++ b/flytekit/models/core/workflow.py @@ -1,10 +1,13 @@ import datetime +import json import typing from flyteidl.core import tasks_pb2 from flyteidl.core import workflow_pb2 as _core_workflow +from google.protobuf import json_format, struct_pb2 from google.protobuf.wrappers_pb2 import BoolValue +from flytekit.core.pod_template import PodTemplate from flytekit.models import common as _common from flytekit.models import interface as _interface from flytekit.models import types as type_models @@ -12,7 +15,7 @@ from flytekit.models.core import identifier as _identifier from flytekit.models.literals import Binding as _Binding from flytekit.models.literals import RetryStrategy as _RetryStrategy -from flytekit.models.task import Resources +from flytekit.models.task import K8sObjectMetadata, Resources class IfBlock(_common.FlyteIdlEntity): @@ -612,10 +615,12 @@ def __init__( resources: typing.Optional[Resources], extended_resources: typing.Optional[tasks_pb2.ExtendedResources], container_image: typing.Optional[str] = None, + pod_template: typing.Optional[PodTemplate] = None, ): self._resources = resources self._extended_resources = extended_resources self._container_image = container_image + self._pod_template = pod_template @property def resources(self) -> Resources: @@ -629,11 +634,27 @@ def extended_resources(self) -> tasks_pb2.ExtendedResources: def container_image(self) -> typing.Optional[str]: return self._container_image + @property + def pod_template(self) -> typing.Optional[PodTemplate]: + return self._pod_template + def to_flyte_idl(self): return _core_workflow.TaskNodeOverrides( resources=self.resources.to_flyte_idl() if self.resources is not None else None, extended_resources=self.extended_resources, container_image=self.container_image, + pod_template=tasks_pb2.K8sPod( + metadata=K8sObjectMetadata( + labels=self.pod_template.labels if self.pod_template else None, + annotations=self.pod_template.annotations if self.pod_template else None, + ).to_flyte_idl() + if self.pod_template is not None + else None, + pod_spec=json_format.Parse(json.dumps(self.pod_template.pod_spec), struct_pb2.Struct()) + if self.pod_template + else None, + primary_container_name=self.pod_template.primary_container_name if self.pod_template else None, + ), ) @classmethod diff --git a/flytekit/models/task.py b/flytekit/models/task.py index 960555fd9b..4810958d9a 100644 --- a/flytekit/models/task.py +++ b/flytekit/models/task.py @@ -1005,6 +1005,7 @@ def __init__( metadata: K8sObjectMetadata = None, pod_spec: typing.Dict[str, typing.Any] = None, data_config: typing.Optional[DataLoadingConfig] = None, + primary_container_name: typing.Optional[str] = None, ): """ This defines a kubernetes pod target. It will build the pod target during task execution @@ -1012,6 +1013,7 @@ def __init__( self._metadata = metadata self._pod_spec = pod_spec self._data_config = data_config + self._primary_container_name = primary_container_name @property def metadata(self) -> K8sObjectMetadata: @@ -1025,6 +1027,10 @@ def pod_spec(self) -> typing.Dict[str, typing.Any]: def data_config(self) -> typing.Optional[DataLoadingConfig]: return self._data_config + @property + def primary_container_name(self) -> typing.Optional[str]: + return self._primary_container_name + def to_flyte_idl(self) -> _core_task.K8sPod: return _core_task.K8sPod( metadata=self._metadata.to_flyte_idl() if self.metadata else None, diff --git a/flytekit/tools/translator.py b/flytekit/tools/translator.py index 5c7a6d5eb4..2b8255a732 100644 --- a/flytekit/tools/translator.py +++ b/flytekit/tools/translator.py @@ -5,7 +5,7 @@ from flyteidl.admin import schedule_pb2 -from flytekit import ImageSpec, PythonFunctionTask, SourceCode +from flytekit import ImageSpec, PodTemplate, PythonFunctionTask, SourceCode from flytekit.configuration import Image, ImageConfig, SerializationSettings from flytekit.core import constants as _common_constants from flytekit.core import context_manager @@ -25,7 +25,7 @@ from flytekit.core.python_function_task import EagerAsyncPythonFunctionTask from flytekit.core.reference_entity import ReferenceEntity, ReferenceSpec, ReferenceTemplate from flytekit.core.task import ReferenceTask -from flytekit.core.utils import ClassDecorator, _dnsify +from flytekit.core.utils import ClassDecorator, _dnsify, _serialize_pod_spec from flytekit.core.workflow import ReferenceWorkflow, WorkflowBase from flytekit.models import common as _common_models from flytekit.models import interface as interface_models @@ -453,6 +453,12 @@ def get_serializable_node( # if entity._aliases: # node_model._output_aliases = entity._aliases elif isinstance(entity.flyte_entity, PythonTask): + override_pod_spec = {} + if entity._pod_template is not None: + entity.flyte_entity.set_command_fn(_fast_serialize_command_fn(settings, entity.flyte_entity)) + override_pod_spec = _serialize_pod_spec( + entity._pod_template, entity.flyte_entity._get_container(settings), settings + ) task_spec = get_serializable(entity_mapping, settings, entity.flyte_entity, options=options) node_model = workflow_model.Node( id=_dnsify(entity.id), @@ -466,6 +472,16 @@ def get_serializable_node( resources=entity._resources, extended_resources=entity._extended_resources, container_image=entity._container_image, + pod_template=PodTemplate( + pod_spec=override_pod_spec, + labels=entity._pod_template.labels if entity._pod_template.labels else None, + annotations=entity._pod_template.annotations if entity._pod_template.annotations else None, + primary_container_name=entity._pod_template.primary_container_name + if entity._pod_template.primary_container_name + else None, + ) + if entity._pod_template + else None, ), ), ) diff --git a/tests/flytekit/unit/core/test_array_node_map_task.py b/tests/flytekit/unit/core/test_array_node_map_task.py index ed1fc7fdd0..c86f8f757a 100644 --- a/tests/flytekit/unit/core/test_array_node_map_task.py +++ b/tests/flytekit/unit/core/test_array_node_map_task.py @@ -8,8 +8,9 @@ import pytest from flyteidl.core import workflow_pb2 as _core_workflow +from kubernetes.client import V1PodSpec, V1Container, V1EnvVar -from flytekit import dynamic, map_task, task, workflow, eager, PythonFunctionTask +from flytekit import dynamic, map_task, task, workflow, eager, PythonFunctionTask, PodTemplate from flytekit.configuration import FastSerializationSettings, Image, ImageConfig, SerializationSettings from flytekit.core import context_manager from flytekit.core.array_node_map_task import ArrayNodeMapTask, ArrayNodeMapTaskResolver @@ -360,6 +361,38 @@ def wf(x: typing.List[int]): assert wf.nodes[0]._container_image == "random:image" +def test_map_task_pod_template_override(serialization_settings): + @task + def my_mappable_task(a: int) -> typing.Optional[str]: + return str(a) + + @workflow + def wf(x: typing.List[int]): + map_task(my_mappable_task)(a=x).with_overrides(pod_template=PodTemplate( + primary_container_name="primary1", + labels={"lKeyA": "lValA", "lKeyB": "lValB"}, + annotations={"aKeyA": "aValA", "aKeyB": "aValB"}, + pod_spec=V1PodSpec( + containers=[ + V1Container( + name="primary1", + image="random:image", + env=[V1EnvVar(name="eKeyC", value="eValC"), V1EnvVar(name="eKeyD", value="eValD")], + ), + V1Container( + name="primary2", + image="random:image2", + env=[V1EnvVar(name="eKeyC", value="eValC"), V1EnvVar(name="eKeyD", value="eValD")], + ), + ], + ) + )) + + assert wf.nodes[0]._pod_template.primary_container_name == "primary1" + assert wf.nodes[0]._pod_template.pod_spec.containers[0].image == "random:image" + assert wf.nodes[0]._pod_template.labels == {"lKeyA": "lValA", "lKeyB": "lValB"} + assert wf.nodes[0]._pod_template.annotations["aKeyA"] == "aValA" + def test_serialization_metadata(serialization_settings): @task(interruptible=True) diff --git a/tests/flytekit/unit/core/test_map_task.py b/tests/flytekit/unit/core/test_map_task.py index 3775c8e12d..13f690bf73 100644 --- a/tests/flytekit/unit/core/test_map_task.py +++ b/tests/flytekit/unit/core/test_map_task.py @@ -3,9 +3,10 @@ from collections import OrderedDict import pytest +from kubernetes.client import V1PodSpec, V1Container, V1EnvVar import flytekit.configuration -from flytekit import LaunchPlan, Resources +from flytekit import LaunchPlan, Resources, PodTemplate from flytekit.configuration import Image, ImageConfig from flytekit.core.legacy_map_task import MapPythonTask, MapTaskResolver, map_task from flytekit.core.task import TaskMetadata, task @@ -354,6 +355,39 @@ def wf(x: typing.List[int]): assert wf.nodes[0]._container_image == "random:image" +def test_map_task_pod_template_override(serialization_settings): + @task + def my_mappable_task(a: int) -> typing.Optional[str]: + return str(a) + + @workflow + def wf(x: typing.List[int]): + map_task(my_mappable_task)(a=x).with_overrides(pod_template=PodTemplate( + primary_container_name="primary1", + labels={"lKeyA": "lValA", "lKeyB": "lValB"}, + annotations={"aKeyA": "aValA", "aKeyB": "aValB"}, + pod_spec=V1PodSpec( + containers=[ + V1Container( + name="primary1", + image="random:image", + env=[V1EnvVar(name="eKeyC", value="eValC"), V1EnvVar(name="eKeyD", value="eValD")], + ), + V1Container( + name="primary2", + image="random:image2", + env=[V1EnvVar(name="eKeyC", value="eValC"), V1EnvVar(name="eKeyD", value="eValD")], + ), + ], + ) + )) + + + assert wf.nodes[0]._pod_template.primary_container_name == "primary1" + assert wf.nodes[0]._pod_template.pod_spec.containers[0].image == "random:image" + assert wf.nodes[0]._pod_template.labels == {"lKeyA": "lValA", "lKeyB": "lValB"} + assert wf.nodes[0]._pod_template.annotations["aKeyA"] == "aValA" + def test_bounded_inputs_vars_order(serialization_settings): mt = map_task(functools.partial(t3, c=1.0, b="hello", a=1)) diff --git a/tests/flytekit/unit/core/test_node_creation.py b/tests/flytekit/unit/core/test_node_creation.py index 381f456bdb..5b911c0052 100644 --- a/tests/flytekit/unit/core/test_node_creation.py +++ b/tests/flytekit/unit/core/test_node_creation.py @@ -4,9 +4,10 @@ from dataclasses import dataclass import pytest +from kubernetes.client import V1PodSpec, V1Container, V1EnvVar import flytekit.configuration -from flytekit import Resources, map_task +from flytekit import Resources, map_task, PodTemplate from flytekit.configuration import Image, ImageConfig from flytekit.core.dynamic_workflow_task import dynamic from flytekit.core.node_creation import create_node @@ -470,6 +471,39 @@ def wf() -> str: assert wf.nodes[0]._container_image == "hello/world" +def test_pod_template_override(): + @task + def bar(): + print("hello") + + @workflow + def wf() -> str: + bar().with_overrides(pod_template=PodTemplate( + primary_container_name="primary1", + labels={"lKeyA": "lValA", "lKeyB": "lValB"}, + annotations={"aKeyA": "aValA", "aKeyB": "aValB"}, + pod_spec=V1PodSpec( + containers=[ + V1Container( + name="primary1", + image="random:image", + env=[V1EnvVar(name="eKeyC", value="eValC"), V1EnvVar(name="eKeyD", value="eValD")], + ), + V1Container( + name="primary2", + image="random:image2", + env=[V1EnvVar(name="eKeyC", value="eValC"), V1EnvVar(name="eKeyD", value="eValD")], + ), + ], + ) + )) + return "hi" + + assert wf.nodes[0]._pod_template.primary_container_name == "primary1" + assert wf.nodes[0]._pod_template.pod_spec.containers[0].image == "random:image" + assert wf.nodes[0]._pod_template.labels == {"lKeyA": "lValA", "lKeyB": "lValB"} + assert wf.nodes[0]._pod_template.annotations["aKeyA"] == "aValA" + def test_override_accelerator(): @task(accelerator=T4)