From 576a9ad65e98feefe0056a7b2b814f4e09b07241 Mon Sep 17 00:00:00 2001 From: lauraporta Date: Tue, 2 Jul 2024 17:37:49 +0100 Subject: [PATCH] wip --- derotation/analysis/full_rotation_pipeline.py | 41 +- .../analysis/incremental_rotation_pipeline.py | 2 +- derotation/config/full_rotation.yml | 47 ++- derotation/config/incremental_rotation.yml | 16 +- docs/source/conf.py | 18 +- .../improve_video_using_suite2p_metrics.py | 71 +++- examples/read_suite2p_output.py | 394 +++++++++++++++++- 7 files changed, 518 insertions(+), 71 deletions(-) diff --git a/derotation/analysis/full_rotation_pipeline.py b/derotation/analysis/full_rotation_pipeline.py index dd9895d..daf99bb 100644 --- a/derotation/analysis/full_rotation_pipeline.py +++ b/derotation/analysis/full_rotation_pipeline.py @@ -15,7 +15,7 @@ from sklearn.mixture import GaussianMixture from tifffile import imsave -from derotation.derotate_by_line import rotate_an_image_array_line_by_line +from derotation.derotate_by_line_plot import rotate_an_image_array_line_by_line from derotation.load_data.custom_data_loaders import ( get_analog_signals, read_randomized_stim_table, @@ -55,7 +55,7 @@ def __call__(self): - adding a circular mask to the rotated image stack - saving the masked image stack """ - self.contrast_enhancement() + # self.contrast_enhancement() self.process_analog_signals() rotated_images = self.rotate_frames_line_by_line() masked = self.add_circle_mask(rotated_images, self.mask_diameter) @@ -236,18 +236,29 @@ def process_analog_signals(self): ) self.rotation_on = self.create_signed_rotation_array() - self.drop_ticks_outside_of_rotation() + if self.adjust_increment: + self.drop_ticks_outside_of_rotation() + self.check_number_of_rotations() - self.check_number_of_rotations() - if not self.is_number_of_ticks_correct() and self.adjust_increment: - ( - self.corrected_increments, - self.ticks_per_rotation, - ) = self.adjust_rotation_increment() + if not self.is_number_of_ticks_correct(): + ( + self.corrected_increments, + self.ticks_per_rotation, + ) = self.adjust_rotation_increment() + else: + self.corrected_increments = [ + self.rotation_increment + ] * self.number_of_rotations + self.ticks_per_rotation = ( + self.rot_deg + * self.rotation_increment + * self.number_of_rotations + ) self.interpolated_angles = self.get_interpolated_angles() - self.remove_artifacts_from_interpolated_angles() + if self.adjust_increment: + self.remove_artifacts_from_interpolated_angles() ( self.line_start, @@ -419,7 +430,7 @@ def drop_ticks_outside_of_rotation(self) -> np.ndarray: inter_roatation_interval = [ idx - for i in range(self.number_of_rotations + 1) + for i in range(len(edited_ends)) for idx in range( edited_ends[i], rolled_starts[i], @@ -457,8 +468,8 @@ def check_number_of_rotations(self): raise ValueError( "Start and end of rotations have different lengths" ) - if self.rot_blocks_idx["start"].shape[0] != self.number_of_rotations: - raise ValueError("Number of rotations is not as expected") + # if self.rot_blocks_idx["start"].shape[0] != self.number_of_rotations: + # raise ValueError("Number of rotations is not as expected") logging.info("Number of rotations is as expected") @@ -550,7 +561,7 @@ def get_interpolated_angles(self) -> np.ndarray: ticks_with_increment = [ item - for i in range(self.number_of_rotations) + for i in range(len(self.corrected_increments)) for item in [self.corrected_increments[i]] * self.ticks_per_rotation[i] ] @@ -599,7 +610,7 @@ def remove_artifacts_from_interpolated_angles(self): rotation_end = np.where(np.diff(thresholded) < 0)[0] assert len(rotation_start) == len(rotation_end) - assert len(rotation_start) == self.number_of_rotations + # assert len(rotation_start) == self.number_of_rotations for i, (start, end) in enumerate( zip(rotation_start[1:], rotation_end[:-1]) diff --git a/derotation/analysis/incremental_rotation_pipeline.py b/derotation/analysis/incremental_rotation_pipeline.py index ebe4f7a..ce4e290 100644 --- a/derotation/analysis/incremental_rotation_pipeline.py +++ b/derotation/analysis/incremental_rotation_pipeline.py @@ -36,7 +36,7 @@ def __call__(self): After processing the analog signals, the image stack is rotated by frame and then registered using phase cross correlation. """ - self.contrast_enhancement() + # self.contrast_enhancement() super().process_analog_signals() rotated_images = self.roatate_by_frame() masked_unregistered = self.add_circle_mask(rotated_images) diff --git a/derotation/config/full_rotation.yml b/derotation/config/full_rotation.yml index 56f4bb5..8b6a46c 100644 --- a/derotation/config/full_rotation.yml +++ b/derotation/config/full_rotation.yml @@ -1,28 +1,35 @@ paths_read: - path_to_randperm: "your_path_to/stimlus_randperm.mat" - path_to_aux: "your_path_to/rotation.bin" - path_to_tif: "your_path_to/rotation.tif" + path_to_randperm: /Users/lauraporta/local_data/rotation/stimlus_randperm.mat + path_to_aux: /Users/lauraporta/local_data/rotation/230731_25_micron_grid/aux_stim/230731_grid_1_001.bin + path_to_tif: /Users/lauraporta/local_data/rotation/230731_25_micron_grid/imaging/rotation_zf2_all_speeds_00002_enhanced.tif paths_write: - debug_plots_folder: "your_path_to/debug_plots/" - logs_folder: "your_path_to/logs/" - derotated_tiff_folder: "your_path_to/data_folder/" - saving_name: "derotated_image_stack" + debug_plots_folder: /Users/lauraporta/local_data/rotation/230731_25_micron_grid/full/debug_plots/ + logs_folder: /Users/lauraporta/local_data/rotation/230731_25_micron_grid/full/logs/ + derotated_tiff_folder: /Users/lauraporta/local_data/rotation/230731_25_micron_grid/full/derotated/ + saving_name: derotated_image_stack_CE +# paths_read: +# path_to_randperm: /Users/lauraporta/local_data/rotation/stimlus_randperm.mat +# path_to_aux: /Users/lauraporta/local_data/rotation/230802_CAA_1120182/aux_stim/230802_CAA_1120182_rotation_1_001.bin +# path_to_tif: /Users/lauraporta/local_data/rotation/230802_CAA_1120182/imaging/rotation_00001.tif +# paths_write: +# debug_plots_folder: /Users/lauraporta/local_data/rotation/230802_CAA_1120182/full/debug_plots/ +# logs_folder: /Users/lauraporta/local_data/rotation/230802_CAA_1120182/full/logs/ +# derotated_tiff_folder: /Users/lauraporta/local_data/rotation/230802_CAA_1120182/full/derotated/ +# saving_name: derotated_image_stack_NO_ce - -channel_names: [ - "camera", - "scanimage_frameclock", - "scanimage_lineclock", - "photodiode2", - "PI_rotON", - "PI_rotticks", -] +channel_names: + - camera + - scanimage_frameclock + - scanimage_lineclock + - photodiode2 + - PI_rotON + - PI_rotticks rotation_increment: 0.2 -adjust_increment: True +adjust_increment: false rot_deg: 360 -debugging_plots: True +debugging_plots: false contrast_enhancement: 0.35 @@ -35,5 +42,5 @@ analog_signals_processing: angle_interpolation_artifact_threshold: 0.15 interpolation: - line_use_start: True - frame_use_start: True + line_use_start: true + frame_use_start: true \ No newline at end of file diff --git a/derotation/config/incremental_rotation.yml b/derotation/config/incremental_rotation.yml index 56f4bb5..15fc89d 100644 --- a/derotation/config/incremental_rotation.yml +++ b/derotation/config/incremental_rotation.yml @@ -1,12 +1,12 @@ paths_read: - path_to_randperm: "your_path_to/stimlus_randperm.mat" - path_to_aux: "your_path_to/rotation.bin" - path_to_tif: "your_path_to/rotation.tif" + path_to_randperm: "/Users/lauraporta/local_data/rotation/stimlus_randperm.mat" + path_to_aux: "/Users/lauraporta/local_data/rotation/230802_CAA_1120182/aux_stim/230802_CAA_1120182_rotationincrement_1_001.bin" + path_to_tif: "/Users/lauraporta/local_data/rotation/230802_CAA_1120182/imaging/rotation_increment_00001.tif" paths_write: - debug_plots_folder: "your_path_to/debug_plots/" - logs_folder: "your_path_to/logs/" - derotated_tiff_folder: "your_path_to/data_folder/" - saving_name: "derotated_image_stack" + debug_plots_folder: "/Users/lauraporta/local_data/rotation/230802_CAA_1120182/incremental/debug_plots/" + logs_folder: "/Users/lauraporta/local_data/rotation/230802_CAA_1120182/incremental/logs/" + derotated_tiff_folder: "/Users/lauraporta/local_data/rotation/230802_CAA_1120182/incremental/derotated/" + saving_name: "derotated_incremental_image_stack_NO_ce" channel_names: [ @@ -22,7 +22,7 @@ rotation_increment: 0.2 adjust_increment: True rot_deg: 360 -debugging_plots: True +debugging_plots: False contrast_enhancement: 0.35 diff --git a/docs/source/conf.py b/docs/source/conf.py index ef73a02..3c70ebd 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -75,22 +75,22 @@ # Automatically generate stub pages for API autosummary_generate = True numpydoc_class_members_toctree = False # stops stubs warning -#toc_object_entries_show_parents = "all" +# toc_object_entries_show_parents = "all" html_show_sourcelink = False -#html_sidebars = { this is not working... +# html_sidebars = { this is not working... # "index": [], # "**": [], -#} +# } autodoc_default_options = { - 'members': True, + "members": True, "member-order": "bysource", - 'special-members': False, - 'private-members': False, - 'inherited-members': False, - 'undoc-members': True, - 'exclude-members': "", + "special-members": False, + "private-members": False, + "inherited-members": False, + "undoc-members": True, + "exclude-members": "", } # List of patterns, relative to source directory, that match files and diff --git a/examples/improve_video_using_suite2p_metrics.py b/examples/improve_video_using_suite2p_metrics.py index fbe0059..dabcdfc 100644 --- a/examples/improve_video_using_suite2p_metrics.py +++ b/examples/improve_video_using_suite2p_metrics.py @@ -6,35 +6,62 @@ from suite2p.io import BinaryFile from suite2p.registration.nonrigid import make_blocks, spatial_taper -from derotation.analysis.incremental_rotation_pipeline import ( - IncrementalPipeline, -) - -derotator = IncrementalPipeline("incremental_rotation") -derotator() +# derotator = IncrementalPipeline("incremental_rotation") +# derotator() -# extract luminance variations across angles -zero_rotation_mean_image = derotator.get_target_image(derotator.masked) -mean_images_across_angles = derotator.calculate_mean_images(derotator.masked) +# # extract luminance variations across angles +# zero_rotation_mean_image = derotator.get_target_image(derotator.masked) +# mean_images_across_angles = derotator.calculate_mean_images(derotator.masked) # load registered bin file of suite2p path_to_bin_file = Path( "/Users/lauraporta/local_data/rotation/230802_CAA_1120182/incremental/derotated/suite2p/plane0/data.bin" ) -shape_image = zero_rotation_mean_image.shape +shape_image = [256, 256] registered = BinaryFile( Ly=shape_image[0], Lx=shape_image[0], filename=path_to_bin_file +).file + +time, x, y = registered.shape +registered = registered.reshape(time, x * y) +model = PCA(n_components=10, random_state=0).fit(registered) + + +print("debug") + + +path_options = Path("/Users/lauraporta/local_data/laura_ops.npy") +ops = np.load(path_options, allow_pickle=True).item() + +bin_size = int( + max(1, ops["nframes"] // ops["nbinned"], np.round(ops["tau"] * ops["fs"])) ) -plt.imshow(registered[0]) -plt.show() -# load options +with BinaryFile(filename=path_to_bin_file, Ly=ops["Ly"], Lx=ops["Lx"]) as f: + registered = f.bin_movie( + bin_size=bin_size, + bad_frames=ops.get("badframes"), + y_range=ops["yrange"], + x_range=ops["xrange"], + ) + +mov_mean = registered.mean(axis=0) +registered -= mov_mean +model = PCA(n_components=8, random_state=0).fit(registered) + +result = (registered @ model.components_.T) @ model.components_ + + +# use PCA as in the suite2p code - this takes in account for non rigid +# registration, that we are not going to do path_options = Path("/Users/lauraporta/local_data/laura_ops.npy") -ops = np.load(path_options) +ops = np.load(path_options, allow_pickle=True).item() + +mov_mean = registered.mean(axis=0) +registered -= mov_mean -# use PCA as in the suite2p code block_size = [ops["block_size"][0] // 2, ops["block_size"][1] // 2] nframes, Ly, Lx = registered.shape yblock, xblock, _, block_size, _ = make_blocks(Ly, Lx, block_size=block_size) @@ -54,5 +81,19 @@ block_re[i] = (block @ model.components_.T) @ model.components_ norm[yblock[i][0] : yblock[i][-1], xblock[i][0] : xblock[i][-1]] += maskMul +reconstruction = np.zeros_like(registered) + +block_re = block_re.reshape(nblocks, nframes, Lyb, Lxb) +block_re *= maskMul +for i in range(nblocks): + reconstruction[ + :, yblock[i][0] : yblock[i][-1], xblock[i][0] : xblock[i][-1] + ] += block_re[i] +reconstruction /= norm +reconstruction += mov_mean + +plt.imshow(norm) +plt.show() + print("debug") diff --git a/examples/read_suite2p_output.py b/examples/read_suite2p_output.py index c45b191..94aee0a 100644 --- a/examples/read_suite2p_output.py +++ b/examples/read_suite2p_output.py @@ -1,12 +1,18 @@ +# %% from pathlib import Path import allensdk.brain_observatory.dff as dff_module +import matplotlib.pyplot as plt import numpy as np import pandas as pd import seaborn as sns from allensdk.brain_observatory.r_neuropil import NeuropilSubtract from scipy.io import loadmat +saving_path = Path( + "/Users/lauraporta/Source/github/neuroinformatics-unit/derotation/examples/figures/pdf/" +) + def neuropil_subtraction(f, f_neu): # use default parameters for all methods @@ -27,13 +33,13 @@ def neuropil_subtraction(f, f_neu): F_path = Path( - "/Users/lauraporta/local_data/rotation/230822_CAA_1120509/suite2p/plane0/F.npy" + "/Users/lauraporta/local_data/rotation/230802_CAA_1120182/full/derotated/CE/suite2p/plane0/F.npy" ) f = np.load(F_path) print(f.shape) Fneu_path = Path( - "/Users/lauraporta/local_data/rotation/230822_CAA_1120509/suite2p/plane0/Fneu.npy" + "/Users/lauraporta/local_data/rotation/230802_CAA_1120182/full/derotated/CE/suite2p/plane0/Fneu.npy" ) fneu = np.load(Fneu_path) print(fneu.shape) @@ -61,7 +67,7 @@ def neuropil_subtraction(f, f_neu): rotated_frames_path = Path( - "/Users/lauraporta/local_data/rotation/230822_CAA_1120509/derotated_image_stack_full.csv" + "/Users/lauraporta/local_data/rotation/230802_CAA_1120182/full/derotated/CE/derotated_image_stack_CE.csv" ) rotated_frames = pd.read_csv(rotated_frames_path) print(rotated_frames.head()) @@ -69,6 +75,8 @@ def neuropil_subtraction(f, f_neu): full_dataframe = pd.concat([dff, rotated_frames], axis=1) +# %% + subset = full_dataframe[ (full_dataframe["speed"] == 100) & (full_dataframe["direction"] == -1) ] @@ -101,3 +109,383 @@ def neuropil_subtraction(f, f_neu): print("debug") + +# %% + +# find where do rotations start +rotation_on = np.diff(full_dataframe["rotation_count"]) + + +def find_zero_chunks(arr): + zero_chunks = [] + start = None + + for i in range(len(arr)): + if arr[i] == 0 and start is None: + start = i + elif arr[i] != 0 and start is not None: + zero_chunks.append((start, i - 1)) + start = None + + # Check if the array ends with a chunk of zeros + if start is not None: + zero_chunks.append((start, len(arr) - 1)) + + return zero_chunks + + +starts_ends = find_zero_chunks(rotation_on) + +frames_before_rotation = 10 +# frames_after_rotation = 10 + +total_len = 70 + +full_dataframe["rotation_frames"] = np.zeros(len(full_dataframe)) +for i, (start, end) in enumerate(starts_ends): + frame_array = np.arange(total_len) + column_index_of_rotation_frames = full_dataframe.columns.get_loc( + "rotation_frames" + ) + full_dataframe.iloc[ + start + - frames_before_rotation : total_len + + start + - frames_before_rotation, + column_index_of_rotation_frames, + ] = frame_array + + # extend this value of speed and direction to all this range + this_speed = full_dataframe.loc[start, "speed"] + this_direction = full_dataframe.loc[start, "direction"] + + full_dataframe.iloc[ + start + - frames_before_rotation : total_len + + start + - frames_before_rotation, + full_dataframe.columns.get_loc("speed"), + ] = this_speed + full_dataframe.iloc[ + start + - frames_before_rotation : total_len + + start + - frames_before_rotation, + full_dataframe.columns.get_loc("direction"), + ] = this_direction + + +# directtion, change -1 to CCW and 1 to CW +full_dataframe["direction"] = np.where( + full_dataframe["direction"] == -1, "CCW", "CW" +) + + +# %% +# Single traces for every ROI +selected_range = (400, 2000) + +for roi in range(11): + roi_selected = full_dataframe.loc[ + :, [roi, "rotation_count", "speed", "direction"] + ] + + fig, ax = plt.subplots(figsize=(27, 5)) + ax.plot(roi_selected.loc[selected_range[0] : selected_range[1], roi]) + ax.set(xlabel="Frames", ylabel="ΔF/F") + + rotation_on = ( + np.diff( + roi_selected.loc[ + selected_range[0] : selected_range[1], "rotation_count" + ] + ) + == 0 + ) + + # add label at the beginning of every block of rotations + # if the previous was true, do not write the label + for i, rotation in enumerate(rotation_on): + if rotation and not rotation_on[i - 1]: + ax.text( + i + selected_range[0] + 3, + -500, + f"{int(roi_selected.loc[i + 5 + selected_range[0], 'speed'])}º/s\n{roi_selected.loc[i + 5 + selected_range[0], 'direction']}", + fontsize=10, + ) + + # add gray squares when the rotation is happening using the starst_ends + for start, end in starts_ends: + if start > selected_range[0] and end < selected_range[1]: + ax.axvspan(start, end, color="gray", alpha=0.2) + + fps = 6.74 + # change xticks to seconds + xticks = ax.get_xticks() + ax.set_xticks(xticks) + ax.set_xticklabels((xticks / fps).astype(int)) + # change x label + ax.set(xlabel="Seconds", ylabel="ΔF/F") + + ax.set_xlim(selected_range) + ax.set_ylim(-300, 300) + + # leave some gap between the axis and the plot + plt.subplots_adjust(left=0.1, right=0.9, top=0.9, bottom=0.1) + + # remove top and right spines + ax.spines["top"].set_visible(False) + ax.spines["right"].set_visible(False) + + plt.savefig(saving_path / f"dff_example_{roi}.pdf") + plt.close() + + +# %% +rois_selection = list(range(11)) + +# do it with seaborn +for roi in rois_selection: + fig, ax = plt.subplots(2, 2, figsize=(27, 10)) + for i, speed in enumerate([50, 100, 150, 200]): + sns.lineplot( + x="rotation_frames", + y=roi, + data=full_dataframe[(full_dataframe["speed"] == speed)], + hue="direction", + ax=ax[i // 2, i % 2], + ) + ax[i // 2, i % 2].set(xlabel="Frames", ylabel="ΔF/F") + ax[i // 2, i % 2].set_title(f"Speed: {speed}º/s") + ax[i // 2, i % 2].legend(title="Direction") + + # remove top and right spines + ax[i // 2, i % 2].spines["top"].set_visible(False) + ax[i // 2, i % 2].spines["right"].set_visible(False) + + # add vertical lines to show the start of the rotation + # start is always at 11, end at total len - 10 + ax[i // 2, i % 2].axvline( + x=frames_before_rotation, color="gray", linestyle="--" + ) + this_x_axis_len = ax[i // 2, i % 2].get_xlim()[1] + ax[i // 2, i % 2].axvline( + x=this_x_axis_len - frames_before_rotation, + color="gray", + linestyle="--", + ) + # to seconds + fps = 6.74 + xticks = ax[i // 2, i % 2].get_xticks() + ax[i // 2, i % 2].set_xticks(xticks) + ax[i // 2, i % 2].set_xticklabels(np.round(xticks / fps, 1)) + # change x label + ax[i // 2, i % 2].set(xlabel="Seconds", ylabel="ΔF/F") + + # unique y scale for all + # ax[i // 2, i % 2].set_ylim(-100, 100) + + plt.savefig(saving_path / f"roi_{roi}_speed_direction.pdf") + plt.close() + +# %% +# now another similar plot. two subplots. On the left CW, right CCW +# hue is speed. +rois_selection = list(range(11)) + + +custom_palette = sns.color_palette("dark:#5A9_r", 4) + +for roi in rois_selection: + fig, ax = plt.subplots(1, 2, figsize=(20, 10)) + for i, direction in enumerate(["CW", "CCW"]): + sns.lineplot( + x="rotation_frames", + y=roi, + data=full_dataframe[(full_dataframe["direction"] == direction)], + hue="speed", + palette=custom_palette, + ax=ax[i], + ) + ax[i].set_title(f"Direction: {direction}") + ax[i].legend(title="Speed") + + # remove top and right spines + ax[i].spines["top"].set_visible(False) + ax[i].spines["right"].set_visible(False) + + # add vertical lines to show the start of the rotation + # start is always at 11, end at total len - 10 + ax[i].axvline(x=frames_before_rotation, color="gray", linestyle="--") + + # change x axis to seconds + fps = 6.74 + xticks = ax[i].get_xticks() + ax[i].set_xticks(xticks) + ax[i].set_xticklabels(np.round(xticks / fps, 1)) + # change x label + ax[i].set(xlabel="Seconds", ylabel="ΔF/F") + + plt.savefig(saving_path / f"roi_{roi}_direction_speed.pdf") + plt.close() + + +# %% +# Same plot as above but only for clockwise rotations +# columns: different speeds +# rows: different rois [0, 8, 9] + +rois_selection = [0, 8, 9] +custom_palette = sns.color_palette("dark:#5A9_r", 4) + +for roi in rois_selection: + fig, ax = plt.subplots(1, 4, figsize=(20, 10)) + for i, speed in enumerate([50, 100, 150, 200]): + sns.lineplot( + x="rotation_frames", + y=roi, + data=full_dataframe[ + (full_dataframe["direction"] == "CW") + & (full_dataframe["speed"] == speed) + ], + ax=ax[i], + ) + ax[i].set_title(f"Speed: {speed}º/s") + ax[i].legend(title="Direction") + + # remove top and right spines + ax[i].spines["top"].set_visible(False) + ax[i].spines["right"].set_visible(False) + + # add vertical lines to show the start of the rotation + # start is always at 11, end at total len - 10 + ax[i].axvline(x=frames_before_rotation, color="gray", linestyle="--") + + # gray box + + # change x axis to seconds + fps = 6.74 + xticks = ax[i].get_xticks() + ax[i].set_xticks(xticks) + ax[i].set_xticklabels(np.round(xticks / fps, 1)) + # change x label + ax[i].set(xlabel="Seconds", ylabel="ΔF/F") + + # ylim: -70 : 320 + ax[i].set_ylim(-70, 320) + + plt.savefig(saving_path / f"roi_{roi}_CW_speed.pdf") + plt.close() + +# for the same selection, take the peak response and the std at the peak and plot them together in +# a scatter plot (x: speed, y: peak response, color: roi) + +# get the peak response and the std at the peak +peak_response = pd.DataFrame() +for roi in rois_selection: + for speed in [50, 100, 150, 200]: + subset = full_dataframe[ + (full_dataframe["direction"] == "CW") + & (full_dataframe["speed"] == speed) + ] + all_peaks = subset.loc[:, roi].max() + peak_response = peak_response.append( + { + "roi": roi, + "speed": speed, + "peak_response": np.max(all_peaks), + "std": subset.loc[:, roi].std(), + }, + ignore_index=True, + ) + + +fig, ax = plt.subplots(figsize=(20, 10)) +sns.scatterplot( + x="speed", + y="peak_response", + data=peak_response, + hue="roi", + ax=ax, +) + +ax.set(xlabel="Speed (º/s)", ylabel="Peak response (ΔF/F)") +ax.legend(title="ROI") +# connect the dots +for roi in rois_selection: + roi_data = peak_response[peak_response["roi"] == roi] + for i in range(0, 4): + ax.plot( + roi_data.loc[roi_data.index[i - 1 : i + 1], "speed"], + roi_data.loc[roi_data.index[i - 1 : i + 1], "peak_response"], + color="gray", + linestyle="--", + ) + + # add the std at the peak as error bars + ax.errorbar( + roi_data.loc[roi_data.index[i], "speed"], + roi_data.loc[roi_data.index[i], "peak_response"], + yerr=roi_data.loc[roi_data.index[i], "std"], + fmt="o", + color="gray", + ) + +# remove top and right spines +ax.spines["top"].set_visible(False) +ax.spines["right"].set_visible(False) + + +plt.savefig(saving_path / "peak_response_speed.pdf") + + +# %% + +# now similar plot but according to rotation angle +# for every angle, get the mean of the roi and plot it + +angles_to_pick = [0, 45, 90, 135, 180, 225, 270, 315] + + +tollerace_deg = 1 + +dataframe_copy_with_rounded_angles = full_dataframe.copy() +dataframe_copy_with_rounded_angles["rotation_angle"] = np.round( + full_dataframe["rotation_angle"] +) +# now np.abs +dataframe_copy_with_rounded_angles.loc[:, "rotation_angle"] = np.abs( + dataframe_copy_with_rounded_angles["rotation_angle"] +) + +for roi in rois_selection: + fig, ax = plt.subplots(figsize=(20, 10)) + mean_response_per_angle = ( + dataframe_copy_with_rounded_angles.loc[:, [roi, "rotation_angle"]] + .groupby("rotation_angle") + .mean() + ) + std = ( + dataframe_copy_with_rounded_angles.loc[:, [roi, "rotation_angle"]] + .groupby("rotation_angle") + .std() + ) + + sns.lineplot( + x="rotation_angle", y=roi, data=mean_response_per_angle, ax=ax + ) + + ax.fill_between( + mean_response_per_angle.index, + mean_response_per_angle[roi] - std[roi], + mean_response_per_angle[roi] + std[roi], + alpha=0.2, + ) + ax.set(xlabel="Rotation angle (º)", ylabel="ΔF/F") + ax.set_title(f"ROI {roi}") + ax.spines["top"].set_visible(False) + ax.spines["right"].set_visible(False) + + plt.savefig(saving_path / f"roi_{roi}_rotation_angle.pdf") + +# %%