Skip to content

Commit

Permalink
Modify how source IDs are filtered
Browse files Browse the repository at this point in the history
  • Loading branch information
wbrannon committed Jun 6, 2024
1 parent b0b692b commit 04a0b88
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 17 deletions.
20 changes: 15 additions & 5 deletions python/fusion_engine_client/analysis/analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,8 +155,10 @@ def __init__(self,

self.source_ids = source_ids.intersection(self.reader.source_ids)
# If the requested source IDs are unavailable, raise error.
if len(source_ids) == 0:
raise ValueError("Requested source ID(s) unavailable. Exiting.")
if len(self.source_ids) == 0:
self.logger.warning('Requested source IDs unavailable. Cannot extract data.')
log_reader = self.reader.get_log_reader()
log_reader.filter_in_place(None, clear_existing='source_id')

if time_axis in ('relative', 'rel'):
self.time_axis = 'relative'
Expand Down Expand Up @@ -482,7 +484,11 @@ def plot_pose(self):
return

# Read the pose data.
result = self.reader.read(message_types=[PoseMessage], source_ids=[min(self.source_ids)], **self.params)
if len(self.source_ids) > 0:
source_id = [min(self.source_ids)]
else:
source_id = self.source_ids
result = self.reader.read(message_types=[PoseMessage], source_ids=source_id, **self.params)
pose_data = result[PoseMessage.MESSAGE_TYPE]

if len(pose_data.p1_time) == 0:
Expand Down Expand Up @@ -697,7 +703,11 @@ def plot_solution_type(self):
return

# Read the pose data.
result = self.reader.read(message_types=[PoseMessage], source_ids=[min(self.source_ids)], **self.params)
if len(self.source_ids) > 0:
source_id = [min(self.source_ids)]
else:
source_id = self.source_ids
result = self.reader.read(message_types=[PoseMessage], source_ids=source_id, **self.params)
pose_data = result[PoseMessage.MESSAGE_TYPE]

if len(pose_data.p1_time) == 0:
Expand Down Expand Up @@ -883,7 +893,7 @@ def plot_map(self, mapbox_token):
"""!
@brief Plot a map of the position data.
"""
if self.output_dir is None:
if self.output_dir is None or len(self.source_ids) == 0:
return

mapbox_token = self.get_mapbox_token(mapbox_token)
Expand Down
33 changes: 21 additions & 12 deletions python/fusion_engine_client/parsers/mixed_log_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,7 +452,7 @@ def parse_entry_at_index(self, index: file_index.FileIndexEntry):
def clear_filters(self):
self.filter_in_place(key=None, clear_existing=True)

def filter_in_place(self, key, clear_existing: bool = False,
def filter_in_place(self, key, clear_existing: Union[bool, str] = False,
source_ids: Union[List[int], set[int], tuple[int]] = None):
"""!
@brief Limit the returned messages by type or time.
Expand Down Expand Up @@ -483,15 +483,26 @@ def filter_in_place(self, key, clear_existing: bool = False,
# that we just read.
prev_offset_bytes = self.index.offset[self.next_index_elem - 1]

# If requested, clear previous filter criteria.
if clear_existing:
if self.index is None:
if type(clear_existing) == str:
# Verify input string and clear accordingly.
if clear_existing == 'message_type':
self.message_types = copy.deepcopy(self._original_message_types)
elif clear_existing == 'time_range':
self.time_range = copy.deepcopy(self._original_time_range)
self.remove_invalid_p1_time = False
self.requested_source_ids = None
elif clear_existing == 'source_id':
self.requested_source_ids = set()
else:
self.index = self._original_index
raise ValueError('Invalid clear_existing flag: %s' % clear_existing)
else:
# If requested, clear previous filter criteria.
if clear_existing:
if self.index is None:
self.message_types = copy.deepcopy(self._original_message_types)
self.time_range = copy.deepcopy(self._original_time_range)
self.remove_invalid_p1_time = False
self.requested_source_ids = None
else:
self.index = self._original_index

# Set requested source IDs.
if source_ids is not None:
Expand All @@ -504,7 +515,8 @@ def filter_in_place(self, key, clear_existing: bool = False,
'source IDs: {}'.format(unavailable_source_ids))
source_ids = list(source_ids.intersection(self.available_source_ids))
if len(source_ids) == 0:
raise ValueError("Requested source ID(s) unavailable. Exiting.")
self.logger.debug('Requested source IDs unavailable. Cannot extract data.')
self.filter_in_place(None, clear_existing='source_id')

if len(unavailable_source_ids) > 0:
self.logger.info('Extracting the following available requested source IDs: {}'.format(source_ids))
Expand Down Expand Up @@ -604,10 +616,7 @@ def _populate_available_source_ids(self, num_messages_to_read: int = 10):
self.available_source_ids = set()
# Loop over all message types and read N of each type.
for message_type in np.unique(self.index['type']):
try:
message_type = MessageType(message_type, raise_on_unrecognized=True)
except (KeyError, ValueError) as e:
continue
message_type = MessageType(message_type, raise_on_unrecognized=False)
self.filter_in_place(message_type)
num_messages_read = 0
while num_messages_read < num_messages_to_read:
Expand Down

0 comments on commit 04a0b88

Please sign in to comment.