Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support overriding task pod_template via with_overrides #2981

Open
wants to merge 7 commits into
base: master
Choose a base branch
from

Conversation

arbaobao
Copy link
Contributor

@arbaobao arbaobao commented Dec 5, 2024

Tracking issue

Related to flyteorg/flyte#5683

Why are the changes needed?

If we can support pod_template in with_overrides, this would reduce a lot of toil since we can supply pod templates in a central location and override downstream tasks, similar to how we can do so for resources.

What changes were proposed in this pull request?

We can use with_override() to override podtemplate, just like resources.

How was this patch tested?

Excute a workflow and using with_override(pod_template=PodTemplate(xxx)) to override the default podtemplate

Setup process

I ran flyte on my local machine and tested my code with this workflow and task:

python

@task
def say_hello() -> str:
    return "Hello, World!"

@workflow
def hello_world_wf() -> str:
    res = say_hello().with_overrides(limits=Resources(cpu="2", mem="600Mi"),pod_template=PodTemplate(
        primary_container_name="primary-nelson",
        labels={"lKeyA": "lValA", "lKeyB": "lValB"},
        annotations={"aKeyA": "aValA", "aKeyB": "aValB"},
        pod_spec=V1PodSpec(
            containers=[
                V1Container(
                    name="primary-nelson",
                    image="arbaobao/flyte-test-images:pythonpath5",
                    env=[V1EnvVar(name="eKeyC", value="eValC"), V1EnvVar(name="eKeyD", value="eValD")],
                ),
                V1Container(
                    name="primary-nelson2",
                    image="arbaobao/flyte-test-images:pythonpath5",
                    env=[V1EnvVar(name="eKeyC", value="eValC"), V1EnvVar(name="eKeyD", value="eValD")],
                ),
            ],
        )
    ))
    return res

Screenshots

Check all the applicable boxes

  • I updated the documentation accordingly.
  • All new and existing tests passed.
  • All commits are signed-off.

Related PRs

Docs link

Summary by Bito

This PR implements pod template override functionality using with_overrides() method, enabling customization of Kubernetes pod configurations. The implementation includes pod template support in Node class, workflow models, and K8sPod with primary container name support. The changes enhance task configuration flexibility while maintaining compatibility with existing resource override patterns.

Unit tests added: True

Estimated effort to review (1-5, lower is better): 3

Signed-off-by: Nelson Chen <[email protected]>
@arbaobao
Copy link
Contributor Author

#take

