diff --git a/Pilot2/P2B1/p2b1.py b/Pilot2/P2B1/p2b1.py index ad11088b..4b1bb177 100644 --- a/Pilot2/P2B1/p2b1.py +++ b/Pilot2/P2B1/p2b1.py @@ -263,8 +263,8 @@ def datagen(self, epoch=0, print_out=1, test=0): num_frames = X.shape[0] - xt_all = np.array([]) - yt_all = np.array([]) + xt_all = [] + yt_all = [] num_active_frames = random.sample(range(num_frames), int(self.sampling_density*num_frames)) @@ -298,12 +298,11 @@ def datagen(self, epoch=0, print_out=1, test=0): exit() yt = xt.copy() - if not len(xt_all): - xt_all = np.expand_dims(xt, axis=0) - yt_all = np.expand_dims(yt, axis=0) - else: - xt_all = np.append(xt_all, np.expand_dims(xt, axis=0), axis=0) - yt_all = np.append(yt_all, np.expand_dims(yt, axis=0), axis=0) + xt_all.append(np.expand_dims(xt, axis=0)) + yt_all.append(np.expand_dims(yt, axis=0)) + + xt_all = np.concatenate(xt_all) + yt_all = np.concatenate(yt_all) yield files[f_ind], xt_all, yt_all