From eb77310ea29172a63be8710e9fc12d6c8c2fb829 Mon Sep 17 00:00:00 2001 From: Jonathan Diamond Date: Thu, 28 Sep 2023 15:40:37 -0700 Subject: [PATCH] Make FileIndex flexible for different type keys. --- .../parsers/file_index.py | 28 +++++++++++-------- 1 file changed, 16 insertions(+), 12 deletions(-) diff --git a/python/fusion_engine_client/parsers/file_index.py b/python/fusion_engine_client/parsers/file_index.py index ef33dc43..58519e81 100644 --- a/python/fusion_engine_client/parsers/file_index.py +++ b/python/fusion_engine_client/parsers/file_index.py @@ -8,6 +8,7 @@ import numpy as np from ..messages import MessageHeader, MessagePayload, MessageType, Timestamp +from ..utils.enum_utils import IntEnum from ..utils.numpy_utils import find_first from ..utils.time_range import TimeRange @@ -16,7 +17,8 @@ class FileIndexIterator(object): - def __init__(self, np_iterator): + def __init__(self, enum_class=None, np_iterator=None): + self.enum_class = enum_class self.np_iterator = np_iterator def __next__(self): @@ -24,7 +26,7 @@ def __next__(self): raise StopIteration() else: entry = next(self.np_iterator) - return FileIndexEntry(time=Timestamp(entry[0]), type=MessageType(entry[1], raise_on_unrecognized=False), + return FileIndexEntry(time=Timestamp(entry[0]), type=self.enum_class(entry[1]), offset=entry[2], message_index=entry[3]) @@ -124,7 +126,7 @@ class FileIndex(object): _DTYPE = np.dtype([('time', ' 0: # Append an EOF marker at the end of the data if data_path is specified. data = self._data - if data['type'][-1] != MessageType.INVALID and data_path is not None: + if data['type'][-1] != self.enum_class.INVALID and data_path is not None: file_size_bytes = os.stat(data_path).st_size - data = np.append(data, np.array((np.nan, int(MessageType.INVALID), file_size_bytes, -1), + data = np.append(data, np.array((np.nan, self.enum_class.INVALID, file_size_bytes, -1), dtype=FileIndex._DTYPE)) raw_data = FileIndex._to_raw(data) @@ -384,14 +388,14 @@ def __getitem__(self, key): elif len(self._data) == 0: return FileIndex() # Return entries for a specific message type. - elif isinstance(key, MessageType): + elif isinstance(key, IntEnum): idx = self._data['type'] == key return FileIndex(data=self._data[idx], t0=self.t0) elif MessagePayload.is_subclass(key): idx = self._data['type'] == key.get_type() return FileIndex(data=self._data[idx], t0=self.t0) # Return entries for a list of message types. - elif isinstance(key, (set, list, tuple)) and len(key) > 0 and isinstance(next(iter(key)), MessageType): + elif isinstance(key, (set, list, tuple)) and len(key) > 0 and isinstance(next(iter(key)), IntEnum): idx = np.isin(self._data['type'], [int(k) for k in key]) return FileIndex(data=self._data[idx], t0=self.t0) elif isinstance(key, (set, list, tuple)) and len(key) > 0 and MessagePayload.is_subclass(next(iter(key))): @@ -429,9 +433,9 @@ def __getitem__(self, key): def __iter__(self): if len(self._data) == 0: - return FileIndexIterator(None) + return FileIndexIterator() else: - return FileIndexIterator(iter(self._data)) + return FileIndexIterator(self.enum_class, iter(self._data)) @classmethod def get_path(cls, data_path): @@ -503,7 +507,7 @@ def from_file(self, data_path: str): self.append(message_type=header.message_type, offset_bytes=offset_bytes, p1_time=p1_time) return self.to_index() - def append(self, message_type: MessageType, offset_bytes: int, p1_time: Timestamp = None): + def append(self, message_type: IntEnum, offset_bytes: int, p1_time: Timestamp = None): """! @brief Add an entry to the index data being accumulated.