Skip to content

Commit

Permalink
mean of cached imaged calculated with ROI is moved around
Browse files Browse the repository at this point in the history
  • Loading branch information
MikeSullivan7 committed Dec 3, 2024
1 parent c55541d commit 1e3f1cc
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 53 deletions.
89 changes: 38 additions & 51 deletions mantidimaging/gui/windows/live_viewer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,23 +44,29 @@ class ImageCache:
cache_dict: dict = {}
image_list: list[Image_Data]
image_paths: set[str] = set()
mean: np.ndarray = np.array([])
roi: SensibleROI | None = None
param_to_calc: list[str] = []
max_cache_size: int = 100
buffer_size: int = 10
max_cache_size: int | None = None
buffer_size: int | None = None

def __init__(self):
pass
def __init__(self, max_cache_size=None, buffer_size=None):
self.max_cache_size = max_cache_size
self.buffer_size = buffer_size

def add_to_cache(self, image: Image_Data, image_array: np.ndarray):
if image.image_path not in self.cache_dict.keys():
self.cache_dict[image.image_path] = (image_array, image.image_modified_time)
if self.max_cache_size is not None:
if self.max_cache_size <= len(self.cache_dict):
self.remove_oldest_image()
self.cache_dict[image.image_path] = [image_array, image.image_modified_time]

def remove_from_cache(self, image: Image_Data):
if image.image_path in self.cache_dict.keys():
del self.cache_dict[image.image_path]

def remove_oldest_image(self):
ordered_times = sorted(self.get_cached_image_modified_times())
oldest_image_path = [path for path in self.cache_dict if self.cache_dict[path][1] == ordered_times[0]][0]
del self.cache_dict[oldest_image_path]

def load_image(self, image: Image_Data) -> np.ndarray:
if image.image_path in self.cache_dict.keys():
return self.cache_dict[image.image_path][0]
Expand All @@ -72,55 +78,20 @@ def load_image(self, image: Image_Data) -> np.ndarray:
def get_cache(self):
return self.cache_dict

def get_current_cache_size(self):
return len(self.cache_dict)

def get_cached_image_paths(self):
return list(self.cache_dict.keys())

def get_cached_image_arrays(self):
return list(self.cache_dict.values())[::, 0]
print(f"{[info[0] for info in list(self.cache_dict.values())]=}")
return np.stack([info[0] for info in list(self.cache_dict.values())])

def get_cached_image_modified_times(self):
return [info[1] for info in list(self.cache_dict.values())]


# def update_param_calculations(self) -> None:
# if 'mean' in self.param_to_calc:
# if len(self.mean) == len(self.image_list) - 1:
# self.add_last_mean()
# else:
# if self.roi:
# self.calc_mean_fully_roi()
# else:
# self.calc_mean_fully()
#
# def add_last_mean(self) -> None:
# if self.delayed_stack is not None:
# if self.roi:
# left, top, right, bottom = self.roi
# mean_to_add = dask.optimize(dask.array.mean(self.delayed_stack[-1, top:bottom,
# left:right]))[0].compute()
# else:
# mean_to_add = dask.optimize(dask.array.mean(self.delayed_stack[-1]))[0].compute()
# self.mean = np.append(self.mean, mean_to_add)
# self.calc_mean_buffer()

# def calc_mean_fully(self) -> None:
# if self.delayed_stack is not None:
# self.mean = dask.array.mean(self.delayed_stack, axis=(1, 2)).compute()
#
# def calc_mean_fully_roi(self):
# if self.delayed_stack is not None and self.image_list:
# left, top, right, bottom = self.roi
# current_cache_size = len(self.)
# self.mean = np.full(len(self.image_list), np.nan)
# np.put(self.mean, range(-current_cache_size, 0), self.calc_mean_cached_images(left, top, right, bottom))
#
# def calc_mean_cached_images(self, left, top, right, bottom):
# current_cache_size = self.get_computed_image.cache_info()[3]
# cache_stack = [
# self.get_computed_image(index)
# for index in range(self.selected_index - current_cache_size + 1, self.selected_index + 1, 1)
# ]
# cache_stack_array = np.stack(cache_stack)
# cache_stack_mean = np.mean(cache_stack_array[:, top:bottom, left:right], axis=(1, 2))
# return cache_stack_mean
#
# def calc_mean_buffer(self):
# nanInds = np.argwhere(np.isnan(self.mean))
# left, top, right, bottom = self.roi
Expand Down Expand Up @@ -235,6 +206,7 @@ def __init__(self, presenter: LiveViewerWindowPresenter):
self.mean_dict: dict[Path, float] = {}
self.roi: SensibleROI | None = None
self.image_cache = ImageCache()
self.mean_cached: np.ndarray = np.empty(0)

@property
def path(self) -> Path | None:
Expand Down Expand Up @@ -304,6 +276,21 @@ def calc_mean_fully(self) -> None:
for image in self.images:
self.add_mean(image, self.image_cache.load_image(image))

def calc_mean_cache(self) -> None:
if self.roi:
left, top, right, bottom = self.roi
self.mean_cached = np.mean(self.image_cache.get_cached_image_arrays()[:, top:bottom, left:right], axis=(1, 2))
else:
self.mean_cached = np.mean(self.image_cache.get_cached_image_arrays(), axis=(1, 2))

def update_mean_with_cached_images(self) -> None:
np.put(self.mean, range(-self.image_cache.get_current_cache_size(), 0), self.mean_cached)

def clear_and_update_mean_cache(self) -> None:
self.mean = np.full(len(self.images), np.nan)
self.calc_mean_cache()
self.update_mean_with_cached_images()


class ImageWatcher(QObject):
"""
Expand Down
6 changes: 4 additions & 2 deletions mantidimaging/gui/windows/live_viewer/presenter.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ def handle_deleted(self) -> None:

def update_image_list(self, images_list: list[Image_Data]) -> None:
"""Update the image in the view."""
# TODO: Might be a good idea to update and store the image list in the model so it can be cycled through
if not images_list:
self.handle_deleted()
self.view.set_load_as_dataset_enabled(False)
Expand Down Expand Up @@ -117,6 +116,8 @@ def display_image(self, image_data_obj: Image_Data) -> None:
self.model.set_roi(self.view.live_viewer.get_roi())
if image_data_obj.image_path not in self.model.mean_dict.keys():
self.model.add_mean(image_data_obj, image_data)
self.model.calc_mean_cache()
self.model.update_mean_with_cached_images()
self.view.show_most_recent_image(image_data)
self.update_spectrum(self.model.mean)
self.view.live_viewer.show_error(None)
Expand Down Expand Up @@ -172,5 +173,6 @@ def handle_roi_moved(self, force_new_spectrums: bool = False):
self.update_spectrum(self.model.mean)

def handle_roi_moved_start(self):
self.model.clear_mean_partial()
self.model.clear_mean()
self.model.clear_and_update_mean_cache()
self.update_spectrum(self.model.mean)

0 comments on commit 1e3f1cc

Please sign in to comment.