Skip to content

Commit

Permalink
Resolved additional source ID handling issues. (#320)
Browse files Browse the repository at this point in the history
# Changes
- Replaced `DataLoader.source_ids` with `get_available_source_ids()` helper function
- Allow integer values in `source_ids` arguments for convenience

# Fixes
- Fixed overlap of data with multiple sources in certain plots
  • Loading branch information
adamshapiro0 authored Jun 6, 2024
2 parents 935ebea + 73d7d52 commit a086a52
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 36 deletions.
37 changes: 16 additions & 21 deletions python/fusion_engine_client/analysis/analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,20 +145,23 @@ def __init__(self,
# If source ID was unspecified, use _all_ source IDs found in the log. If source ID _was_ specified, use the
# intersection of the requested source ID(s) and the available source IDs.
if source_id is None:
self.source_ids = self.reader.source_ids
self.source_ids = self.reader.get_available_source_ids()
else:
source_ids = set(source_id)
unavailable_source_ids = source_ids.difference(self.reader.source_ids)
unavailable_source_ids = source_ids.difference(self.reader.get_available_source_ids())
if len(unavailable_source_ids) > 0:
self.logger.warning('Not all source IDs requested are available. Cannot extract the following '
'source IDs: {}'.format(unavailable_source_ids))

self.source_ids = source_ids.intersection(self.reader.source_ids)
# If the requested source IDs are unavailable, raise error.
self.source_ids = source_ids.intersection(self.reader.get_available_source_ids())
# If the requested pose source IDs are unavailable, warn.
if len(self.source_ids) == 0:
self.logger.warning('Requested source IDs unavailable. Cannot extract data.')
log_reader = self.reader.get_log_reader()
log_reader.filter_in_place(None, clear_existing='source_id')
self.logger.warning('Requested source IDs unavailable. Cannot extract pose data.')

if len(self.source_ids) > 0:
self.default_source_id = min(self.source_ids)
else:
self.default_source_id = 0

if time_axis in ('relative', 'rel'):
self.time_axis = 'relative'
Expand Down Expand Up @@ -484,11 +487,7 @@ def plot_pose(self):
return

# Read the pose data.
if len(self.source_ids) > 0:
source_id = [min(self.source_ids)]
else:
source_id = self.source_ids
result = self.reader.read(message_types=[PoseMessage], source_ids=source_id, **self.params)
result = self.reader.read(message_types=[PoseMessage], source_ids=self.default_source_id, **self.params)
pose_data = result[PoseMessage.MESSAGE_TYPE]

if len(pose_data.p1_time) == 0:
Expand Down Expand Up @@ -703,11 +702,7 @@ def plot_solution_type(self):
return

# Read the pose data.
if len(self.source_ids) > 0:
source_id = [min(self.source_ids)]
else:
source_id = self.source_ids
result = self.reader.read(message_types=[PoseMessage], source_ids=source_id, **self.params)
result = self.reader.read(message_types=[PoseMessage], source_ids=self.default_source_id, **self.params)
pose_data = result[PoseMessage.MESSAGE_TYPE]

if len(pose_data.p1_time) == 0:
Expand Down Expand Up @@ -831,7 +826,7 @@ def plot_pose_displacement(self):
return

# Read the pose data.
result = self.reader.read(message_types=[PoseMessage], source_ids=self.source_ids, **self.params)
result = self.reader.read(message_types=[PoseMessage], source_ids=self.default_source_id, **self.params)
pose_data = result[PoseMessage.MESSAGE_TYPE]

if len(pose_data.p1_time) == 0:
Expand Down Expand Up @@ -1567,7 +1562,7 @@ def _get_time_source(meas_type, data):
# If we have pose messages _and_ they contain body velocity, we can use that.
#
# Note that we are using this to compare vs wheel speeds, so we're only interested in forward speed here.
result = self.reader.read(message_types=[PoseMessage], source_ids=self.source_ids, **self.params)
result = self.reader.read(message_types=[PoseMessage], source_ids=self.default_source_id, **self.params)
pose_data = result[PoseMessage.MESSAGE_TYPE]
if len(pose_data.p1_time) != 0 and np.any(~np.isnan(pose_data.velocity_body_mps[0, :])):
nav_engine_p1_time = pose_data.p1_time
Expand Down Expand Up @@ -1821,7 +1816,7 @@ def plot_heading_measurements(self):

# Note that we read the pose data after heading, that way we don't bother reading pose data from disk if there's
# no heading data in the log.
result = self.reader.read(message_types=[PoseMessage], source_ids=self.source_ids, **self.params)
result = self.reader.read(message_types=[PoseMessage], source_ids=self.default_source_id, **self.params)
primary_pose_data = result[PoseMessage.MESSAGE_TYPE]

# Setup the figure.
Expand Down Expand Up @@ -2250,7 +2245,7 @@ def _set_data_summary(self):
log_duration_sec, processing_duration_sec, reduced_index = self._calculate_duration(return_index=True)

# Create a table with position solution type statistics.
result = self.reader.read(message_types=[PoseMessage], source_ids=self.source_ids, **self.params)
result = self.reader.read(message_types=[PoseMessage], source_ids=self.default_source_id, **self.params)
pose_data = result[PoseMessage.MESSAGE_TYPE]
num_pose_messages = len(pose_data.solution_type)
solution_type_count = {}
Expand Down
18 changes: 10 additions & 8 deletions python/fusion_engine_client/analysis/data_loader.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from enum import Enum, auto
from typing import Dict, Iterable, Union
from typing import Dict, Iterable, Set, Union

from collections import deque
from datetime import datetime, timezone
from enum import Enum, auto

from gpstime import gpstime, unix2gps
import numpy as np
Expand Down Expand Up @@ -176,9 +176,6 @@ def open(self, path, save_index=True, ignore_index=False):
self.reader = MixedLogReader(input_file=path, save_index=save_index, ignore_index=ignore_index,
return_bytes=True, return_message_index=True)

# By default, use all available source IDs. Filtering by source ID may be done with the read() function.
self.source_ids = self.reader.available_source_ids

# Read the first message (with P1 time) in the file to set self.t0.
#
# Note that we explicitly set a start time since, if the time range is not specified, read() will include
Expand Down Expand Up @@ -237,8 +234,8 @@ def read(self, *args, **kwargs) \
be returned. If `None` or an empty list, read all available messages.
@param time_range An optional @ref TimeRange object specifying desired start and end time bounds of the data to
be read. See @ref TimeRange for more details.
@param source_ids An optional list of one or more source identifiers to be returned. If `None` or an empty list,
use all available source identifiers.
@param source_ids An optional list message source identifiers to be returned. If `None`, read messages from
available source identifiers.
@param show_progress If `True`, print the read progress every 10 MB (useful for large files).
@param ignore_cache If `True`, ignore any cached data from a previous @ref read() call, and reload the requested
Expand Down Expand Up @@ -307,7 +304,9 @@ def _read(self,

if source_ids is None:
source_ids = self.reader.get_available_source_ids()
if source_ids is not None:
elif isinstance(source_ids, int):
source_ids = {source_ids}
else:
source_ids = set(source_ids)

# Store the set of parameters used to perform this read along with the cache data. When doing reads for the
Expand Down Expand Up @@ -635,6 +634,9 @@ def get_log_reader(self) -> MixedLogReader:
def get_input_path(self):
return self.reader.input_file.name

def get_available_source_ids(self) -> Set[int]:
return self.reader.get_available_source_ids()

def _convert_time(self, conversion_type: TimeConversionType,
times: Union[Iterable[Union[datetime, gpstime, Timestamp, float]],
Union[datetime, gpstime, Timestamp, float]],
Expand Down
21 changes: 14 additions & 7 deletions python/fusion_engine_client/parsers/mixed_log_reader.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Iterable, List, Union, Optional
from typing import Iterable, List, Set, Union, Optional

import copy
from datetime import datetime
Expand Down Expand Up @@ -46,9 +46,8 @@ def __init__(self, input_file, warn_on_gaps: bool = False, show_progress: bool =
be read. See @ref TimeRange for more details.
@param message_types A list of one or more @ref fusion_engine_client.messages.defs.MessageType "MessageTypes" to
be returned. If `None` or an empty list, read all available messages.
@param source_ids An optional list of one or more source identifiers to be used when extracting @ref PoseMessage
messages. If `None`, use all available source identifiers. If an empty list, use no @ref PoseMessage
messages.
@param source_ids An optional list message source identifiers to be returned. If `None`, read messages from
available source identifiers.
@param return_header If `True`, return the decoded @ref MessageHeader for each message.
@param return_payload If `True`, parse and return the payload for each message as a subclass of @ref
MessagePayload. Will return `None` if the payload cannot be parsed.
Expand Down Expand Up @@ -83,6 +82,8 @@ def __init__(self, input_file, warn_on_gaps: bool = False, show_progress: bool =
# The source IDs requested by the user. If none were requested, then use all of them.
if source_ids is None:
self.requested_source_ids = None
elif isinstance(source_ids, int):
self.requested_source_ids = {source_ids}
else:
self.requested_source_ids = set(source_ids)
# The source IDs that are available in the log. This will be populated below when
Expand Down Expand Up @@ -455,7 +456,7 @@ def clear_filters(self):
self.filter_in_place(key=None, clear_existing=True)

def filter_in_place(self, key, clear_existing: Union[bool, str] = False,
source_ids: Iterable[int] = None):
source_ids: Optional[Iterable[int]] = None):
"""!
@brief Limit the returned messages by type or time.
Expand All @@ -467,7 +468,13 @@ def filter_in_place(self, key, clear_existing: Union[bool, str] = False,
- An iterable listing one or more @ref MessageType%s to be returned
- A `slice` specifying the start/end of the desired absolute (P1) or relative time range
- A @ref TimeRange object
@param clear_existing If `True`, clear any previous filter criteria.
@param clear_existing One of the following:
- `True` - Clear any previous filter criteria
- `False` - Add the new criteria to existing filters
- `'message_type'` - Clear previous message type filtering
- `'time_range'` - Clear previous time range filtering
- `'source_id'` - Clear previous source identifier filtering
@param source_ids If set, limit results to messages using one of the specified source identifier values.
@return A reference to this class.
"""
Expand Down Expand Up @@ -611,7 +618,7 @@ def filter_out_invalid_p1_times(self, clear_existing: bool = False):

return self

def get_available_source_ids(self):
def get_available_source_ids(self) -> Set[int]:
return self.available_source_ids

def _populate_available_source_ids(self, num_messages_to_read: int = 10):
Expand Down

0 comments on commit a086a52

Please sign in to comment.