This repository has been archived by the owner on Jan 23, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 2
/
generator.py
66 lines (51 loc) · 1.75 KB
/
generator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
from __future__ import print_function
import numpy as np
import threading
try:
import Queue as queue
except:
import queue
import sys
class Generator(object):
def __init__(self, dataset, ids, batch_size=16, shuffle=False,
buffer_size=32, verbose=0):
self.dataset = dataset
self.ids = np.array(ids)
self.batch_size = batch_size
self.shuffle = shuffle
self.verbose = verbose
self.n_samples = self.ids.size
self.n_batches = int(np.ceil(float(self.n_samples) / self.batch_size))
self.buffer_size = buffer_size
self._i = 0
self._buffer = queue.Queue()
procs = []
for i in range(self.buffer_size):
procs.append(self._buffer_next())
if self.verbose > 0:
sys.stdout.write("Filling generator buffer. ")
sys.stdout.flush()
for proc in procs:
proc.join()
if self.verbose > 0:
print("Done")
def _buffer_next(self):
if self.shuffle and self._i == 0:
np.random.shuffle(self.ids)
batch_ids = self.ids[self._i:self._i + self.batch_size]
self._i = self._i + self.batch_size
if self._i >= self.n_samples:
self._i = 0
proc = threading.Thread(target=self._buffer_next_worker,
args=(batch_ids,))
proc.start()
return proc
def _buffer_next_worker(self, batch_ids):
images_batch, labels_batch = self.dataset.get_outputs(batch_ids)
self._buffer.put([images_batch, labels_batch])
def __iter__(self):
return self
def next(self):
images_batch, labels_batch = self._buffer.get(block=True)
self._buffer_next()
return images_batch, labels_batch