@@ -612,10 +614,12 @@ def __init__(
resources: typing.Optional[Resources],
extended_resources: typing.Optional[tasks_pb2.ExtendedResources],
container_image: typing.Optional[str] = None,
pod_template: typing.Optional[tasks_pb2.K8sPod] = None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The return type should be PodTemplate, right

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pingsutw I didn't create a new proto named PodTemplate in tasks_pb2, so I used the existing protobuf K8sPod.

@@ -629,11 +633,22 @@ def extended_resources(self) -> tasks_pb2.ExtendedResources:
def container_image(self) -> typing.Optional[str]:
return self._container_image

@property
def pod_template(self) -> typing.Optional[tasks_pb2.K8sPod]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

@@ -1005,13 +1005,15 @@ def __init__(
metadata: K8sObjectMetadata = None,
pod_spec: typing.Dict[str, typing.Any] = None,
data_config: typing.Optional[DataLoadingConfig] = None,
primarycontainername: typing.Optional[str] = None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
primarycontainername: typing.Optional[str] = None,
primary_container_name: typing.Optional[str] = None,

@@ -1025,6 +1027,10 @@ def pod_spec(self) -> typing.Dict[str, typing.Any]:
def data_config(self) -> typing.Optional[DataLoadingConfig]:
return self._data_config

@property
def primarycontainername(self) -> typing.Optional[str]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def primarycontainername(self) -> typing.Optional[str]:
def primary_container_name(self) -> typing.Optional[str]:

@@ -453,6 +453,12 @@ def get_serializable_node(
# if entity._aliases:
# node_model._output_aliases = entity._aliases
elif isinstance(entity.flyte_entity, PythonTask):
override_pod = {}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
override_pod = {}
overrided_pod_spec = {}

@@ -797,6 +813,7 @@ def get_serializable(
)
entity.docs.source_code = SourceCode(link=settings.git_repo)
# This needs to be at the bottom not the top - i.e. dependent tasks get added before the workflow containing it
# if not any(entity_mapping.get("id") == cp_entity.get("id") for entity in entity_mapping)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove?

Signed-off-by: Nelson Chen <[email protected]>
@flyte-bot
Copy link
Contributor

flyte-bot commented Dec 29, 2024

Code Review Agent Run #0770cb

Actionable Suggestions - 5
  • flytekit/models/core/workflow.py - 1
    • Consider handling pod_spec JSON serialization errors · Line 647-649
  • tests/flytekit/unit/core/test_map_task.py - 1
    • Consider improving map task call readability · Line 365-383
  • flytekit/tools/translator.py - 2
    • Consider command serialization location · Line 457-461
    • Consider extracting pod template creation logic · Line 475-484
  • flytekit/models/task.py - 1
Additional Suggestions - 1
  • flytekit/core/node.py - 1
    • Consider pod template as constructor param · Line 71-71
Review Details
  • Files reviewed - 7 · Commit Range: 76fbbd4..fce7ccd
    • flytekit/core/node.py
    • flytekit/models/core/workflow.py
    • flytekit/models/task.py
    • flytekit/tools/translator.py
    • tests/flytekit/unit/core/test_array_node_map_task.py
    • tests/flytekit/unit/core/test_map_task.py
    • tests/flytekit/unit/core/test_node_creation.py
  • Files skipped - 0
  • Tools
    • Whispers (Secret Scanner) - ✔︎ Successful
    • Detect-secrets (Secret Scanner) - ✔︎ Successful
    • MyPy (Static Code Analysis) - ✔︎ Successful
    • Astral Ruff (Static Code Analysis) - ✔︎ Successful

AI Code Review powered by Bito Logo

@flyte-bot
Copy link
Contributor

Changelist by Bito

This pull request implements the following key changes.

Key Change Files Impacted
Feature Improvement - Pod Template Override Support

node.py - Added pod template support in Node class

workflow.py - Implemented pod template handling in workflow models

task.py - Added primary container name support for K8sPod

translator.py - Enhanced serialization for pod template overrides

Testing - Pod Template Override Tests

test_array_node_map_task.py - Added tests for map task pod template overrides

test_map_task.py - Added tests for map task pod template functionality

test_node_creation.py - Added tests for basic pod template override functionality

Comment on lines +647 to +649
pod_spec=json_format.Parse(json.dumps(self.pod_template.pod_spec), struct_pb2.Struct())
if self.pod_template
else None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider handling pod_spec JSON serialization errors

Consider handling potential JSON serialization errors when converting pod_spec to protobuf struct. The current implementation may raise exceptions if pod_spec contains non-serializable objects.

Code suggestion
Check the AI-generated fix before applying
Suggested change
pod_spec=json_format.Parse(json.dumps(self.pod_template.pod_spec), struct_pb2.Struct())
if self.pod_template
else None,
pod_spec=(
try:
json_format.Parse(json.dumps(self.pod_template.pod_spec), struct_pb2.Struct())
except (TypeError, ValueError, json.JSONDecodeError) as e:
logger.error(f"Failed to serialize pod_spec: {e}")
None
) if self.pod_template else None,

Code Review Run #0770cb


Is this a valid issue, or was it incorrectly flagged by the Agent?

  • it was incorrectly flagged

Comment on lines +365 to +383
map_task(my_mappable_task)(a=x).with_overrides(pod_template=PodTemplate(
primary_container_name="primary1",
labels={"lKeyA": "lValA", "lKeyB": "lValB"},
annotations={"aKeyA": "aValA", "aKeyB": "aValB"},
pod_spec=V1PodSpec(
containers=[
V1Container(
name="primary1",
image="random:image",
env=[V1EnvVar(name="eKeyC", value="eValC"), V1EnvVar(name="eKeyD", value="eValD")],
),
V1Container(
name="primary2",
image="random:image2",
env=[V1EnvVar(name="eKeyC", value="eValC"), V1EnvVar(name="eKeyD", value="eValD")],
),
],
)
))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider improving map task call readability

Consider breaking down the nested map_task call with with_overrides into multiple lines for better readability. The current structure makes it difficult to understand the configuration hierarchy.

Code suggestion
Check the AI-generated fix before applying
Suggested change
map_task(my_mappable_task)(a=x).with_overrides(pod_template=PodTemplate(
primary_container_name="primary1",
labels={"lKeyA": "lValA", "lKeyB": "lValB"},
annotations={"aKeyA": "aValA", "aKeyB": "aValB"},
pod_spec=V1PodSpec(
containers=[
V1Container(
name="primary1",
image="random:image",
env=[V1EnvVar(name="eKeyC", value="eValC"), V1EnvVar(name="eKeyD", value="eValD")],
),
V1Container(
name="primary2",
image="random:image2",
env=[V1EnvVar(name="eKeyC", value="eValC"), V1EnvVar(name="eKeyD", value="eValD")],
),
],
)
))
pod_template = PodTemplate(
primary_container_name="primary1",
labels={"lKeyA": "lValA", "lKeyB": "lValB"},
annotations={"aKeyA": "aValA", "aKeyB": "aValB"},
pod_spec=V1PodSpec(
containers=[
V1Container(
name="primary1",
image="random:image",
env=[V1EnvVar(name="eKeyC", value="eValC"), V1EnvVar(name="eKeyD", value="eValD")],
),
V1Container(
name="primary2",
image="random:image2",
env=[V1EnvVar(name="eKeyC", value="eValC"), V1EnvVar(name="eKeyD", value="eValD")],
),
],

Code Review Run #0770cb


Is this a valid issue, or was it incorrectly flagged by the Agent?

  • it was incorrectly flagged

Comment on lines +457 to +461
if entity._pod_template is not None:
entity.flyte_entity.set_command_fn(_fast_serialize_command_fn(settings, entity.flyte_entity))
override_pod_spec = _serialize_pod_spec(
entity._pod_template, entity.flyte_entity._get_container(settings), settings
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider command serialization location

Consider moving the command function serialization logic outside the if block since it's needed regardless of whether _pod_template exists or not.

Code suggestion
Check the AI-generated fix before applying
Suggested change
if entity._pod_template is not None:
entity.flyte_entity.set_command_fn(_fast_serialize_command_fn(settings, entity.flyte_entity))
override_pod_spec = _serialize_pod_spec(
entity._pod_template, entity.flyte_entity._get_container(settings), settings
)
entity.flyte_entity.set_command_fn(_fast_serialize_command_fn(settings, entity.flyte_entity))
if entity._pod_template is not None:
override_pod_spec = _serialize_pod_spec(
entity._pod_template, entity.flyte_entity._get_container(settings), settings
)

Code Review Run #0770cb


Is this a valid issue, or was it incorrectly flagged by the Agent?

  • it was incorrectly flagged

Comment on lines +475 to +484
pod_template=K8sPod(
pod_spec=override_pod_spec if override_pod_spec is not None else None,
metadata=K8sObjectMetadata(
labels=entity._pod_template.labels if entity._pod_template else None,
annotations=entity._pod_template.annotations if entity._pod_template else None,
),
primary_container_name=entity._pod_template.primary_container_name
if entity._pod_template
else None,
),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider extracting pod template creation logic

Consider extracting the K8sPod template creation logic into a separate helper function to improve code readability and maintainability. The nested structure with multiple conditional checks makes the code harder to follow. A similar issue was also found in tests/flytekit/unit/core/test_array_node_map_task.py (line 371-389).

Code suggestion
Check the AI-generated fix before applying
Suggested change
pod_template=K8sPod(
pod_spec=override_pod_spec if override_pod_spec is not None else None,
metadata=K8sObjectMetadata(
labels=entity._pod_template.labels if entity._pod_template else None,
annotations=entity._pod_template.annotations if entity._pod_template else None,
),
primary_container_name=entity._pod_template.primary_container_name
if entity._pod_template
else None,
),
pod_template=_create_pod_template(entity, override_pod_spec),
def _create_pod_template(entity, override_pod_spec):
return K8sPod(
pod_spec=override_pod_spec if override_pod_spec is not None else None,
metadata=K8sObjectMetadata(
labels=entity._pod_template.labels if entity._pod_template else None,
annotations=entity._pod_template.annotations if entity._pod_template else None,
),
primary_container_name=entity._pod_template.primary_container_name
if entity._pod_template
else None,
)

Code Review Run #0770cb


Is this a valid issue, or was it incorrectly flagged by the Agent?

  • it was incorrectly flagged

@@ -1005,13 +1005,15 @@ def __init__(
metadata: K8sObjectMetadata = None,
pod_spec: typing.Dict[str, typing.Any] = None,
data_config: typing.Optional[DataLoadingConfig] = None,
primary_container_name: typing.Optional[str] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider adding validation for container name

Consider adding validation for primary_container_name to ensure it's not an empty string when provided.

Code suggestion
Check the AI-generated fix before applying
 @@ -1013,4 +1013,6 @@
          self._metadata = metadata
          self._pod_spec = pod_spec
          self._data_config = data_config
 +        if primary_container_name is not None and not primary_container_name.strip():
 +            raise ValueError("primary_container_name cannot be empty if provided")
          self._primary_container_name = primary_container_name

Code Review Run #0770cb


Is this a valid issue, or was it incorrectly flagged by the Agent?

  • it was incorrectly flagged

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
Status: In review
Development

Successfully merging this pull request may close these issues.

3 participants