From 3fc51afb6b96859a7b625aff5071a55007e7c561 Mon Sep 17 00:00:00 2001 From: Yee Hing Tong Date: Tue, 22 Oct 2024 15:44:02 -0700 Subject: [PATCH] Async/data persistence (#2829) Signed-off-by: Yee Hing Tong --- flytekit/core/data_persistence.py | 104 +++++++++++++----- flytekit/core/type_engine.py | 10 +- flytekit/extend/backend/base_agent.py | 2 +- flytekit/extras/tensorflow/model.py | 12 +- flytekit/types/directory/types.py | 12 +- flytekit/types/file/file.py | 9 +- flytekit/types/iterator/json_iterator.py | 11 +- flytekit/types/numpy/ndarray.py | 14 ++- flytekit/types/pickle/pickle.py | 22 ++-- flytekit/types/schema/types.py | 18 ++- flytekit/types/schema/types_pandas.py | 12 +- .../types/structured/structured_dataset.py | 10 +- flytekit/utils/asyn.py | 2 +- tests/flytekit/unit/core/test_data.py | 88 ++++++++++++++- .../unit/core/test_data_persistence.py | 22 +++- tests/flytekit/unit/core/test_flyte_file.py | 2 +- tests/flytekit/unit/core/test_type_engine.py | 6 +- tests/flytekit/unit/remote/test_fs_remote.py | 2 +- tests/flytekit/unit/utils/test_asyn.py | 6 + 19 files changed, 266 insertions(+), 98 deletions(-) diff --git a/flytekit/core/data_persistence.py b/flytekit/core/data_persistence.py index cdd07afba7..751ffb8b27 100644 --- a/flytekit/core/data_persistence.py +++ b/flytekit/core/data_persistence.py @@ -18,6 +18,7 @@ """ +import asyncio import io import os import pathlib @@ -29,6 +30,7 @@ import fsspec from decorator import decorator +from fsspec.asyn import AsyncFileSystem from fsspec.utils import get_protocol from typing_extensions import Unpack @@ -40,6 +42,7 @@ from flytekit.exceptions.user import FlyteAssertion, FlyteDataNotFoundException from flytekit.interfaces.random import random from flytekit.loggers import logger +from flytekit.utils.asyn import loop_manager # Refer to https://github.com/fsspec/s3fs/blob/50bafe4d8766c3b2a4e1fc09669cf02fb2d71454/s3fs/core.py#L198 # for key and secret @@ -208,8 +211,17 @@ def get_filesystem( storage_options = get_fsspec_storage_options( protocol=protocol, anonymous=anonymous, data_config=self._data_config, **kwargs ) + kwargs.update(storage_options) - return fsspec.filesystem(protocol, **storage_options) + return fsspec.filesystem(protocol, **kwargs) + + async def get_async_filesystem_for_path( + self, path: str = "", anonymous: bool = False, **kwargs + ) -> Union[AsyncFileSystem, fsspec.AbstractFileSystem]: + protocol = get_protocol(path) + loop = asyncio.get_running_loop() + + return self.get_filesystem(protocol, anonymous=anonymous, path=path, asynchronous=True, loop=loop, **kwargs) def get_filesystem_for_path(self, path: str = "", anonymous: bool = False, **kwargs) -> fsspec.AbstractFileSystem: protocol = get_protocol(path) @@ -282,8 +294,8 @@ def exists(self, path: str) -> bool: raise oe @retry_request - def get(self, from_path: str, to_path: str, recursive: bool = False, **kwargs): - file_system = self.get_filesystem_for_path(from_path) + async def get(self, from_path: str, to_path: str, recursive: bool = False, **kwargs): + file_system = await self.get_async_filesystem_for_path(from_path) if recursive: from_path, to_path = self.recursive_paths(from_path, to_path) try: @@ -294,7 +306,10 @@ def get(self, from_path: str, to_path: str, recursive: bool = False, **kwargs): self.strip_file_header(from_path), self.strip_file_header(to_path), dirs_exist_ok=True ) logger.info(f"Getting {from_path} to {to_path}") - dst = file_system.get(from_path, to_path, recursive=recursive, **kwargs) + if isinstance(file_system, AsyncFileSystem): + dst = await file_system._get(from_path, to_path, recursive=recursive, **kwargs) # pylint: disable=W0212 + else: + dst = file_system.get(from_path, to_path, recursive=recursive, **kwargs) if isinstance(dst, (str, pathlib.Path)): return dst return to_path @@ -302,15 +317,22 @@ def get(self, from_path: str, to_path: str, recursive: bool = False, **kwargs): logger.debug(f"Error in getting {from_path} to {to_path} rec {recursive} {oe}") if not file_system.exists(from_path): raise FlyteDataNotFoundException(from_path) - file_system = self.get_filesystem(get_protocol(from_path), anonymous=True) + file_system = self.get_filesystem(get_protocol(from_path), anonymous=True, asynchronous=True) if file_system is not None: logger.debug(f"Attempting anonymous get with {file_system}") - return file_system.get(from_path, to_path, recursive=recursive, **kwargs) + if isinstance(file_system, AsyncFileSystem): + return await file_system._get(from_path, to_path, recursive=recursive, **kwargs) # pylint: disable=W0212 + else: + return file_system.get(from_path, to_path, recursive=recursive, **kwargs) raise oe @retry_request - def put(self, from_path: str, to_path: str, recursive: bool = False, **kwargs): - file_system = self.get_filesystem_for_path(to_path) + async def _put(self, from_path: str, to_path: str, recursive: bool = False, **kwargs): + """ + More of an internal function to be called by put_data and put_raw_data + This does not need a separate sync function. + """ + file_system = await self.get_async_filesystem_for_path(to_path) from_path = self.strip_file_header(from_path) if recursive: # Only check this for the local filesystem @@ -327,13 +349,16 @@ def put(self, from_path: str, to_path: str, recursive: bool = False, **kwargs): if "metadata" not in kwargs: kwargs["metadata"] = {} kwargs["metadata"].update(self._execution_metadata) - dst = file_system.put(from_path, to_path, recursive=recursive, **kwargs) + if isinstance(file_system, AsyncFileSystem): + dst = await file_system._put(from_path, to_path, recursive=recursive, **kwargs) # pylint: disable=W0212 + else: + dst = file_system.put(from_path, to_path, recursive=recursive, **kwargs) if isinstance(dst, (str, pathlib.Path)): return dst else: return to_path - def put_raw_data( + async def async_put_raw_data( self, lpath: Uploadable, upload_prefix: Optional[str] = None, @@ -364,7 +389,7 @@ def put_raw_data( :param read_chunk_size_bytes: If lpath is a buffer, this is the chunk size to read from it :param encoding: If lpath is a io.StringIO, this is the encoding to use to encode it to binary. :param skip_raw_data_prefix: If True, the raw data prefix will not be prepended to the upload_prefix - :param kwargs: Additional kwargs are passed into the the fsspec put() call or the open() call + :param kwargs: Additional kwargs are passed into the fsspec put() call or the open() call :return: Returns the final path data was written to. """ # First figure out what the destination path should be, then call put. @@ -388,42 +413,60 @@ def put_raw_data( raise FlyteAssertion(f"File {from_path} is a symlink, can't upload") if p.is_dir(): logger.debug(f"Detected directory {from_path}, using recursive put") - r = self.put(from_path, to_path, recursive=True, **kwargs) + r = await self._put(from_path, to_path, recursive=True, **kwargs) else: logger.debug(f"Detected file {from_path}, call put non-recursive") - r = self.put(from_path, to_path, **kwargs) + r = await self._put(from_path, to_path, **kwargs) return r or to_path # raw bytes if isinstance(lpath, bytes): - fs = self.get_filesystem_for_path(to_path) - with fs.open(to_path, "wb", **kwargs) as s: - s.write(lpath) + fs = await self.get_async_filesystem_for_path(to_path) + if isinstance(fs, AsyncFileSystem): + async with fs.open_async(to_path, "wb", **kwargs) as s: + s.write(lpath) + else: + with fs.open(to_path, "wb", **kwargs) as s: + s.write(lpath) + return to_path # If lpath is a buffered reader of some kind if isinstance(lpath, io.BufferedReader) or isinstance(lpath, io.BytesIO): if not lpath.readable(): raise FlyteAssertion("Buffered reader must be readable") - fs = self.get_filesystem_for_path(to_path) + fs = await self.get_async_filesystem_for_path(to_path) lpath.seek(0) - with fs.open(to_path, "wb", **kwargs) as s: - while data := lpath.read(read_chunk_size_bytes): - s.write(data) + if isinstance(fs, AsyncFileSystem): + async with fs.open_async(to_path, "wb", **kwargs) as s: + while data := lpath.read(read_chunk_size_bytes): + s.write(data) + else: + with fs.open(to_path, "wb", **kwargs) as s: + while data := lpath.read(read_chunk_size_bytes): + s.write(data) return to_path if isinstance(lpath, io.StringIO): if not lpath.readable(): raise FlyteAssertion("Buffered reader must be readable") - fs = self.get_filesystem_for_path(to_path) + fs = await self.get_async_filesystem_for_path(to_path) lpath.seek(0) - with fs.open(to_path, "wb", **kwargs) as s: - while data_str := lpath.read(read_chunk_size_bytes): - s.write(data_str.encode(encoding)) + if isinstance(fs, AsyncFileSystem): + async with fs.open_async(to_path, "wb", **kwargs) as s: + while data_str := lpath.read(read_chunk_size_bytes): + s.write(data_str.encode(encoding)) + else: + with fs.open(to_path, "wb", **kwargs) as s: + while data_str := lpath.read(read_chunk_size_bytes): + s.write(data_str.encode(encoding)) return to_path raise FlyteAssertion(f"Unsupported lpath type {type(lpath)}") + # Public synchronous version + put_raw_data = loop_manager.synced(async_put_raw_data) + @staticmethod def get_random_string() -> str: return UUID(int=random.getrandbits(128)).hex @@ -549,7 +592,7 @@ def upload_directory(self, local_path: str, remote_path: str, **kwargs): """ return self.put_data(local_path, remote_path, is_multipart=True, **kwargs) - def get_data(self, remote_path: str, local_path: str, is_multipart: bool = False, **kwargs): + async def async_get_data(self, remote_path: str, local_path: str, is_multipart: bool = False, **kwargs): """ :param remote_path: :param local_path: @@ -558,7 +601,7 @@ def get_data(self, remote_path: str, local_path: str, is_multipart: bool = False try: pathlib.Path(local_path).parent.mkdir(parents=True, exist_ok=True) with timeit(f"Download data to local from {remote_path}"): - self.get(remote_path, to_path=local_path, recursive=is_multipart, **kwargs) + await self.get(remote_path, to_path=local_path, recursive=is_multipart, **kwargs) except FlyteDataNotFoundException: raise except Exception as ex: @@ -567,7 +610,9 @@ def get_data(self, remote_path: str, local_path: str, is_multipart: bool = False f"Original exception: {str(ex)}" ) - def put_data( + get_data = loop_manager.synced(async_get_data) + + async def async_put_data( self, local_path: Union[str, os.PathLike], remote_path: str, is_multipart: bool = False, **kwargs ) -> str: """ @@ -581,7 +626,7 @@ def put_data( try: local_path = str(local_path) with timeit(f"Upload data to {remote_path}"): - put_result = self.put(cast(str, local_path), remote_path, recursive=is_multipart, **kwargs) + put_result = await self._put(cast(str, local_path), remote_path, recursive=is_multipart, **kwargs) # This is an unfortunate workaround to ensure that we return the correct path for the remote location # Callers of this put_data function in flytekit have been changed to assign the remote path to the # output @@ -595,6 +640,9 @@ def put_data( f"Original exception: {str(ex)}" ) from ex + # Public synchronous version + put_data = loop_manager.synced(async_put_data) + flyte_tmp_dir = tempfile.mkdtemp(prefix="flyte-") default_local_file_access_provider = FileAccessProvider( diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index ee1513267b..39b0d6f096 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -1908,7 +1908,9 @@ def extract_types_or_metadata(t: Optional[Type[dict]]) -> typing.Tuple: return None, None @staticmethod - def dict_to_binary_literal(ctx: FlyteContext, v: dict, python_type: Type[dict], allow_pickle: bool) -> Literal: + async def dict_to_binary_literal( + ctx: FlyteContext, v: dict, python_type: Type[dict], allow_pickle: bool + ) -> Literal: """ Converts a Python dictionary to a Flyte-specific ``Literal`` using MessagePack encoding. Falls back to Pickle if encoding fails and `allow_pickle` is True. @@ -1922,7 +1924,7 @@ def dict_to_binary_literal(ctx: FlyteContext, v: dict, python_type: Type[dict], return Literal(scalar=Scalar(binary=Binary(value=msgpack_bytes, tag="msgpack"))) except TypeError as e: if allow_pickle: - remote_path = FlytePickle.to_pickle(ctx, v) + remote_path = await FlytePickle.to_pickle(ctx, v) return Literal( scalar=Scalar( generic=_json_format.Parse(json.dumps({"pickle_file": remote_path}), _struct.Struct()) @@ -1980,7 +1982,7 @@ async def async_to_literal( allow_pickle, base_type = DictTransformer.is_pickle(python_type) if expected and expected.simple and expected.simple == SimpleType.STRUCT: - return self.dict_to_binary_literal(ctx, python_val, python_type, allow_pickle) + return await self.dict_to_binary_literal(ctx, python_val, python_type, allow_pickle) lit_map = {} for k, v in python_val.items(): @@ -2036,7 +2038,7 @@ async def async_to_python_value(self, ctx: FlyteContext, lv: Literal, expected_p from flytekit.types.pickle import FlytePickle uri = json.loads(_json_format.MessageToJson(lv.scalar.generic)).get("pickle_file") - return FlytePickle.from_pickle(uri) + return await FlytePickle.from_pickle(uri) try: return json.loads(_json_format.MessageToJson(lv.scalar.generic)) diff --git a/flytekit/extend/backend/base_agent.py b/flytekit/extend/backend/base_agent.py index 2f973e94f0..eb476bc983 100644 --- a/flytekit/extend/backend/base_agent.py +++ b/flytekit/extend/backend/base_agent.py @@ -368,7 +368,7 @@ async def _create( literal_map = await TypeEngine._dict_to_literal_map(ctx, inputs or {}, self.get_input_types()) path = ctx.file_access.get_random_local_path() utils.write_proto_to_file(literal_map.to_flyte_idl(), path) - ctx.file_access.put_data(path, f"{output_prefix}/inputs.pb") + await ctx.file_access.async_put_data(path, f"{output_prefix}/inputs.pb") task_template = render_task_template(task_template, output_prefix) else: literal_map = TypeEngine.dict_to_literal_map(ctx, inputs or {}, self.get_input_types()) diff --git a/flytekit/extras/tensorflow/model.py b/flytekit/extras/tensorflow/model.py index 2978fe1d69..b9fbf24d4b 100644 --- a/flytekit/extras/tensorflow/model.py +++ b/flytekit/extras/tensorflow/model.py @@ -4,13 +4,13 @@ import tensorflow as tf from flytekit.core.context_manager import FlyteContext -from flytekit.core.type_engine import TypeEngine, TypeTransformer, TypeTransformerFailedError +from flytekit.core.type_engine import AsyncTypeTransformer, TypeEngine, TypeTransformerFailedError from flytekit.models.core import types as _core_types from flytekit.models.literals import Blob, BlobMetadata, Literal, Scalar from flytekit.models.types import LiteralType -class TensorFlowModelTransformer(TypeTransformer[tf.keras.Model]): +class TensorFlowModelTransformer(AsyncTypeTransformer[tf.keras.Model]): TENSORFLOW_FORMAT = "TensorFlowModel" def __init__(self): @@ -24,7 +24,7 @@ def get_literal_type(self, t: Type[tf.keras.Model]) -> LiteralType: ) ) - def to_literal( + async def async_to_literal( self, ctx: FlyteContext, python_val: tf.keras.Model, @@ -44,10 +44,10 @@ def to_literal( # save model in SavedModel format tf.keras.models.save_model(python_val, local_path) - remote_path = ctx.file_access.put_raw_data(local_path) + remote_path = await ctx.file_access.async_put_raw_data(local_path) return Literal(scalar=Scalar(blob=Blob(metadata=meta, uri=remote_path))) - def to_python_value( + async def async_to_python_value( self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[tf.keras.Model] ) -> tf.keras.Model: try: @@ -56,7 +56,7 @@ def to_python_value( TypeTransformerFailedError(f"Cannot convert from {lv} to {expected_python_type}") local_path = ctx.file_access.get_random_local_path() - ctx.file_access.get_data(uri, local_path, is_multipart=True) + await ctx.file_access.async_get_data(uri, local_path, is_multipart=True) # load model return tf.keras.models.load_model(local_path) diff --git a/flytekit/types/directory/types.py b/flytekit/types/directory/types.py index 518525914d..52249e2977 100644 --- a/flytekit/types/directory/types.py +++ b/flytekit/types/directory/types.py @@ -19,7 +19,7 @@ from flytekit import BlobType from flytekit.core.constants import MESSAGEPACK from flytekit.core.context_manager import FlyteContext, FlyteContextManager -from flytekit.core.type_engine import TypeEngine, TypeTransformer, TypeTransformerFailedError, get_batch_size +from flytekit.core.type_engine import AsyncTypeTransformer, TypeEngine, TypeTransformerFailedError, get_batch_size from flytekit.exceptions.user import FlyteAssertion from flytekit.models import types as _type_models from flytekit.models.core import types as _core_types @@ -407,7 +407,7 @@ def __str__(self): return str(self.path) -class FlyteDirToMultipartBlobTransformer(TypeTransformer[FlyteDirectory]): +class FlyteDirToMultipartBlobTransformer(AsyncTypeTransformer[FlyteDirectory]): """ This transformer handles conversion between the Python native FlyteDirectory class defined above, and the Flyte IDL literal/type of Multipart Blob. Please see the FlyteDirectory comments for additional information. @@ -444,7 +444,7 @@ def assert_type(self, t: typing.Type[FlyteDirectory], v: typing.Union[FlyteDirec def get_literal_type(self, t: typing.Type[FlyteDirectory]) -> LiteralType: return _type_models.LiteralType(blob=self._blob_type(format=FlyteDirToMultipartBlobTransformer.get_format(t))) - def to_literal( + async def async_to_literal( self, ctx: FlyteContext, python_val: FlyteDirectory, @@ -499,7 +499,9 @@ def to_literal( remote_directory = ctx.file_access.get_random_remote_directory() if not pathlib.Path(source_path).is_dir(): raise FlyteAssertion("Expected a directory. {} is not a directory".format(source_path)) - ctx.file_access.put_data(source_path, remote_directory, is_multipart=True, batch_size=batch_size) + await ctx.file_access.async_put_data( + source_path, remote_directory, is_multipart=True, batch_size=batch_size + ) return Literal(scalar=Scalar(blob=Blob(metadata=meta, uri=remote_directory))) # If not uploading, then we can only take the original source path as the uri. @@ -535,7 +537,7 @@ def from_binary_idl( else: raise TypeTransformerFailedError(f"Unsupported binary format: `{binary_idl_object.tag}`") - def to_python_value( + async def async_to_python_value( self, ctx: FlyteContext, lv: Literal, expected_python_type: typing.Type[FlyteDirectory] ) -> FlyteDirectory: if lv.scalar.binary: diff --git a/flytekit/types/file/file.py b/flytekit/types/file/file.py index 602f5bc12e..bf08a1f535 100644 --- a/flytekit/types/file/file.py +++ b/flytekit/types/file/file.py @@ -309,6 +309,9 @@ def remote_source(self) -> str: def download(self) -> str: return self.__fspath__() + async def _download(self) -> str: + return self.__fspath__() + @contextmanager def open( self, @@ -511,9 +514,11 @@ async def async_to_literal( if should_upload: headers = self.get_additional_headers(source_path) if remote_path is not None: - remote_path = ctx.file_access.put_data(source_path, remote_path, is_multipart=False, **headers) + remote_path = await ctx.file_access.async_put_data( + source_path, remote_path, is_multipart=False, **headers + ) else: - remote_path = ctx.file_access.put_raw_data(source_path, **headers) + remote_path = await ctx.file_access.async_put_raw_data(source_path, **headers) return Literal(scalar=Scalar(blob=Blob(metadata=meta, uri=unquote(str(remote_path))))) # If not uploading, then we can only take the original source path as the uri. else: diff --git a/flytekit/types/iterator/json_iterator.py b/flytekit/types/iterator/json_iterator.py index 782b6de229..7db9361bac 100644 --- a/flytekit/types/iterator/json_iterator.py +++ b/flytekit/types/iterator/json_iterator.py @@ -6,8 +6,8 @@ from flytekit import FlyteContext, Literal, LiteralType from flytekit.core.type_engine import ( + AsyncTypeTransformer, TypeEngine, - TypeTransformer, TypeTransformerFailedError, ) from flytekit.models.core import types as _core_types @@ -34,7 +34,7 @@ def __next__(self): raise StopIteration("File handler is exhausted") -class JSONIteratorTransformer(TypeTransformer[Iterator[JSON]]): +class JSONIteratorTransformer(AsyncTypeTransformer[Iterator[JSON]]): """ A JSON iterator that handles conversion between an iterator/generator and a JSONL file. """ @@ -54,7 +54,7 @@ def get_literal_type(self, t: Type[Iterator[JSON]]) -> LiteralType: metadata={"format": self.JSON_ITERATOR_METADATA}, ) - def to_literal( + async def async_to_literal( self, ctx: FlyteContext, python_val: Iterator[JSON], @@ -83,9 +83,10 @@ def to_literal( ) ) - return Literal(scalar=Scalar(blob=Blob(metadata=meta, uri=ctx.file_access.put_raw_data(uri)))) + uri = await ctx.file_access.async_put_raw_data(uri) + return Literal(scalar=Scalar(blob=Blob(metadata=meta, uri=uri))) - def to_python_value( + async def async_to_python_value( self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[Iterator[JSON]] ) -> JSONIterator: try: diff --git a/flytekit/types/numpy/ndarray.py b/flytekit/types/numpy/ndarray.py index 1ca25bde11..91c9dc3019 100644 --- a/flytekit/types/numpy/ndarray.py +++ b/flytekit/types/numpy/ndarray.py @@ -9,8 +9,8 @@ from flytekit.core.context_manager import FlyteContext from flytekit.core.hash import HashMethod from flytekit.core.type_engine import ( + AsyncTypeTransformer, TypeEngine, - TypeTransformer, TypeTransformerFailedError, ) from flytekit.models.core import types as _core_types @@ -41,7 +41,7 @@ def extract_metadata(t: Type[np.ndarray]) -> Tuple[Type[np.ndarray], Dict[str, b return t, metadata -class NumpyArrayTransformer(TypeTransformer[np.ndarray]): +class NumpyArrayTransformer(AsyncTypeTransformer[np.ndarray]): """ TypeTransformer that supports np.ndarray as a native type. """ @@ -59,7 +59,7 @@ def get_literal_type(self, t: Type[np.ndarray]) -> LiteralType: ) ) - def to_literal( + async def async_to_literal( self, ctx: FlyteContext, python_val: np.ndarray, @@ -84,10 +84,12 @@ def to_literal( arr=python_val, allow_pickle=metadata.get("allow_pickle", False), ) - remote_path = ctx.file_access.put_raw_data(local_path) + remote_path = await ctx.file_access.async_put_raw_data(local_path) return Literal(scalar=Scalar(blob=Blob(metadata=meta, uri=remote_path))) - def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[np.ndarray]) -> np.ndarray: + async def async_to_python_value( + self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[np.ndarray] + ) -> np.ndarray: try: uri = lv.scalar.blob.uri except AttributeError: @@ -96,7 +98,7 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: expected_python_type, metadata = extract_metadata(expected_python_type) local_path = ctx.file_access.get_random_local_path() - ctx.file_access.get_data(uri, local_path, is_multipart=False) + await ctx.file_access.async_get_data(uri, local_path, is_multipart=False) # load numpy array from a file return np.load( diff --git a/flytekit/types/pickle/pickle.py b/flytekit/types/pickle/pickle.py index d26ede7b1b..7b4c99cae6 100644 --- a/flytekit/types/pickle/pickle.py +++ b/flytekit/types/pickle/pickle.py @@ -5,7 +5,7 @@ import cloudpickle from flytekit.core.context_manager import FlyteContext, FlyteContextManager -from flytekit.core.type_engine import TypeEngine, TypeTransformer +from flytekit.core.type_engine import AsyncTypeTransformer, TypeEngine from flytekit.models.core import types as _core_types from flytekit.models.literals import Blob, BlobMetadata, Literal, Scalar from flytekit.models.types import LiteralType @@ -52,7 +52,7 @@ def python_type(cls) -> typing.Type: return _SpecificFormatClass @classmethod - def to_pickle(cls, ctx: FlyteContext, python_val: typing.Any) -> str: + async def to_pickle(cls, ctx: FlyteContext, python_val: typing.Any) -> str: local_dir = ctx.file_access.get_random_local_directory() os.makedirs(local_dir, exist_ok=True) local_path = ctx.file_access.get_random_local_path() @@ -60,23 +60,23 @@ def to_pickle(cls, ctx: FlyteContext, python_val: typing.Any) -> str: with open(uri, "w+b") as outfile: cloudpickle.dump(python_val, outfile) - return ctx.file_access.put_raw_data(uri) + return await ctx.file_access.async_put_raw_data(uri) @classmethod - def from_pickle(cls, uri: str) -> typing.Any: + async def from_pickle(cls, uri: str) -> typing.Any: ctx = FlyteContextManager.current_context() # Deserialize the pickle, and return data in the pickle, # and download pickle file to local first if file is not in the local file systems. if ctx.file_access.is_remote(uri): local_path = ctx.file_access.get_random_local_path() - ctx.file_access.get_data(uri, local_path, False) + await ctx.file_access.async_get_data(uri, local_path, False) uri = local_path with open(uri, "rb") as infile: data = cloudpickle.load(infile) return data -class FlytePickleTransformer(TypeTransformer[FlytePickle]): +class FlytePickleTransformer(AsyncTypeTransformer[FlytePickle]): PYTHON_PICKLE_FORMAT = "PythonPickle" def __init__(self): @@ -86,11 +86,13 @@ def assert_type(self, t: Type[T], v: T): # Every type can serialize to pickle, so we don't need to check the type here. ... - def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T]) -> T: + async def async_to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T]) -> T: uri = lv.scalar.blob.uri - return FlytePickle.from_pickle(uri) + return await FlytePickle.from_pickle(uri) - def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], expected: LiteralType) -> Literal: + async def async_to_literal( + self, ctx: FlyteContext, python_val: T, python_type: Type[T], expected: LiteralType + ) -> Literal: if python_val is None: raise AssertionError("Cannot pickle None Value.") meta = BlobMetadata( @@ -98,7 +100,7 @@ def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], exp format=self.PYTHON_PICKLE_FORMAT, dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE ) ) - remote_path = FlytePickle.to_pickle(ctx, python_val) + remote_path = await FlytePickle.to_pickle(ctx, python_val) return Literal(scalar=Scalar(blob=Blob(metadata=meta, uri=remote_path))) def guess_python_type(self, literal_type: LiteralType) -> typing.Type[FlytePickle[typing.Any]]: diff --git a/flytekit/types/schema/types.py b/flytekit/types/schema/types.py index 5cf8308b03..28a2c542ef 100644 --- a/flytekit/types/schema/types.py +++ b/flytekit/types/schema/types.py @@ -17,7 +17,7 @@ from flytekit.core.constants import MESSAGEPACK from flytekit.core.context_manager import FlyteContext, FlyteContextManager -from flytekit.core.type_engine import TypeEngine, TypeTransformer, TypeTransformerFailedError +from flytekit.core.type_engine import AsyncTypeTransformer, TypeEngine, TypeTransformerFailedError from flytekit.loggers import logger from flytekit.models.literals import Binary, Literal, Scalar, Schema from flytekit.models.types import LiteralType, SchemaType @@ -373,7 +373,7 @@ def _get_numpy_type_mappings() -> typing.Dict[Type, SchemaType.SchemaColumn.Sche return {} -class FlyteSchemaTransformer(TypeTransformer[FlyteSchema]): +class FlyteSchemaTransformer(AsyncTypeTransformer[FlyteSchema]): _SUPPORTED_TYPES: typing.Dict[Type, SchemaType.SchemaColumn.SchemaColumnType] = { float: SchemaType.SchemaColumn.SchemaColumnType.FLOAT, int: SchemaType.SchemaColumn.SchemaColumnType.INTEGER, @@ -406,7 +406,7 @@ def assert_type(self, t: Type[FlyteSchema], v: typing.Any): def get_literal_type(self, t: Type[FlyteSchema]) -> LiteralType: return LiteralType(schema=self._get_schema_type(t)) - def to_literal( + async def async_to_literal( self, ctx: FlyteContext, python_val: FlyteSchema, python_type: Type[FlyteSchema], expected: LiteralType ) -> Literal: if isinstance(python_val, FlyteSchema): @@ -421,7 +421,9 @@ def to_literal( # This means the local path is empty. Don't try to overwrite the remote data logger.debug(f"Skipping upload for {python_val} because it was never downloaded.") else: - remote_path = ctx.file_access.put_data(python_val.local_path, remote_path, is_multipart=True) + remote_path = await ctx.file_access.async_put_data( + python_val.local_path, remote_path, is_multipart=True + ) return Literal(scalar=Scalar(schema=Schema(remote_path, self._get_schema_type(python_type)))) remote_path = ctx.file_access.join(ctx.file_access.raw_output_prefix, ctx.file_access.get_random_string()) @@ -438,7 +440,9 @@ def to_literal( writer = schema.open(type(python_val)) writer.write(python_val) if not h.handles_remote_io: - schema.remote_path = ctx.file_access.put_data(schema.local_path, schema.remote_path, is_multipart=True) + schema.remote_path = await ctx.file_access.async_put_data( + schema.local_path, schema.remote_path, is_multipart=True + ) return Literal(scalar=Scalar(schema=Schema(schema.remote_path, self._get_schema_type(python_type)))) def from_binary_idl(self, binary_idl_object: Binary, expected_python_type: Type[FlyteSchema]) -> FlyteSchema: @@ -458,7 +462,9 @@ def from_binary_idl(self, binary_idl_object: Binary, expected_python_type: Type[ else: raise TypeTransformerFailedError(f"Unsupported binary format: `{binary_idl_object.tag}`") - def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[FlyteSchema]) -> FlyteSchema: + async def async_to_python_value( + self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[FlyteSchema] + ) -> FlyteSchema: # Handle dataclass attribute access if lv.scalar and lv.scalar.binary: return self.from_binary_idl(lv.scalar.binary, expected_python_type) diff --git a/flytekit/types/schema/types_pandas.py b/flytekit/types/schema/types_pandas.py index a7ade2fe46..bff3572cfe 100644 --- a/flytekit/types/schema/types_pandas.py +++ b/flytekit/types/schema/types_pandas.py @@ -5,7 +5,7 @@ import pandas from flytekit import FlyteContext -from flytekit.core.type_engine import T, TypeEngine, TypeTransformer +from flytekit.core.type_engine import AsyncTypeTransformer, T, TypeEngine from flytekit.models.literals import Literal, Scalar, Schema from flytekit.models.types import LiteralType, SchemaType from flytekit.types.schema import LocalIOSchemaReader, LocalIOSchemaWriter, SchemaEngine, SchemaFormat, SchemaHandler @@ -75,7 +75,7 @@ def _write(self, df: T, path: os.PathLike, **kwargs): return self._parquet_engine.write(df, to_file=path, **kwargs) -class PandasDataFrameTransformer(TypeTransformer[pandas.DataFrame]): +class PandasDataFrameTransformer(AsyncTypeTransformer[pandas.DataFrame]): """ Transforms a pd.DataFrame to Schema without column types. """ @@ -91,7 +91,7 @@ def _get_schema_type() -> SchemaType: def get_literal_type(self, t: Type[pandas.DataFrame]) -> LiteralType: return LiteralType(schema=self._get_schema_type()) - def to_literal( + async def async_to_literal( self, ctx: FlyteContext, python_val: pandas.DataFrame, @@ -105,16 +105,16 @@ def to_literal( ctx.file_access.raw_output_prefix, ctx.file_access.get_random_string(), ) - remote_path = ctx.file_access.put_data(local_dir, remote_path, is_multipart=True) + remote_path = await ctx.file_access.async_put_data(local_dir, remote_path, is_multipart=True) return Literal(scalar=Scalar(schema=Schema(remote_path, self._get_schema_type()))) - def to_python_value( + async def async_to_python_value( self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[pandas.DataFrame] ) -> pandas.DataFrame: if not (lv and lv.scalar and lv.scalar.schema): return pandas.DataFrame() local_dir = ctx.file_access.get_random_local_directory() - ctx.file_access.get_data(lv.scalar.schema.uri, local_dir, is_multipart=True) + await ctx.file_access.async_get_data(lv.scalar.schema.uri, local_dir, is_multipart=True) r = PandasSchemaReader(local_dir=local_dir, cols=None, fmt=SchemaFormat.PARQUET) return r.all() diff --git a/flytekit/types/structured/structured_dataset.py b/flytekit/types/structured/structured_dataset.py index 12a1b1ca28..57c028e71c 100644 --- a/flytekit/types/structured/structured_dataset.py +++ b/flytekit/types/structured/structured_dataset.py @@ -19,7 +19,7 @@ from flytekit import lazy_module from flytekit.core.constants import MESSAGEPACK from flytekit.core.context_manager import FlyteContext, FlyteContextManager -from flytekit.core.type_engine import TypeEngine, TypeTransformer, TypeTransformerFailedError +from flytekit.core.type_engine import AsyncTypeTransformer, TypeEngine, TypeTransformerFailedError from flytekit.deck.renderer import Renderable from flytekit.loggers import developer_logger, logger from flytekit.models import literals @@ -399,7 +399,7 @@ def get_supported_types(): class DuplicateHandlerError(ValueError): ... -class StructuredDatasetTransformerEngine(TypeTransformer[StructuredDataset]): +class StructuredDatasetTransformerEngine(AsyncTypeTransformer[StructuredDataset]): """ Think of this transformer as a higher-level meta transformer that is used for all the dataframe types. If you are bringing a custom data frame type, or any data frame type, to flytekit, instead of @@ -594,7 +594,7 @@ def register_for_protocol( def assert_type(self, t: Type[StructuredDataset], v: typing.Any): return - def to_literal( + async def async_to_literal( self, ctx: FlyteContext, python_val: Union[StructuredDataset, typing.Any], @@ -654,7 +654,7 @@ def to_literal( if not uri: raise ValueError(f"If dataframe is not specified, then the uri should be specified. {python_val}") if not ctx.file_access.is_remote(uri): - uri = ctx.file_access.put_raw_data(uri) + uri = await ctx.file_access.async_put_raw_data(uri) sd_model = literals.StructuredDataset( uri=uri, metadata=StructuredDatasetMetadata(structured_dataset_type=sdt), @@ -752,7 +752,7 @@ def from_binary_idl( else: raise TypeTransformerFailedError(f"Unsupported binary format: `{binary_idl_object.tag}`") - def to_python_value( + async def async_to_python_value( self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T] | StructuredDataset ) -> T | StructuredDataset: """ diff --git a/flytekit/utils/asyn.py b/flytekit/utils/asyn.py index c447db052f..d1edb67436 100644 --- a/flytekit/utils/asyn.py +++ b/flytekit/utils/asyn.py @@ -82,7 +82,7 @@ def run_sync(self, coro_func: Callable[..., Awaitable[T]], *args, **kwargs) -> T """ This should be called from synchronous functions to run an async function. """ - name = threading.current_thread().name + name = threading.current_thread().name + f"PID:{os.getpid()}" coro = coro_func(*args, **kwargs) if name not in self._runner_map: if len(self._runner_map) > 500: diff --git a/tests/flytekit/unit/core/test_data.py b/tests/flytekit/unit/core/test_data.py index 2de6e8c196..42e74f453c 100644 --- a/tests/flytekit/unit/core/test_data.py +++ b/tests/flytekit/unit/core/test_data.py @@ -4,6 +4,7 @@ import tempfile from uuid import UUID import typing +import asyncio import fsspec import mock import pytest @@ -17,6 +18,7 @@ from flytekit.types.file import FlyteFile from flytekit.utils.asyn import loop_manager from flytekit.models.literals import Literal +from flytekit.utils.asyn import run_sync local = fsspec.filesystem("file") root = os.path.abspath(os.sep) @@ -111,7 +113,8 @@ def test_local_raw_fsspec(source_folder): assert len(files) == 2 -def test_local_provider(source_folder): +@pytest.mark.asyncio +async def test_local_provider(source_folder): # Test that behavior putting from a local dir to a local remote dir is the same whether or not the local # dest folder exists. dc = Config.for_sandbox().data_config @@ -119,19 +122,20 @@ def test_local_provider(source_folder): provider = FileAccessProvider(local_sandbox_dir="/tmp/unittest", raw_output_prefix=dest_tmpdir, data_config=dc) r = provider.get_random_string() doesnotexist = provider.join(provider.raw_output_prefix, r) - provider.put_data(source_folder, doesnotexist, is_multipart=True) + await provider.async_put_data(source_folder, doesnotexist, is_multipart=True) files = provider.raw_output_fs.find(doesnotexist) assert len(files) == 2 r = provider.get_random_string() exists = provider.join(provider.raw_output_prefix, r) provider.raw_output_fs.mkdir(exists) - provider.put_data(source_folder, exists, is_multipart=True) + await provider.async_put_data(source_folder, exists, is_multipart=True) files = provider.raw_output_fs.find(exists) assert len(files) == 2 -def test_async_file_system(): +@pytest.mark.asyncio +async def test_async_file_system(): remote_path = "test:///tmp/test.py" local_path = "test.py" @@ -161,9 +165,9 @@ async def _lsdir( fsspec.register_implementation("test", MockAsyncFileSystem) ctx = FlyteContextManager.current_context() - dst = ctx.file_access.put(local_path, remote_path) + dst = await ctx.file_access._put(local_path, remote_path) assert dst == remote_path - dst = ctx.file_access.get(remote_path, local_path) + dst = await ctx.file_access.get(remote_path, local_path) assert dst == local_path @@ -492,3 +496,75 @@ def test_async_local_copy_to_s3(): print(f"Time taken: {end_time - start_time}") print(f"Wall time taken: {end_wall_time - start_wall_time}") print(f"Process time taken: {end_process_time - start_process_time}") + + +async def download_files(ffs: typing.List[FlyteFile]): + futures = [asyncio.create_task(ff._download()) for ff in ffs] + return await asyncio.gather(*futures, return_exceptions=True) + + +@pytest.mark.sandbox_test +def test_async_download_from_s3(): + import time + import datetime + + f1 = "/Users/ytong/go/src/github.com/unionai/debugyt/user/ytong/src/yt_dbg/fr/rand.file" + f2 = "/Users/ytong/go/src/github.com/unionai/debugyt/user/ytong/src/yt_dbg/fr/rand2.file" + f3 = "/Users/ytong/go/src/github.com/unionai/debugyt/user/ytong/src/yt_dbg/fr/rand3.file" + + ff1 = FlyteFile(path=f1) + ff2 = FlyteFile(path=f2) + ff3 = FlyteFile(path=f3) + ff = [ff1, ff2, ff3] + + ctx = FlyteContextManager.current_context() + dc = Config.for_sandbox().data_config + random_folder = UUID(int=random.getrandbits(64)).hex + raw_output = f"s3://my-s3-bucket/testing/upload_test/{random_folder}" + print(f"Uploading to {raw_output}") + provider = FileAccessProvider(local_sandbox_dir="/tmp/unittest", raw_output_prefix=raw_output, data_config=dc) + + with FlyteContextManager.with_context(ctx.with_file_access(provider)) as ctx: + lit = TypeEngine.to_literal(ctx, ff, typing.List[FlyteFile], TypeEngine.to_literal_type(typing.List[FlyteFile])) + print(f"Literal is {lit}") + python_list = TypeEngine.to_python_value(ctx, lit, typing.List[FlyteFile]) + + print(f"Serial File list: {python_list}") + + start_time = datetime.datetime.now(datetime.timezone.utc) + start_wall_time = time.perf_counter() + start_process_time = time.process_time() + + for local_file in python_list: + print(f"Downloading {local_file.remote_source} to {local_file.path}") + local_file.download() + + end_time = datetime.datetime.now(datetime.timezone.utc) + end_wall_time = time.perf_counter() + end_process_time = time.process_time() + + print(f"Time taken (serial download): {end_time - start_time}") + print(f"Wall time taken (serial download): {end_wall_time - start_wall_time}") + print(f"Process time taken (serial download): {end_process_time - start_process_time}") + + with FlyteContextManager.with_context(ctx.with_file_access(provider)) as ctx: + lit = TypeEngine.to_literal(ctx, ff, typing.List[FlyteFile], TypeEngine.to_literal_type(typing.List[FlyteFile])) + print(f"Literal is {lit}") + python_list = TypeEngine.to_python_value(ctx, lit, typing.List[FlyteFile]) + + print(f"Async file list: {python_list}") + + start_time = datetime.datetime.now(datetime.timezone.utc) + start_wall_time = time.perf_counter() + start_process_time = time.process_time() + + res = run_sync(download_files, python_list) + print(f"Result is: {res}") + + end_time = datetime.datetime.now(datetime.timezone.utc) + end_wall_time = time.perf_counter() + end_process_time = time.process_time() + + print(f"Time taken (async): {end_time - start_time}") + print(f"Wall time taken (async): {end_wall_time - start_wall_time}") + print(f"Process time taken (async): {end_process_time - start_process_time}") diff --git a/tests/flytekit/unit/core/test_data_persistence.py b/tests/flytekit/unit/core/test_data_persistence.py index 5063e484d2..d992ed1fa5 100644 --- a/tests/flytekit/unit/core/test_data_persistence.py +++ b/tests/flytekit/unit/core/test_data_persistence.py @@ -1,5 +1,6 @@ import io import os +import fsspec import pathlib import random import string @@ -11,6 +12,7 @@ from azure.identity import ClientSecretCredential, DefaultAzureCredential from flytekit.core.data_persistence import FileAccessProvider +from flytekit.core.local_fsspec import FlyteLocalFileSystem def test_get_manual_random_remote_path(): @@ -92,7 +94,8 @@ def test_write_folder_put_raw(mock_uuid_class): assert sorted(paths) == sorted(expected) -def test_write_large_put_raw(): +@pytest.mark.asyncio +async def test_write_large_put_raw(): """ Test that writes a large'ish file setting block size and read size. """ @@ -107,7 +110,7 @@ def test_write_large_put_raw(): sio.seek(0) # Write foo/a.txt by specifying the upload prefix and a file name - fs.put_raw_data(sio, upload_prefix="foo", file_name="a.txt", block_size=5, read_chunk_size_bytes=1) + await fs.async_put_raw_data(sio, upload_prefix="foo", file_name="a.txt", block_size=5, read_chunk_size_bytes=1) output_file = os.path.join(raw, "foo", "a.txt") with open(output_file, "rb") as f: assert f.read() == arbitrary_text.encode("utf-8") @@ -189,3 +192,18 @@ def test_initialise_azure_file_provider_with_default_credential(): fp = FileAccessProvider("/tmp", "abfs://container/path/within/container") assert fp.get_filesystem().account_name == "accountname" assert isinstance(fp.get_filesystem().sync_credential, DefaultAzureCredential) + + +def test_get_file_system(): + # Test that custom args are not swallowed by get_filesystem + + class MockFileSystem(FlyteLocalFileSystem): + def __init__(self, *args, **kwargs): + assert "test_arg" in kwargs + del kwargs["test_arg"] + super().__init__(*args, **kwargs) + + fsspec.register_implementation("testgetfs", MockFileSystem) + + fp = FileAccessProvider("/tmp", "s3://my-bucket") + fp.get_filesystem("testgetfs", test_arg="test_arg") diff --git a/tests/flytekit/unit/core/test_flyte_file.py b/tests/flytekit/unit/core/test_flyte_file.py index 7e09e918ae..352984ca37 100644 --- a/tests/flytekit/unit/core/test_flyte_file.py +++ b/tests/flytekit/unit/core/test_flyte_file.py @@ -650,7 +650,7 @@ def write_this_file_to_s3() -> FlyteFile: ctx = FlyteContextManager.current_context() r = ctx.file_access.get_random_string() dest = ctx.file_access.join(ctx.file_access.raw_output_prefix, r) - ctx.file_access.put(__file__, dest) + ctx.file_access._put(__file__, dest) return FlyteFile(path=dest) @task diff --git a/tests/flytekit/unit/core/test_type_engine.py b/tests/flytekit/unit/core/test_type_engine.py index 2f8e5a8a3e..8721a8d4db 100644 --- a/tests/flytekit/unit/core/test_type_engine.py +++ b/tests/flytekit/unit/core/test_type_engine.py @@ -927,7 +927,7 @@ class TestStructD(DataClassJsonMixin): assert ot == o -@mock.patch("flytekit.core.data_persistence.FileAccessProvider.put_data") +@mock.patch("flytekit.core.data_persistence.FileAccessProvider.async_put_data") def test_dataclass_with_postponed_annotation(mock_put_data): remote_path = "s3://tmp/file" mock_put_data.return_value = remote_path @@ -953,7 +953,7 @@ class Data: assert dict_obj["f"]["path"] == remote_path -@mock.patch("flytekit.core.data_persistence.FileAccessProvider.put_data") +@mock.patch("flytekit.core.data_persistence.FileAccessProvider.async_put_data") def test_optional_flytefile_in_dataclass(mock_upload_dir): mock_upload_dir.return_value = True @@ -1040,7 +1040,7 @@ class TestFileStruct(DataClassJsonMixin): assert o.i_prime == A(a=99) -@mock.patch("flytekit.core.data_persistence.FileAccessProvider.put_data") +@mock.patch("flytekit.core.data_persistence.FileAccessProvider.async_put_data") def test_optional_flytefile_in_dataclassjsonmixin(mock_upload_dir): @dataclass class A_optional_flytefile(DataClassJSONMixin): diff --git a/tests/flytekit/unit/remote/test_fs_remote.py b/tests/flytekit/unit/remote/test_fs_remote.py index efc6e94e8b..5c635376b4 100644 --- a/tests/flytekit/unit/remote/test_fs_remote.py +++ b/tests/flytekit/unit/remote/test_fs_remote.py @@ -109,7 +109,7 @@ def test_remote_upload_with_data_persistence(sandbox_remote): f.write("asdf") f.flush() # Test uploading a file and folder. - res = fp.put(f.name, "flyte://data/", recursive=True) + res = fp._put(f.name, "flyte://data/", recursive=True) # Unlike using the RemoteFS directly, the trailing slash is automatically added by data persistence, # not sure why but preserving the behavior for now. only_file = pathlib.Path(f.name).name diff --git a/tests/flytekit/unit/utils/test_asyn.py b/tests/flytekit/unit/utils/test_asyn.py index db74ac6f53..b8ce75b2a7 100644 --- a/tests/flytekit/unit/utils/test_asyn.py +++ b/tests/flytekit/unit/utils/test_asyn.py @@ -1,3 +1,4 @@ +import os import threading import pytest import asyncio @@ -116,3 +117,8 @@ def test_recursive_calling(): main_ctx.vals["depth"] = 0 assert res == "world" sync_function(6, 6) + + # Check to make sure that the names of the runners have the PID in them. This make the loop manager work with + # things like pytorch elastic. + for k in loop_manager._runner_map.keys(): + assert str(os.getpid()) in k