Skip to content

Commit

Permalink
WIP 🏗️: draft pipeline for adapting center of rotation
Browse files Browse the repository at this point in the history
  • Loading branch information
lauraporta committed Sep 25, 2024
1 parent 6000609 commit a0c6017
Showing 1 changed file with 158 additions and 25 deletions.
183 changes: 158 additions & 25 deletions derotation/analysis/full_rotation_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -793,7 +793,7 @@ def plot_rotation_angles(self):

### ----------------- Derotation ----------------- ###
def shift_image_given_different_center_of_rotation(
self, image: np.ndarray, center_of_rotation: Tuple[int, int]
self, offset: int = 0
) -> np.ndarray:
"""Shifts the image to the center of rotation.
It is useful when the center of rotation is not at the center of the
Expand All @@ -811,17 +811,101 @@ def shift_image_given_different_center_of_rotation(
np.ndarray
The shifted image.
"""
x_center, y_center = center_of_rotation
x_center = int(x_center)
y_center = int(y_center)

x_shift = x_center - image.shape[0] // 2
y_shift = y_center - image.shape[1] // 2
def get_padding(center_of_rotation, image):
x_center, y_center = center_of_rotation
x_center = int(x_center)
y_center = int(y_center)

# real center of the image
x_center_image = int(image.shape[0] / 2)
y_center_image = int(image.shape[1] / 2)

x_shift = x_center_image - x_center
y_shift = y_center_image - y_center
if x_shift == 0 and y_shift == 0:
return image
elif x_shift > 0:
pad_left = 0
pad_right = x_shift // 2
else:
pad_left = -x_shift // 2
pad_right = 0

if y_shift > 0:
pad_top = 0
pad_bottom = y_shift // 2
else:
pad_top = -y_shift // 2
pad_bottom = 0

logging.info(f"Shifting image by {x_shift, y_shift}")
logging.info(
f"Padding image by {pad_top, pad_bottom, pad_left, pad_right}"
)

return pad_top, pad_bottom, pad_left, pad_right

def apply_padding(image, padding, offset):
pad_top, pad_bottom, pad_left, pad_right = padding

padded_image = np.pad(
image,
((pad_top, pad_bottom), (pad_left, pad_right)),
"constant",
constant_values=offset,
)

return padded_image

fig, ax = plt.subplots(1, 2, figsize=(10, 5))
ax[0].imshow(self.image_stack[0], cmap="viridis")
ax[0].set_title("Before shifting")
ax[0].scatter(
self.center_of_rotation[0], self.center_of_rotation[1], color="red"
)
ax[0].scatter(
int(self.image_stack[0].shape[0] / 2),
int(self.image_stack[0].shape[1] / 2),
color="green",
)
ax[0].axis("off")

padding = get_padding(self.center_of_rotation, self.image_stack[0])
self.padded_image_stack = np.asarray(
[apply_padding(img, padding, offset) for img in self.image_stack]
)
# save shifted array
tiff.imsave(
self.config["paths_read"]["path_to_tif"] + "shifted_raw.tif",
self.padded_image_stack,
)
logging.info(
"Image stack shifted, new center of rotation:"
+ f"{self.center_of_rotation}"
)
logging.info(
"Shifted image saved in "
+ f"{self.config['paths_write']['derotated_tiff_folder']}"
)
logging.info(
"Image was padded and has new shape: "
+ f"{self.padded_image_stack.shape}, original shape: "
+ f"{self.image_stack.shape}"
)

new_image_center = (
int(self.padded_image_stack[0].shape[0] / 2),
int(self.padded_image_stack[0].shape[1] / 2),
)
ax[1].imshow(self.padded_image_stack[0], cmap="viridis")
ax[1].set_title("After shifting")
ax[1].scatter(new_image_center[0], new_image_center[1], color="orange")
ax[1].axis("off")

shifted_image = np.roll(image, x_shift, axis=0)
shifted_image = np.roll(shifted_image, y_shift, axis=1)
plt.savefig(self.debug_plots_folder / "image_shift.png", dpi=300)

return shifted_image
return self.padded_image_stack

def rotate_frames_line_by_line(self) -> np.ndarray:
"""Rotates the image stack line by line, using the rotation angles
Expand Down Expand Up @@ -853,29 +937,78 @@ def rotate_frames_line_by_line(self) -> np.ndarray:

offset = self.find_image_offset(self.image_stack[0])

if self.center_of_rotation:
self.image_stack = (
self.shift_image_given_different_center_of_rotation(
self.image_stack, self.center_of_rotation
)
if self.center_of_rotation is not None:
self.padded_image_stack = (
self.shift_image_given_different_center_of_rotation(offset)
)
# save shifted array
tiff.imsave(
self.config["paths_write"]["derotated_tiff_folder"]
+ self.config["paths_write"]["saving_name"]
+ "_shifted_raw.tif",
rotated_image_stack = rotate_an_image_array_line_by_line(
self.padded_image_stack,
self.rot_deg_line,
blank_pixels_value=offset,
num_lines_per_frame=self.image_stack.shape[1],
# plotting_hook_line_addition=self.plotting_hook_for_derotation_line_addition,
plotting_hook_image_completed=self.plotting_hook_for_derotation_image_completed,
)
else:
rotated_image_stack = rotate_an_image_array_line_by_line(
self.image_stack,
self.rot_deg_line,
blank_pixels_value=offset,
# plotting_hook_line_addition=self.plotting_hook_for_derotation_line_addition,
plotting_hook_image_completed=self.plotting_hook_for_derotation_image_completed,
)

rotated_image_stack = rotate_an_image_array_line_by_line(
self.image_stack,
self.rot_deg_line,
blank_pixels_value=offset,
)

logging.info("✨ Image stack rotated ✨")
return rotated_image_stack

@staticmethod
def plotting_hook_for_derotation_line_addition(
rotated_filled_image, rotated_line, image_counter, line_counter, angle
):
fig, ax = plt.subplots(1, 2, figsize=(10, 10))

ax[0].imshow(rotated_filled_image, cmap="viridis")
ax[0].set_title(f"Frame {image_counter}")
ax[0].axis("off")

ax[1].imshow(rotated_line, cmap="viridis")
ax[1].set_title(f"Line {line_counter}, angle: {angle:.2f}")
ax[1].axis("off")

plt.savefig(
f"debug/lines/derotated_image_{image_counter}_line_{line_counter}.png",
dpi=300,
)
plt.close()

@staticmethod
def plotting_hook_for_derotation_image_completed(
rotated_image_stack, image_counter
):
"""Hook for plotting the image stack after derotation.
It is useful for debugging purposes.
"""
if image_counter == 149:
fig, ax = plt.subplots(1, 1, figsize=(10, 10))
# plot maximum projection of the image stack
ax.imshow(
np.max(rotated_image_stack[:image_counter], axis=0),
cmap="viridis",
)
ax.axis("off")
plt.savefig("debug/max_projection.png", dpi=300)
plt.close()

plt.close()

fig, ax = plt.subplots(1, 1, figsize=(10, 10))
ax.imshow(rotated_image_stack[image_counter], cmap="viridis")
ax.axis("off")
plt.savefig(
f"debug/frames/derotated_image_{image_counter}.png", dpi=300
)
plt.close()

@staticmethod
def find_image_offset(img):
"""Find the "F0", also called "image offset" for a given image.
Expand Down

0 comments on commit a0c6017

Please sign in to comment.