Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stream_refactor #362

Open
wants to merge 7 commits into
base: gui
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions examples/plot_0_first_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def generate_random_walk(NUM_CHANNELS, TIME_DATA_SAMPLES):
# DataFrame. There are some helper functions that let you create the
# nm_channels without much effort:

nm_channels = nm.utils.get_default_channels_from_data(data, car_rereferencing=True)
nm_channels = nm.utils.create_default_channels_from_data(data, car_rereferencing=True)

nm_channels

Expand Down Expand Up @@ -135,14 +135,15 @@ def generate_random_walk(NUM_CHANNELS, TIME_DATA_SAMPLES):
# We are now ready to go to instantiate the *Stream* and call the *run* method for feature estimation:

stream = nm.Stream(
data=data,
timonmerk marked this conversation as resolved.
Show resolved Hide resolved
settings=settings,
channels=nm_channels,
verbose=True,
sfreq=sfreq,
line_noise=50,
)

features = stream.run(data, save_csv=True)
features = stream.run(save_csv=True)

# %%
# Feature Analysis
Expand Down
6 changes: 3 additions & 3 deletions examples/plot_1_example_BIDS.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
coord_names,
) = nm.io.read_BIDS_data(PATH_RUN=PATH_RUN)

channels = nm.utils.set_channels(
channels = nm.utils.create_channels(
timonmerk marked this conversation as resolved.
Show resolved Hide resolved
ch_names=raw.ch_names,
ch_types=raw.get_channel_types(),
reference="default",
Expand Down Expand Up @@ -94,6 +94,8 @@

# %%
stream = nm.Stream(
data=data,
experiment_name=RUN_NAME,
sfreq=sfreq,
channels=channels,
settings=settings,
Expand All @@ -105,9 +107,7 @@

# %%
features = stream.run(
data=data,
out_dir=PATH_OUT,
experiment_name=RUN_NAME,
save_csv=True,
)

Expand Down
19 changes: 11 additions & 8 deletions examples/plot_3_example_sharpwave_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@
coord_names,
) = nm.io.read_BIDS_data(PATH_RUN=PATH_RUN)

print(data.shape)

# %%
settings = NMSettings.get_fast_compute()

Expand All @@ -69,7 +71,7 @@
for sw_feature in settings.sharpwave_analysis_settings.sharpwave_features.list_all():
settings.sharpwave_analysis_settings.estimator["mean"].append(sw_feature)

channels = nm.utils.set_channels(
channels = nm.utils.create_channels(
ch_names=raw.ch_names,
ch_types=raw.get_channel_types(),
reference="default",
Expand All @@ -79,7 +81,9 @@
target_keywords=["MOV_RIGHT"],
)

stream = nm.Stream(
data_plt = data[5, 1000:4000]

data_processor = nm.DataProcessor(
sfreq=sfreq,
channels=channels,
settings=settings,
Expand All @@ -88,14 +92,12 @@
coord_names=coord_names,
verbose=False,
)
sw_analyzer = cast(
SharpwaveAnalyzer, stream.data_processor.features.get_feature("sharpwave_analysis")
)

sw_analyzer = data_processor.features.get_feature("sharpwave_analysis")


# %%
# The plotted example time series, visualized on a short time scale, shows the relation of identified peaks, troughs, and estimated features:
data_plt = data[5, 1000:4000]

filtered_dat = fftconvolve(data_plt, sw_analyzer.list_filter[0][1], mode="same")

troughs = signal.find_peaks(-filtered_dat, distance=10)[0]
Expand Down Expand Up @@ -297,6 +299,7 @@
channels.loc[[3, 8], "used"] = 1

stream = nm.Stream(
data=data[:, :30000],
sfreq=sfreq,
channels=channels,
settings=settings,
Expand All @@ -306,7 +309,7 @@
verbose=True,
)

df_features = stream.run(data=data[:, :30000], save_csv=True)
df_features = stream.run(save_csv=True)

# %%
# We can then plot two exemplary features, prominence and interval, and see that the movement amplitude can be clustered with those two features alone:
Expand Down
6 changes: 3 additions & 3 deletions examples/plot_4_example_gridPointProjection.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@

settings.postprocessing.project_cortex = True

channels = nm.utils.set_channels(
channels = nm.utils.create_channels(
ch_names=raw.ch_names,
ch_types=raw.get_channel_types(),
reference="default",
Expand All @@ -65,6 +65,8 @@
)

stream = nm.Stream(
data=data[:, : int(sfreq * 5)],
experiment_name=RUN_NAME,
sfreq=sfreq,
channels=channels,
settings=settings,
Expand All @@ -75,9 +77,7 @@
)

features = stream.run(
data=data[:, : int(sfreq * 5)],
out_dir=PATH_OUT,
experiment_name=RUN_NAME,
save_csv=True,
)

Expand Down
1 change: 1 addition & 0 deletions examples/plot_6_real_time_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ def get_fast_compute_settings():
print("Computation time for single ECoG channel: ")
data = np.random.random([1, 1000])
stream = nm.Stream(sfreq=1000, data=data, sampling_rate_features_hz=10, verbose=False)

print(
f"{np.round(timeit.timeit(lambda: stream.data_processor.process(data), number=10)/10, 3)} s"
)
Expand Down
17 changes: 9 additions & 8 deletions examples/plot_7_lsl_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# %%
from matplotlib import pyplot as plt
import py_neuromodulation as nm
import time

# %%
# Let’s get the example data from the provided BIDS dataset and create the channels DataFrame.
Expand All @@ -32,7 +33,7 @@
coord_names,
) = nm.io.read_BIDS_data(PATH_RUN=PATH_RUN)

channels = nm.utils.set_channels(
channels = nm.utils.create_channels(
ch_names=raw.ch_names,
ch_types=raw.get_channel_types(),
reference="default",
Expand Down Expand Up @@ -61,6 +62,9 @@
player = nm.stream.LSLOfflinePlayer(raw=raw, stream_name="example_stream")

player.start_player(chunk_size=30)

time.sleep(2) # Wait for stream to start

# %%
# Creating the LSLStream object
# -----------------------------
Expand All @@ -78,6 +82,9 @@
# %%
stream = nm.Stream(
sfreq=sfreq,
experiment_name=RUN_NAME,
is_stream_lsl=True,
stream_lsl_name="example_stream",
channels=channels,
settings=settings,
coord_list=coord_list,
Expand All @@ -87,13 +94,7 @@
# %%
# We then simply have to set the `stream_lsl` parameter to be `True` and specify the `stream_lsl_name`.

features = stream.run(
is_stream_lsl=True,
plot_lsl=False,
stream_lsl_name="example_stream",
out_dir=PATH_OUT,
experiment_name=RUN_NAME,
)
features = stream.run(out_dir=PATH_OUT)

# %%
# We can then look at the computed features and check if the streamed data was processed correctly.
Expand Down
7 changes: 0 additions & 7 deletions py_neuromodulation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import platform
from pathlib import PurePath
from importlib.metadata import version
from py_neuromodulation.utils.logging import NMLogger

#####################################
# Globals and environment variables #
Expand Down Expand Up @@ -57,12 +56,6 @@
user_features = {}


######################################
# Logger initialization and settings #
######################################

logger = NMLogger(__name__) # logger initialization first to prevent circular import

####################################
# API: Exposed classes and methods #
####################################
Expand Down
2 changes: 1 addition & 1 deletion py_neuromodulation/analysis/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from pathlib import PurePath
import pickle

from py_neuromodulation import logger
from py_neuromodulation.utils import logger

from typing import Callable

Expand Down
2 changes: 1 addition & 1 deletion py_neuromodulation/analysis/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from matplotlib import gridspec
import seaborn as sb
from pathlib import PurePath
from py_neuromodulation import logger, PYNM_DIR
from py_neuromodulation.utils import logger, PYNM_DIR
from py_neuromodulation.utils.types import _PathLike


Expand Down
1 change: 1 addition & 0 deletions py_neuromodulation/features/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,5 @@
FeatureProcessors,
add_custom_feature,
remove_custom_feature,
USE_FREQ_RANGES,
)
2 changes: 1 addition & 1 deletion py_neuromodulation/features/coherence.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
FrequencyRange,
NMBaseModel,
)
from py_neuromodulation import logger
from py_neuromodulation.utils import logger

if TYPE_CHECKING:
from py_neuromodulation import NMSettings
Expand Down
10 changes: 10 additions & 0 deletions py_neuromodulation/features/feature_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,16 @@
import numpy as np
from py_neuromodulation import NMSettings

USE_FREQ_RANGES: list[FeatureName] = [
"bandpass_filter",
"stft",
"fft",
"welch",
"bursts",
"coherence",
"nolds",
"bispectrum",
]

FEATURE_DICT: dict[FeatureName | str, str] = {
"raw_hjorth": "Hjorth",
Expand Down
2 changes: 1 addition & 1 deletion py_neuromodulation/filter/notch_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import cast

from py_neuromodulation.utils.types import NMPreprocessor
from py_neuromodulation import logger
from py_neuromodulation.utils import logger


class NotchFilter(NMPreprocessor):
Expand Down
29 changes: 4 additions & 25 deletions py_neuromodulation/gui/backend/app_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
Query,
WebSocket,
)
from fastapi.responses import FileResponse
from fastapi.staticfiles import StaticFiles
from fastapi.middleware.cors import CORSMiddleware

Expand Down Expand Up @@ -71,7 +70,9 @@ def __init__(

def push_features_to_frontend(self, feature_queue: Queue) -> None:
while True:
time.sleep(0.002) # NOTE: should be adapted depending on feature sampling rate
time.sleep(
0.002
) # NOTE: should be adapted depending on feature sampling rate
if feature_queue.empty() is False:
self.logger.info("data in feature queue")
features = feature_queue.get()
Expand Down Expand Up @@ -231,7 +232,6 @@ async def setup_offline_stream(data: dict):
#######################

@self.get("/api/app-info")
# TODO: fix this function
async def get_app_info():
metadata = importlib.metadata.metadata("py_neuromodulation")
url_list = metadata.get_all("Project-URL")
Expand Down Expand Up @@ -353,25 +353,4 @@ def quick_access():
###########################
@self.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
# if self.websocket_manager.is_connected:
# self.logger.info(
# "WebSocket connection attempted while already connected"
# )
# await websocket.close(
# code=1008, reason="Another client is already connected"
# )
# return

await self.websocket_manager.connect(websocket)
# # #######################
# # ### SPA ENTRY POINT ###
# # #######################
# if not self.dev:

# @self.get("/app/{full_path:path}")
# async def serve_spa(request, full_path: str):
# # Serve the index.html for any path that doesn't match an API route
# print(Path.cwd())
# return FileResponse("frontend/index.html")


await self.websocket_manager.connect(websocket)
Loading
Loading