From e6ad98ca08c97b814b5d44b0e4117ae5b388e28f Mon Sep 17 00:00:00 2001 From: ChungYujoyce Date: Wed, 3 Apr 2024 20:19:53 -0700 Subject: [PATCH] fix #4854 Signed-off-by: ChungYujoyce --- flytekit/core/python_customized_container_task.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/flytekit/core/python_customized_container_task.py b/flytekit/core/python_customized_container_task.py index e47fe77bf0..27b54fe7d0 100644 --- a/flytekit/core/python_customized_container_task.py +++ b/flytekit/core/python_customized_container_task.py @@ -13,6 +13,7 @@ from flytekit.core.shim_task import ExecutableTemplateShimTask, ShimTaskExecutor from flytekit.core.tracker import TrackedInstance from flytekit.core.utils import _get_container_definition, load_proto_from_file +from flytekit.image_spec.image_spec import ImageSpec from flytekit.loggers import logger from flytekit.models import task as _task_model from flytekit.models.core import identifier as identifier_models @@ -157,11 +158,18 @@ def get_command(self, settings: SerializationSettings) -> List[str]: ] return container_args - + + def get_image(self, settings: SerializationSettings) -> str: + if settings.fast_serialization_settings is None or not settings.fast_serialization_settings.enabled: + if isinstance(self.container_image, ImageSpec): + # Set the source root for the image spec if it's non-fast registration + self.container_image.source_root = settings.source_root + return get_registerable_container_image(self.container_image, settings.image_config) + def get_container(self, settings: SerializationSettings) -> _task_model.Container: env = {**settings.env, **self.environment} if self.environment else settings.env return _get_container_definition( - image=get_registerable_container_image(self.container_image, settings.image_config), + image=self.get_image(settings), command=[], args=self.get_command(settings=settings), data_loading_config=None,