-
Notifications
You must be signed in to change notification settings - Fork 11
/
distribute.py
63 lines (57 loc) · 1.99 KB
/
distribute.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import ray
from tqdm import tqdm
class Pool:
def __init__(self, actors):
"""
actors: list of ray actor handles
"""
self.actors = actors
assert len(self.actors) > 0
def map(self, exec_fn, iterable):
"""
exec_fn: function to execute actor on each item of iterable
"""
arg_it = iter(iterable)
actor_it = iter(self.actors)
pending_tasks = []
while True:
arg = next(arg_it, None)
if arg is None:
break
actor = next(actor_it, None)
if actor is None:
actor_it = iter(self.actors)
actor = next(actor_it, None)
pending_tasks.append(exec_fn(actor, arg))
return ray.get(pending_tasks)
def map_unordered(self, exec_fn, iterable,
callback_fn=None, desc=None, use_tqdm: bool = True):
"""
exec_fn: function to execute actor on each item of iterable
callback_fn: function to process each result
"""
arg_it = iter(iterable)
actor_it = iter(self.actors)
pending_tasks = []
results = []
while True:
arg = next(arg_it, None)
if arg is None:
break
actor = next(actor_it, None)
if actor is None:
actor_it = iter(self.actors)
actor = next(actor_it, None)
pending_tasks.append(exec_fn(actor, arg))
if use_tqdm:
pbar = tqdm(total=len(pending_tasks), desc=desc,
dynamic_ncols=True, smoothing=0.01)
while len(pending_tasks) > 0:
finished_tasks, pending_tasks = ray.wait(pending_tasks)
for finished_task in finished_tasks:
if use_tqdm:
pbar.update()
results.append(ray.get(finished_task))
if callback_fn is not None:
callback_fn(results[-1])
return results