Skip to content

Commit

Permalink
Bring mock class back in the method where it was
Browse files Browse the repository at this point in the history
  • Loading branch information
lauraporta committed Dec 5, 2024
1 parent 0601588 commit 31a63a1
Showing 1 changed file with 32 additions and 35 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -181,38 +181,6 @@ def create_rotation_angles(
# -----------------------------------------------------


# Mock class to use the IncrementalPipeline
class MockIncrementalPipeline(IncrementalPipeline):
def __init__(self, rotated_stack_incremental, incremental_angles):
# Overwrite the constructor and provide the mock data
self.image_stack = rotated_stack_incremental
self.rot_deg_frame = incremental_angles[
:: rotated_stack_incremental.shape[1]
]
self.num_frames = rotated_stack_incremental.shape[0]

if __name__ == "__main__":
self.debugging_plots = True
self.debug_plots_folder = Path("debug/")
else:
self.debugging_plots = False

def calculate_mean_images(self, image_stack: np.ndarray) -> list:
# Overwrite original method as it is too bound
# to signal coming from a real motor
angles_subset = copy.deepcopy(self.rot_deg_frame)
rounded_angles = np.round(angles_subset)

mean_images = []
for i in np.arange(10, 360, 10):
images = image_stack[rounded_angles == i]
mean_image = np.mean(images, axis=0)

mean_images.append(mean_image)

return mean_images


def get_center_of_rotation(
rotated_stack_incremental: np.ndarray, incremental_angles: np.ndarray
) -> Tuple[int, int]:
Expand All @@ -239,10 +207,39 @@ def get_center_of_rotation(
The center of rotation
"""

# Mock class to use the IncrementalPipeline
class MockIncrementalPipeline(IncrementalPipeline):
def __init__(self):
# Overwrite the constructor and provide the mock data
self.image_stack = rotated_stack_incremental
self.rot_deg_frame = incremental_angles[
:: rotated_stack_incremental.shape[1]
]
self.num_frames = rotated_stack_incremental.shape[0]

if __name__ == "__main__":
self.debugging_plots = True
self.debug_plots_folder = Path("debug/")
else:
self.debugging_plots = False

def calculate_mean_images(self, image_stack: np.ndarray) -> list:
# Overwrite original method as it is too bound
# to signal coming from a real motor
angles_subset = copy.deepcopy(self.rot_deg_frame)
rounded_angles = np.round(angles_subset)

mean_images = []
for i in np.arange(10, 360, 10):
images = image_stack[rounded_angles == i]
mean_image = np.mean(images, axis=0)

mean_images.append(mean_image)

return mean_images

# Use the mock class to find the center of rotation
pipeline = MockIncrementalPipeline(
rotated_stack_incremental, incremental_angles
)
pipeline = MockIncrementalPipeline()
center_of_rotation = pipeline.find_center_of_rotation()

return center_of_rotation
Expand Down

0 comments on commit 31a63a1

Please sign in to comment.