diff --git a/mantidimaging/gui/windows/live_viewer/model.py b/mantidimaging/gui/windows/live_viewer/model.py index 92aa9620d26..0d15134a224 100644 --- a/mantidimaging/gui/windows/live_viewer/model.py +++ b/mantidimaging/gui/windows/live_viewer/model.py @@ -45,9 +45,9 @@ class ImageCache: image_list: list[Image_Data] image_paths: set[str] = set() max_cache_size: int | None = None - buffer_size: int | None = None + buffer_size: int = 10 - def __init__(self, max_cache_size=None, buffer_size=None): + def __init__(self, max_cache_size=None, buffer_size=10): self.max_cache_size = max_cache_size self.buffer_size = buffer_size @@ -62,10 +62,13 @@ def remove_from_cache(self, image: Image_Data): if image.image_path in self.cache_dict.keys(): del self.cache_dict[image.image_path] - def remove_oldest_image(self): + def get_oldest_image(self): ordered_times = sorted(self.get_cached_image_modified_times()) oldest_image_path = [path for path in self.cache_dict if self.cache_dict[path][1] == ordered_times[0]][0] - del self.cache_dict[oldest_image_path] + return oldest_image_path + + def remove_oldest_image(self): + del self.cache_dict[self.get_oldest_image()] def load_image(self, image: Image_Data) -> np.ndarray: if image.image_path in self.cache_dict.keys(): @@ -85,34 +88,12 @@ def get_cached_image_paths(self): return list(self.cache_dict.keys()) def get_cached_image_arrays(self): - print(f"{[info[0] for info in list(self.cache_dict.values())]=}") return np.stack([info[0] for info in list(self.cache_dict.values())]) def get_cached_image_modified_times(self): return [info[1] for info in list(self.cache_dict.values())] - # def calc_mean_buffer(self): - # nanInds = np.argwhere(np.isnan(self.mean)) - # left, top, right, bottom = self.roi - # if nanInds.size > 0: - # print(f"{self.mean=}") - # if nanInds.size < self.buffer_size: - # buffer_start = 0 - # else: - # buffer_start = nanInds.size - self.buffer_size - # dask_mean = dask.optimize( - # dask.array.mean(self.delayed_stack[buffer_start:nanInds.size, top:bottom, left:right], - # axis=(1, 2)))[0].compute() - # np.put(self.mean, range(buffer_start, nanInds.size), dask_mean) - - def delete_all_data(self): - pass - - def add_param_to_calc(self, param_name: str): - self.param_to_calc.append(param_name) - - class Image_Data: """ Image Data Class to store represent image data. @@ -205,7 +186,7 @@ def __init__(self, presenter: LiveViewerWindowPresenter): self.mean: np.ndarray = np.empty(0) self.mean_dict: dict[Path, float] = {} self.roi: SensibleROI | None = None - self.image_cache = ImageCache() + self.image_cache = ImageCache(max_cache_size=10) self.mean_cached: np.ndarray = np.empty(0) @property @@ -291,6 +272,23 @@ def clear_and_update_mean_cache(self) -> None: self.calc_mean_cache() self.update_mean_with_cached_images() + def calc_mean_buffer(self): + nanInds = np.argwhere(np.isnan(self.mean)) + if self.roi: + left, top, right, bottom = self.roi + else: + left, top, right, bottom = (0, 0, -1, -1) + if nanInds.size > 0: + oldest_image_modified_time = self.image_cache.cache_dict[self.image_cache.get_oldest_image()][1] + all_modified_times = [image.image_modified_time for image in self.images] + norm_mod_times = [mod_time - oldest_image_modified_time for mod_time in all_modified_times] + oldest_image_index = norm_mod_times.index(0.0) + for ind in range(len(nanInds) - 1, len(nanInds) - 1 - self.image_cache.buffer_size, -1): + if ind < 0: + break + buffer_mean = np.mean(load_image_from_path(self.images[ind].image_path)[top:bottom, left:right]) + np.put(self.mean, ind, buffer_mean) + class ImageWatcher(QObject): """ diff --git a/mantidimaging/gui/windows/live_viewer/presenter.py b/mantidimaging/gui/windows/live_viewer/presenter.py index cb91b6eae83..d5a907c8d5d 100644 --- a/mantidimaging/gui/windows/live_viewer/presenter.py +++ b/mantidimaging/gui/windows/live_viewer/presenter.py @@ -118,6 +118,7 @@ def display_image(self, image_data_obj: Image_Data) -> None: self.model.add_mean(image_data_obj, image_data) self.model.calc_mean_cache() self.model.update_mean_with_cached_images() + self.model.calc_mean_buffer() self.view.show_most_recent_image(image_data) self.update_spectrum(self.model.mean) self.view.live_viewer.show_error(None) @@ -169,7 +170,7 @@ def handle_roi_moved(self, force_new_spectrums: bool = False): roi = self.view.live_viewer.get_roi() self.model.set_roi(roi) self.model.clear_mean() - self.model.calc_mean_fully() + self.model.clear_and_update_mean_cache() self.update_spectrum(self.model.mean) def handle_roi_moved_start(self):