Skip to content

Commit

Permalink
change stream generator
Browse files Browse the repository at this point in the history
  • Loading branch information
timonmerk committed Oct 8, 2024
1 parent 8c8a4ed commit 60b25cc
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 17 deletions.
25 changes: 12 additions & 13 deletions py_neuromodulation/stream/rawdata_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from .data_generator_abc import DataGeneratorABC
import numpy as np
import pandas as pd
from typing import Tuple
from typing import Generator

class RawDataGenerator(DataGeneratorABC):
"""
Expand Down Expand Up @@ -141,19 +141,18 @@ def add_target(self, feature_dict: "pd.DataFrame", data_batch: np.array) -> None

return feature_dict

def __iter__(self):
return self
def __next__(self) -> Generator[np.ndarray, np.ndarray, None]:
while True:
start = self.stride * self.batch_counter
end = start + self.segment_length

def __next__(self) -> Tuple[np.ndarray, np.ndarray]:
start = self.stride * self.batch_counter
end = start + self.segment_length
self.batch_counter += 1

self.batch_counter += 1
start_idx = int(start)
end_idx = int(end)

start_idx = int(start)
end_idx = int(end)
if end_idx > self.data.shape[1]:
#raise StopIteration
break

if end_idx > self.data.shape[1]:
raise StopIteration

return np.arange(start, end) / self.sfreq, self.data[:, start_idx:end_idx]
yield np.arange(start, end) / self.sfreq, self.data[:, start_idx:end_idx]
7 changes: 3 additions & 4 deletions py_neuromodulation/stream/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from collections.abc import Iterator
import numpy as np

import multiprocessing as mp
from contextlib import suppress

from py_neuromodulation.features import USE_FREQ_RANGES
Expand Down Expand Up @@ -35,7 +34,6 @@ def __init__(
) -> None:
self.verbose = verbose
self.is_running = False


async def run(
self,
Expand Down Expand Up @@ -72,7 +70,7 @@ async def run(
self.is_lslstream = type(data_generator) != RawDataGenerator

prev_batch_end = 0
for timestamps, data_batch in data_generator:
for timestamps, data_batch in next(data_generator):
self.is_running = True
if self.stream_handling_queue is not None:
await asyncio.sleep(0.001)
Expand Down Expand Up @@ -102,7 +100,8 @@ async def run(
if self.verbose:
logger.info("Time: %.2f", feature_dict["time"] / 1000)

feature_dict = data_generator.add_target(feature_dict, data_batch)
if not self.is_lslstream:
feature_dict = data_generator.add_target(feature_dict, data_batch)

with suppress(TypeError): # Need this because some features output None
for key, value in feature_dict.items():
Expand Down

0 comments on commit 60b25cc

Please sign in to comment.