diff --git a/examples/plots/atlas_annotations.py b/examples/plots/atlas_annotations.py index 1d27cbb..7aba4a8 100644 --- a/examples/plots/atlas_annotations.py +++ b/examples/plots/atlas_annotations.py @@ -2,7 +2,7 @@ This script needs: - A local path to where the BrainGlobe atlas is stored -- A path to a .csv file specifying RGB values for each region in the atlas +- A path to a .csv file specifying RGB color values for each region """ # %% # Imports @@ -28,12 +28,15 @@ annotation_path = atlas_dir / "annotation.tiff" structures_csv_path = atlas_dir / "structures.csv" +# get path of this script's parent directory +current_dir = Path(os.path.dirname(os.path.abspath(__file__))) +# Load matplotlib parameters (to allow for proper font export) +plt.style.use(current_dir / "plots.mplstyle") # Path to the csv file containing the RGB values for each region # Expected columns are: "acronym", "R", "G", "B" # The .csv is in the same folder as this script colors_csv_filename = "oldenburg_blackcap_colors.csv" -currenct_script_dir = Path(os.path.dirname(os.path.abspath(__file__))) -colors_csv_path = currenct_script_dir / colors_csv_filename +colors_csv_path = current_dir / colors_csv_filename # Path to save the plots save_dir = Path.home() / "Downloads" @@ -47,126 +50,87 @@ # Load both "structures" and "colors" csv files structures = pd.read_csv(structures_csv_path) colors = pd.read_csv(colors_csv_path, dtype={"R": int, "G": int, "B": int}) + +# Prepare structure-to-color mapping + # Merge the two dataframes on the "acronym" column # Maintain the order of the "colors" dataframe -structures = pd.merge( +struct2col_df = pd.merge( structures, colors, on="acronym", how="right", validate="one_to_one" ) -# add an alpha column with all areas being 0.8 -structures.loc[:, "A"] = 1.0 -# Keep only the columns needed for the colormap -structures = structures[["id", "acronym", "R", "G", "B", "A"]] -# Divide RGB values by 255 (normalise to 0 - 1) -for col in ["R", "G", "B"]: - structures[col] = structures[col] / 255 +struct2col_df.loc[:, "A"] = 0.6 # Add alpha column for transparency +for col in ["R", "G", "B"]: # Normalise RGB values to 0-1 range + struct2col_df[col] = round(struct2col_df[col] / 255, 3) +# Keep only necessary columns +struct2col_df = struct2col_df[["id", "acronym", "R", "G", "B", "A"]] # Add an id=0 (empty areas), with RGBA values (0, 0, 0, 0) new_row = pd.DataFrame( {"id": [0], "acronym": "empty", "R": [0], "G": [0], "B": [0], "A": [0]} ) -structures = pd.concat([new_row, structures], ignore_index=True) -# Sort daraframe by id (increasing id values) -structures = structures.sort_values(by="id") -# assign new_id in increasing order -structures.loc[:, "new_id"] = range(len(structures)) -# Save the new dataframe to a csv file here -structures.to_csv(currenct_script_dir / "colors.csv", index=False) - -# Create a dictionary mapping the id to the RGBA values -id_to_rgba = { - row["new_id"]: (row["R"], row["G"], row["B"], row["A"]) - for _, row in structures.iterrows() +struct2col_df = pd.concat([new_row, struct2col_df], ignore_index=True) +# Add a mononically increasing color id column +n_colors = len(struct2col_df) +struct2col_df.loc[:, "color_id"] = range(n_colors) +struct2col_df.head(n_colors) + +# %% +# Construct a new colormap using the structure-to-color mapping +# Map the monotonically increasing color_id to the RGBA values +struct_to_rgba = { + row["color_id"]: (row["R"], row["G"], row["B"], row["A"]) + for _, row in struct2col_df.iterrows() } -# Create a colormap using the RGBA values -annotation_cmap = mcolors.ListedColormap([id_to_rgba[id] for id in id_to_rgba]) -# Create a normalization for the colormap based on the id values -annotation_cmap_norm = mcolors.BoundaryNorm( - list(id_to_rgba.keys()), annotation_cmap.N +atlas_cmap = mcolors.ListedColormap( + [struct_to_rgba[color_id] for color_id in struct_to_rgba], + name="blackap_atlas_v1.1", +) +# Create a normalization for the colormap based on the color id values +atlas_cmap_norm = mcolors.BoundaryNorm( + list(np.arange(n_colors + 1) - 0.5), # list bin edges bracketing the ids + atlas_cmap.N, # number of colors in the colormap ) # Remap the annotation image to the new ids -for id in structures["id"].unique(): - new_id = structures.loc[structures["id"] == id, "new_id"].values[0] - annotation_img[annotation_img == id] = new_id +atlas_overlay = annotation_img.copy() +for id in struct2col_df["id"]: + atlas_overlay[annotation_img == id] = struct2col_df[ + struct2col_df["id"] == id + ]["color_id"].values[0] # %% -# Define funciton for plotting - - -def plot_slices( - reference: np.ndarray, - slices: list[int], - annotation=np.ndarray | None, - axis: int = 0, - vmin_perc: float = 1, - vmax_perc: float = 99, - save_path: Path | None = None, -): - """Plot slices from a 3D image with optional annotation overlay. - - The slices are shown in a single column. - - Parameters - ---------- - reference : np.ndarray - Reference 3D image to plot slices from. - slices : list[int] - List of slice indices to plot. - annotation : np.ndarray, optional - Annotation image to overlay on the reference image, by default None. - If supplied, must have the same shape as the reference image. - axis : int, optional - Axis along which to take slices, by default 0. - vmin_perc : float, optional - Lower percentile for reference image, by default 1. - vmax_perc : float, optional - Upper percentile for reference image, by default 99. - save_path : Path, optional - Path to save the plot, by default None (does not save). - """ - n_slices = len(slices) - ref_slices = [reference.take(s, axis=axis) for s in slices] - height, width = ref_slices[0].shape - fig_width = width / 100 - fig_height = height / 100 * n_slices - fig, ax = plt.subplots(n_slices, 1, figsize=(fig_width, fig_height)) - - if annotation is not None: - ann_slices = [annotation.take(s, axis=axis) for s in slices] - # Make the left half of each slice to 0 - for ann_slice in ann_slices: - ann_slice[:, : width // 2] = 0 - - for i in range(n_slices): - ref_frame = ref_slices[i] - ax[i].imshow( - ref_frame, - cmap="gray", - vmin=np.percentile(ref_frame, vmin_perc), - vmax=np.percentile(ref_frame, vmax_perc), - ) - - if annotation is not None: - ann_frame = ann_slices[i] - ax[i].imshow( - ann_frame, - cmap=annotation_cmap, - norm=annotation_cmap_norm, - ) - ax[i].axis("off") - - fig.subplots_adjust(left=0, right=1, top=1, bottom=0, wspace=0, hspace=0) - if save_path: - save_dir, save_name = save_path.parent, save_path.name.split(".")[0] - save_figure(fig, save_dir, save_name) - - -# %% -# Save plot - -plot_slices( - reference_img, - slices=[120, 260, 356], - annotation=annotation_img, - save_path=save_dir / "test.png", -) +# Plot the reference image with the atlas overlay + +slices = [120, 260, 356] +n_slices = len(slices) +ref_slices = [reference_img.take(s, axis=0) for s in slices] + +height, width = ref_slices[0].shape +fig_width = width / 100 +fig_height = height / 100 * n_slices +fig, ax = plt.subplots(n_slices, 1, figsize=(fig_width, fig_height)) + +ann_slices = [atlas_overlay.take(s, axis=0) for s in slices] +for ann_slice in ann_slices: + ann_slice[:, : width // 2] = 0 # Make the left half of each slice to 0 + +for i in range(n_slices): + ref_frame = ref_slices[i] + ax[i].imshow( + ref_frame, + cmap="gray", + vmin=np.percentile(ref_frame, 1), + vmax=np.percentile(ref_frame, 99), + ) + + ann_frame = ann_slices[i] + ax[i].imshow( + ann_frame, + cmap=atlas_cmap, + norm=atlas_cmap_norm, + interpolation="nearest", + ) + ax[i].axis("off") + +fig.subplots_adjust(left=0, right=1, top=1, bottom=0, wspace=0, hspace=0) +save_figure(fig, save_dir, "annotations_overlaid_on_reference") diff --git a/examples/plots/colors.csv b/examples/plots/colors.csv deleted file mode 100644 index 1926ff9..0000000 --- a/examples/plots/colors.csv +++ /dev/null @@ -1,26 +0,0 @@ -id,acronym,R,G,B,A,new_id -0,empty,0.0,0.0,0.0,0.0,0 -1,P,0.19215686274509805,0.6862745098039216,0.9607843137254902,1.0,1 -2,St,0.9450980392156862,0.796078431372549,0.22745098039215686,1.0,2 -3,Di,0.9411764705882353,0.3568627450980392,0.07058823529411765,1.0,3 -4,Me,0.2784313725490196,0.4823529411764706,0.9490196078431372,1.0,4 -5,Pn,0.8745098039215686,0.8745098039215686,0.21568627450980393,1.0,5 -6,Cb,0.6745098039215687,0.984313725490196,0.2196078431372549,1.0,6 -10,PrVd,0.8862745098039215,0.2627450980392157,0.0392156862745098,1.0,7 -17,PrVv,0.996078431372549,0.6,0.17254901960784313,1.0,8 -18,SpVl,0.23921568627450981,0.20784313725490197,0.5450980392156862,1.0,9 -19,SpVm,0.18823529411764706,0.07058823529411765,0.23137254901960785,1.0,10 -20,Ento,0.13333333333333333,0.9215686274509803,0.6666666666666666,1.0,11 -30,HP,0.47843137254901963,0.01568627450980392,0.011764705882352941,1.0,12 -40,N,0.2196078431372549,0.15294117647058825,0.42745098039215684,1.0,13 -50,H,0.7843137254901961,0.9372549019607843,0.20392156862745098,1.0,14 -60,M,0.5607843137254902,1.0,0.28627450980392155,1.0,15 -70,A,0.12549019607843137,0.7803921568627451,0.8745098039215686,1.0,16 -80,OB,0.24705882352941178,0.9647058823529412,0.5411764705882353,1.0,17 -90,CDL,0.2549019607843137,0.5882352941176471,1.0,1.0,18 -110,N.Rot,0.27450980392156865,0.3803921568627451,0.8392156862745098,1.0,19 -140,Gld,0.2549019607843137,0.27450980392156865,0.6745098039215687,1.0,20 -220,OT,0.1843137254901961,0.9450980392156862,0.6078431372549019,1.0,21 -305,CN-HP,0.9882352941176471,0.7019607843137254,0.21176470588235294,1.0,22 -402,NFT,0.9803921568627451,0.4823529411764706,0.12156862745098039,1.0,23 -505,CN-H,0.8156862745098039,0.1843137254901961,0.0196078431372549,1.0,24