Skip to content

Commit

Permalink
Async/Batching of coroutines (#2855)
Browse files Browse the repository at this point in the history
Signed-off-by: Yee Hing Tong <[email protected]>
  • Loading branch information
wild-endeavor authored Oct 28, 2024
1 parent 6bf6f8e commit 8021cb5
Showing 1 changed file with 8 additions and 5 deletions.
13 changes: 8 additions & 5 deletions flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import msgpack
from dataclasses_json import DataClassJsonMixin, dataclass_json
from flyteidl.core import literals_pb2
from fsspec.asyn import _run_coros_in_chunks # pylint: disable=W0212
from google.protobuf import json_format as _json_format
from google.protobuf import struct_pb2 as _struct
from google.protobuf.json_format import MessageToDict as _MessageToDict
Expand Down Expand Up @@ -1539,8 +1540,10 @@ async def async_to_literal(
raise TypeTransformerFailedError("Expected a list")

t = self.get_sub_type(python_type)
lit_list = [TypeEngine.async_to_literal(ctx, x, t, expected.collection_type) for x in python_val]
lit_list = await asyncio.gather(*lit_list)
lit_list = [
asyncio.create_task(TypeEngine.async_to_literal(ctx, x, t, expected.collection_type)) for x in python_val
]
lit_list = await _run_coros_in_chunks(lit_list)

return Literal(collection=LiteralCollection(literals=lit_list))

Expand All @@ -1562,7 +1565,7 @@ async def async_to_python_value( # type: ignore

st = self.get_sub_type(expected_python_type)
result = [TypeEngine.async_to_python_value(ctx, x, st) for x in lits]
result = await asyncio.gather(*result)
result = await _run_coros_in_chunks(result)
return result # type: ignore # should be a list, thinks its a tuple

def guess_python_type(self, literal_type: LiteralType) -> list: # type: ignore
Expand Down Expand Up @@ -1968,7 +1971,7 @@ async def async_to_literal(
TypeEngine.async_to_literal(ctx, v, cast(type, v_type), expected.map_value_type)
)

await asyncio.gather(*lit_map.values())
await _run_coros_in_chunks([c for c in lit_map.values()])
for k, v in lit_map.items():
lit_map[k] = v.result()

Expand All @@ -1994,7 +1997,7 @@ async def async_to_python_value(self, ctx: FlyteContext, lv: Literal, expected_p
fut = asyncio.create_task(TypeEngine.async_to_python_value(ctx, v, cast(Type, tp[1])))
py_map[k] = fut

await asyncio.gather(*py_map.values())
await _run_coros_in_chunks([c for c in py_map.values()])
for k, v in py_map.items():
py_map[k] = v.result()

Expand Down

0 comments on commit 8021cb5

Please sign in to comment.