diff --git a/python/fusion_engine_client/analysis/pose_compare.py b/python/fusion_engine_client/analysis/pose_compare.py index d8af45c6..a84f2011 100755 --- a/python/fusion_engine_client/analysis/pose_compare.py +++ b/python/fusion_engine_client/analysis/pose_compare.py @@ -38,11 +38,6 @@ SolutionTypeInfo = namedtuple('SolutionTypeInfo', ['name', 'style']) -_LOG_NAMES = [ - 'Test', - 'Reference', -] - _SOLUTION_TYPE_MAP = [{ SolutionType.Invalid: SolutionTypeInfo(name='Invalid', style={'color': 'black'}), SolutionType.Integrate: SolutionTypeInfo(name='Integrated', style={'color': 'cyan'}), @@ -122,7 +117,7 @@ def __init__(self, file_test: Union[DataLoader, str], file_reference: Union[DataLoader, str], ignore_index: bool = False, output_dir: str = None, prefix: str = '', time_range: TimeRange = None, max_messages: int = None, - time_axis: str = 'relative'): + time_axis: str = 'relative', test_device_name=None, reference_device_name=None): """! @brief Create an analyzer for the comparing the pose from two p1log files. @@ -139,6 +134,8 @@ def __init__(self, @param time_axis Specify the way in which time will be plotted: - `absolute`, `abs` - Absolute P1 or system timestamps - `relative`, `rel` - Elapsed time since the start of the log + @param test_device_name If set, label test data with this device name. + @param reference_device_name If set, label reference data with this device name. """ self.params = { 'time_range': time_range, @@ -166,10 +163,22 @@ def __init__(self, if len(self.reference_pose.p1_time) == 0: raise ValueError('Reference log did not contain pose data.') + self.test_data_name = 'Test' + if test_device_name: + self.test_data_name += ' ' + test_device_name + + self.reference_data_name = 'Reference' + if reference_device_name: + self.reference_data_name += ' ' + reference_device_name + + self.data_names = [ + self.test_data_name, + self.reference_data_name, + ] + self.output_dir = output_dir self.prefix = prefix - gps_time_test = self.test_pose.gps_time valid_gps_time = gps_time_test[np.isfinite(gps_time_test)] if len(valid_gps_time) == 0: @@ -226,7 +235,7 @@ def plot_solution_type(self): ticktext=['%s (%d)' % (e.name, e.value) for e in SolutionType], tickvals=[e.value for e in SolutionType]) - for name, pose_data in zip(_LOG_NAMES, (self.test_pose, self.reference_pose)): + for name, pose_data in zip(self.data_names, (self.test_pose, self.reference_pose)): time = pose_data.gps_time - float(self.t0) text = ["Time: %.3f sec (%.3f sec)" % (t, t + float(self.t0)) for t in time] @@ -256,7 +265,7 @@ def plot_map(self, mapbox_token): map_data = [] for i, pose_data in enumerate((self.test_pose, self.reference_pose)): - log_name = _LOG_NAMES[i] + log_name = self.data_names[i] # Remove invalid solutions. valid_idx = np.logical_and(~np.isnan(pose_data.p1_time), pose_data.solution_type != SolutionType.Invalid) @@ -326,7 +335,8 @@ def _plot_pose_error(self, time, solution_type, error_enu_m, std_enu_m): time_figure = make_subplots(rows=4, cols=1, print_grid=False, shared_xaxes=True, subplot_titles=['3D', 'East', 'North', 'Up']) - time_figure['layout'].update(showlegend=True, modebar_add=['v1hovermode']) + time_figure['layout'].update(showlegend=True, modebar_add=['v1hovermode'], + title=f'{self.test_data_name} Vs. {self.reference_data_name}') for i in range(4): time_figure['layout']['xaxis%d' % (i + 1)].update(title=self.gps_time_label, showticklabels=True) time_figure['layout']['yaxis1'].update(title="Error (m)") @@ -515,13 +525,13 @@ def _set_data_summary(self): log_b_end = log_b_p1_time[-1] data = [[], []] - data[0].append('Test Log Start') + data[0].append(self.test_data_name + ' Log Start') data[1].append(log_a_t0) - data[0].append('Reference Log Start') + data[0].append(self.reference_data_name + ' Log Start') data[1].append(log_b_t0) - data[0].append('Test Log Duration') + data[0].append(self.test_data_name + ' Log Duration') data[1].append(log_a_end - log_a_t0) - data[0].append('Reference Log Duration') + data[0].append(self.reference_data_name + ' Log Duration') data[1].append(log_b_end - log_b_t0) data[0].append('Matched GPS Epochs') data[1].append(len(matched_p1_time_a)) @@ -550,6 +560,8 @@ def get_stats(label, values): error_table = _data_to_table(columns, data, row_major=True, round_decimal_places=2) self.summary += f""" +