diff --git a/flytekit/core/python_customized_container_task.py b/flytekit/core/python_customized_container_task.py index e47fe77bf0..27b54fe7d0 100644 --- a/flytekit/core/python_customized_container_task.py +++ b/flytekit/core/python_customized_container_task.py @@ -13,6 +13,7 @@ from flytekit.core.shim_task import ExecutableTemplateShimTask, ShimTaskExecutor from flytekit.core.tracker import TrackedInstance from flytekit.core.utils import _get_container_definition, load_proto_from_file +from flytekit.image_spec.image_spec import ImageSpec from flytekit.loggers import logger from flytekit.models import task as _task_model from flytekit.models.core import identifier as identifier_models @@ -157,11 +158,18 @@ def get_command(self, settings: SerializationSettings) -> List[str]: ] return container_args - + + def get_image(self, settings: SerializationSettings) -> str: + if settings.fast_serialization_settings is None or not settings.fast_serialization_settings.enabled: + if isinstance(self.container_image, ImageSpec): + # Set the source root for the image spec if it's non-fast registration + self.container_image.source_root = settings.source_root + return get_registerable_container_image(self.container_image, settings.image_config) + def get_container(self, settings: SerializationSettings) -> _task_model.Container: env = {**settings.env, **self.environment} if self.environment else settings.env return _get_container_definition( - image=get_registerable_container_image(self.container_image, settings.image_config), + image=self.get_image(settings), command=[], args=self.get_command(settings=settings), data_loading_config=None,