From 8021cb5e5776388135fe15ead19cfd989d5611bd Mon Sep 17 00:00:00 2001 From: Yee Hing Tong Date: Mon, 28 Oct 2024 09:17:29 -0700 Subject: [PATCH] Async/Batching of coroutines (#2855) Signed-off-by: Yee Hing Tong --- flytekit/core/type_engine.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index 30bc43a106..72a59aa82e 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -22,6 +22,7 @@ import msgpack from dataclasses_json import DataClassJsonMixin, dataclass_json from flyteidl.core import literals_pb2 +from fsspec.asyn import _run_coros_in_chunks # pylint: disable=W0212 from google.protobuf import json_format as _json_format from google.protobuf import struct_pb2 as _struct from google.protobuf.json_format import MessageToDict as _MessageToDict @@ -1539,8 +1540,10 @@ async def async_to_literal( raise TypeTransformerFailedError("Expected a list") t = self.get_sub_type(python_type) - lit_list = [TypeEngine.async_to_literal(ctx, x, t, expected.collection_type) for x in python_val] - lit_list = await asyncio.gather(*lit_list) + lit_list = [ + asyncio.create_task(TypeEngine.async_to_literal(ctx, x, t, expected.collection_type)) for x in python_val + ] + lit_list = await _run_coros_in_chunks(lit_list) return Literal(collection=LiteralCollection(literals=lit_list)) @@ -1562,7 +1565,7 @@ async def async_to_python_value( # type: ignore st = self.get_sub_type(expected_python_type) result = [TypeEngine.async_to_python_value(ctx, x, st) for x in lits] - result = await asyncio.gather(*result) + result = await _run_coros_in_chunks(result) return result # type: ignore # should be a list, thinks its a tuple def guess_python_type(self, literal_type: LiteralType) -> list: # type: ignore @@ -1968,7 +1971,7 @@ async def async_to_literal( TypeEngine.async_to_literal(ctx, v, cast(type, v_type), expected.map_value_type) ) - await asyncio.gather(*lit_map.values()) + await _run_coros_in_chunks([c for c in lit_map.values()]) for k, v in lit_map.items(): lit_map[k] = v.result() @@ -1994,7 +1997,7 @@ async def async_to_python_value(self, ctx: FlyteContext, lv: Literal, expected_p fut = asyncio.create_task(TypeEngine.async_to_python_value(ctx, v, cast(Type, tp[1]))) py_map[k] = fut - await asyncio.gather(*py_map.values()) + await _run_coros_in_chunks([c for c in py_map.values()]) for k, v in py_map.items(): py_map[k] = v.result()