From 04a0b8829edf100455ad0ff3912abed47f34820f Mon Sep 17 00:00:00 2001 From: William Brannon Date: Wed, 5 Jun 2024 21:29:42 -0600 Subject: [PATCH] Modify how source IDs are filtered --- .../fusion_engine_client/analysis/analyzer.py | 20 ++++++++--- .../parsers/mixed_log_reader.py | 33 ++++++++++++------- 2 files changed, 36 insertions(+), 17 deletions(-) diff --git a/python/fusion_engine_client/analysis/analyzer.py b/python/fusion_engine_client/analysis/analyzer.py index ddb7e865..a164374d 100755 --- a/python/fusion_engine_client/analysis/analyzer.py +++ b/python/fusion_engine_client/analysis/analyzer.py @@ -155,8 +155,10 @@ def __init__(self, self.source_ids = source_ids.intersection(self.reader.source_ids) # If the requested source IDs are unavailable, raise error. - if len(source_ids) == 0: - raise ValueError("Requested source ID(s) unavailable. Exiting.") + if len(self.source_ids) == 0: + self.logger.warning('Requested source IDs unavailable. Cannot extract data.') + log_reader = self.reader.get_log_reader() + log_reader.filter_in_place(None, clear_existing='source_id') if time_axis in ('relative', 'rel'): self.time_axis = 'relative' @@ -482,7 +484,11 @@ def plot_pose(self): return # Read the pose data. - result = self.reader.read(message_types=[PoseMessage], source_ids=[min(self.source_ids)], **self.params) + if len(self.source_ids) > 0: + source_id = [min(self.source_ids)] + else: + source_id = self.source_ids + result = self.reader.read(message_types=[PoseMessage], source_ids=source_id, **self.params) pose_data = result[PoseMessage.MESSAGE_TYPE] if len(pose_data.p1_time) == 0: @@ -697,7 +703,11 @@ def plot_solution_type(self): return # Read the pose data. - result = self.reader.read(message_types=[PoseMessage], source_ids=[min(self.source_ids)], **self.params) + if len(self.source_ids) > 0: + source_id = [min(self.source_ids)] + else: + source_id = self.source_ids + result = self.reader.read(message_types=[PoseMessage], source_ids=source_id, **self.params) pose_data = result[PoseMessage.MESSAGE_TYPE] if len(pose_data.p1_time) == 0: @@ -883,7 +893,7 @@ def plot_map(self, mapbox_token): """! @brief Plot a map of the position data. """ - if self.output_dir is None: + if self.output_dir is None or len(self.source_ids) == 0: return mapbox_token = self.get_mapbox_token(mapbox_token) diff --git a/python/fusion_engine_client/parsers/mixed_log_reader.py b/python/fusion_engine_client/parsers/mixed_log_reader.py index 89b11b48..9d8b45b8 100644 --- a/python/fusion_engine_client/parsers/mixed_log_reader.py +++ b/python/fusion_engine_client/parsers/mixed_log_reader.py @@ -452,7 +452,7 @@ def parse_entry_at_index(self, index: file_index.FileIndexEntry): def clear_filters(self): self.filter_in_place(key=None, clear_existing=True) - def filter_in_place(self, key, clear_existing: bool = False, + def filter_in_place(self, key, clear_existing: Union[bool, str] = False, source_ids: Union[List[int], set[int], tuple[int]] = None): """! @brief Limit the returned messages by type or time. @@ -483,15 +483,26 @@ def filter_in_place(self, key, clear_existing: bool = False, # that we just read. prev_offset_bytes = self.index.offset[self.next_index_elem - 1] - # If requested, clear previous filter criteria. - if clear_existing: - if self.index is None: + if type(clear_existing) == str: + # Verify input string and clear accordingly. + if clear_existing == 'message_type': self.message_types = copy.deepcopy(self._original_message_types) + elif clear_existing == 'time_range': self.time_range = copy.deepcopy(self._original_time_range) - self.remove_invalid_p1_time = False - self.requested_source_ids = None + elif clear_existing == 'source_id': + self.requested_source_ids = set() else: - self.index = self._original_index + raise ValueError('Invalid clear_existing flag: %s' % clear_existing) + else: + # If requested, clear previous filter criteria. + if clear_existing: + if self.index is None: + self.message_types = copy.deepcopy(self._original_message_types) + self.time_range = copy.deepcopy(self._original_time_range) + self.remove_invalid_p1_time = False + self.requested_source_ids = None + else: + self.index = self._original_index # Set requested source IDs. if source_ids is not None: @@ -504,7 +515,8 @@ def filter_in_place(self, key, clear_existing: bool = False, 'source IDs: {}'.format(unavailable_source_ids)) source_ids = list(source_ids.intersection(self.available_source_ids)) if len(source_ids) == 0: - raise ValueError("Requested source ID(s) unavailable. Exiting.") + self.logger.debug('Requested source IDs unavailable. Cannot extract data.') + self.filter_in_place(None, clear_existing='source_id') if len(unavailable_source_ids) > 0: self.logger.info('Extracting the following available requested source IDs: {}'.format(source_ids)) @@ -604,10 +616,7 @@ def _populate_available_source_ids(self, num_messages_to_read: int = 10): self.available_source_ids = set() # Loop over all message types and read N of each type. for message_type in np.unique(self.index['type']): - try: - message_type = MessageType(message_type, raise_on_unrecognized=True) - except (KeyError, ValueError) as e: - continue + message_type = MessageType(message_type, raise_on_unrecognized=False) self.filter_in_place(message_type) num_messages_read = 0 while num_messages_read < num_messages_to_read: