Skip to content

Commit

Permalink
WIP breaking up plotting scripts
Browse files Browse the repository at this point in the history
  • Loading branch information
niksirbi committed Sep 10, 2024
1 parent 1c07942 commit 238d8ee
Show file tree
Hide file tree
Showing 8 changed files with 453 additions and 181 deletions.
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ repos:
- id: mypy
additional_dependencies:
- types-setuptools
- types-pyyaml
- repo: https://github.com/mgedmin/check-manifest
rev: "0.49"
hooks:
Expand Down
181 changes: 0 additions & 181 deletions examples/BlackCap_plot_figures.py

This file was deleted.

12 changes: 12 additions & 0 deletions examples/plots/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
atlas_forge_dir: "/Volumes/neuroinformatics/neuroinformatics/atlas-forge"
species: "BlackCap"
template_name: "template_sym_res-25um_n-18"
resolution_um: 25
transform_types: ["rigid", "similarity", "affine", "nlin"]
num_iterations: 4
show_coronal_slice: 256
vmin_percentile: 0.1
vmax_percentile: 99.9
animation_fps: [2, 4, 8]
use4template_dir_suffix: "orig-asr_N4_aligned_padded_use4template"
example_subject: "sub-BC41o"
6 changes: 6 additions & 0 deletions examples/plots/plots.mplstyle
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
pdf.fonttype : 42
ps.fonttype : 42
svg.fonttype : none
font.family : sans-serif
font.sans-serif : Barlow
savefig.dpi : 300
125 changes: 125 additions & 0 deletions examples/plots/template_building_stages.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
# Imports
import os
from pathlib import Path

from matplotlib import pyplot as plt
from matplotlib.animation import FuncAnimation
from utils import (
collect_coronal_slices,
collect_template_paths,
compute_vmin_vmax_across_slices,
load_config,
save_figure,
setup_directories,
)

# get path of this script's parent directory
current_dir = Path(os.path.dirname(os.path.abspath(__file__)))
# Load matplotlib parameters (to allow for proper font export)
plt.style.use(current_dir / "plots.mplstyle")
# Load config file containing template building parameters
config = load_config(current_dir / "config.yaml")

# Setup directories based on config file
atlas_dir, template_dir, plots_dir = setup_directories(config)

# Load the list of transform types and number of iterations
transform_types = config["transform_types"]
n_transforms = len(transform_types)
n_iter = config["num_iterations"]
print("transform types: ", transform_types)
print("number of iterations: ", n_iter)

# Collect template images for each iteration and transform type
template_paths = collect_template_paths(template_dir, transform_types, n_iter)

# Collect coronal slices for each iteration and transform type
show_coronal_slice = config["show_coronal_slice"]
template_slices = collect_coronal_slices(template_paths, show_coronal_slice)

# Calculate vmin and vmax for all slices to ensure consistent scaling
vmin, vmax = compute_vmin_vmax_across_slices(
template_slices,
vmin_perc=config["vmin_percentile"],
vmax_perc=config["vmax_percentile"],
)

# Compute the aspect ratio of the 1st slice (should be same for all
width = template_slices["rigid iter-0"].shape[1]
height = template_slices["rigid iter-0"].shape[0]
aspect = width / height

# Plot all transform types and iterations in a grid
# Rows: transform types, Columns: iterations
figs, axs = plt.subplots(
n_transforms,
n_iter,
figsize=(2 * aspect * n_iter, 2 * n_transforms),
)

for t, transform_type in enumerate(transform_types):
for i in range(n_iter):
frame = template_slices[f"{transform_type} iter-{i}"]
ax = axs[t, i]
ax.imshow(frame, vmin=vmin, vmax=vmax, cmap="gray")

title = f"iter {i}" if t == 0 else ""
ax.set_title(title, fontsize="x-large")

ylabel = transform_type if i == 0 else ""
ylabel = ylabel.replace("nlin", "non-linear")
ax.set_ylabel(ylabel, fontsize="x-large")
ax.set_xticks([])
ax.set_yticks([])

figs.subplots_adjust(
left=0.05, right=0.95, bottom=0.05, top=0.95, wspace=0.05, hspace=0.05
)
save_figure(
figs,
plots_dir,
"template_across_transform_types_and_iterations",
)

# Create animation of template building progress

# Create figure and axis
fig, ax = plt.subplots(figsize=(8 * aspect, 8))
ax.set_xticks([])
ax.set_yticks([])

# Initialize the plot with the first frame
frame_list = list(template_slices.values())
frame_names = list(template_slices.keys())
img = ax.imshow(frame_list[0], vmin=vmin, vmax=vmax, cmap="gray")


def update(frame_index):
"""Update function for the animation"""
img.set_array(frame_list[frame_index])
transform, iteration = frame_names[frame_index].split()
transform = transform.replace("nlin", "non-linear")
ax.set_ylabel(transform, fontsize="x-large")
ax.set_title(f"iter {iteration}", fontsize="x-large")
return [img]


fig.subplots_adjust(left=0.05, right=0.95, top=0.95, bottom=0.05)

# Create the animation
ani = FuncAnimation(
fig, # The figure object
update, # The update function
frames=len(template_slices), # Number of frames
interval=200, # Interval between frames in milliseconds
blit=True, # Use blitting for better performance
)

# Save the animation as a gif
for fps in config["animation_fps"]:
ani.save(
plots_dir / f"transforms_iterations_animation_fps-{fps}.gif",
writer="ffmpeg",
dpi=150,
fps=fps,
)
Loading

0 comments on commit 238d8ee

Please sign in to comment.