Skip to content

Commit

Permalink
Make FileIndex flexible for different type keys.
Browse files Browse the repository at this point in the history
  • Loading branch information
Jonathan Diamond committed Sep 28, 2023
1 parent 18c57a2 commit eb77310
Showing 1 changed file with 16 additions and 12 deletions.
28 changes: 16 additions & 12 deletions python/fusion_engine_client/parsers/file_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import numpy as np

from ..messages import MessageHeader, MessagePayload, MessageType, Timestamp
from ..utils.enum_utils import IntEnum
from ..utils.numpy_utils import find_first
from ..utils.time_range import TimeRange

Expand All @@ -16,15 +17,16 @@


class FileIndexIterator(object):
def __init__(self, np_iterator):
def __init__(self, enum_class=None, np_iterator=None):
self.enum_class = enum_class
self.np_iterator = np_iterator

def __next__(self):
if self.np_iterator is None:
raise StopIteration()
else:
entry = next(self.np_iterator)
return FileIndexEntry(time=Timestamp(entry[0]), type=MessageType(entry[1], raise_on_unrecognized=False),
return FileIndexEntry(time=Timestamp(entry[0]), type=self.enum_class(entry[1]),
offset=entry[2], message_index=entry[3])


Expand Down Expand Up @@ -124,7 +126,7 @@ class FileIndex(object):
_DTYPE = np.dtype([('time', '<f8'), ('type', '<u2'), ('offset', '<u8'), ('message_index', '<u8')])

def __init__(self, index_path: str = None, data_path: str = None, delete_on_error=True,
data: Union[np.ndarray, list] = None, t0: Timestamp = None):
data: Union[np.ndarray, list] = None, t0: Timestamp = None, enum_class = MessageType):
"""!
@brief Construct a new @ref FileIndex instance.
Expand All @@ -136,7 +138,9 @@ def __init__(self, index_path: str = None, data_path: str = None, delete_on_erro
@param data A NumPy `ndarray` or Python `list` containing information about each FusionEngine message in the
`.p1log` file. For internal use.
@param t0 The P1 time corresponding with the start of the `.p1log` file, if known. For internal use.
@param enum_class The @ref IntEnum class used to identify the messages.
"""
self.enum_class = enum_class
if data is None:
self._data = None
else:
Expand Down Expand Up @@ -191,7 +195,7 @@ def load(self, index_path, data_path=None, delete_on_error=True):
if not os.path.exists(data_path):
# If the user didn't explicitly set data_path and the default file doesn't exist, it is not considered
# an error.
if self._data['type'][-1] == MessageType.INVALID:
if self._data['type'][-1] == self.enum_class.INVALID:
self._data = self._data[:-1]
return
elif not os.path.exists(data_path):
Expand All @@ -216,7 +220,7 @@ def load(self, index_path, data_path=None, delete_on_error=True):
# Get the last entry in the index. If its message type is INVALID, it's a special marker at the end of the
# index file indicating the size of the binary data file when the index was created. If it exists, we can
# use it to check if the data file size has changed.
if self.type[-1] == MessageType.INVALID:
if self.type[-1] == self.enum_class.INVALID:
expected_data_file_size = self.offset[-1]
self._data = self._data[:-1]

Expand Down Expand Up @@ -267,9 +271,9 @@ def save(self, index_path: str, data_path: str):
if len(self._data) > 0:
# Append an EOF marker at the end of the data if data_path is specified.
data = self._data
if data['type'][-1] != MessageType.INVALID and data_path is not None:
if data['type'][-1] != self.enum_class.INVALID and data_path is not None:
file_size_bytes = os.stat(data_path).st_size
data = np.append(data, np.array((np.nan, int(MessageType.INVALID), file_size_bytes, -1),
data = np.append(data, np.array((np.nan, self.enum_class.INVALID, file_size_bytes, -1),
dtype=FileIndex._DTYPE))

raw_data = FileIndex._to_raw(data)
Expand Down Expand Up @@ -384,14 +388,14 @@ def __getitem__(self, key):
elif len(self._data) == 0:
return FileIndex()
# Return entries for a specific message type.
elif isinstance(key, MessageType):
elif isinstance(key, IntEnum):
idx = self._data['type'] == key
return FileIndex(data=self._data[idx], t0=self.t0)
elif MessagePayload.is_subclass(key):
idx = self._data['type'] == key.get_type()
return FileIndex(data=self._data[idx], t0=self.t0)
# Return entries for a list of message types.
elif isinstance(key, (set, list, tuple)) and len(key) > 0 and isinstance(next(iter(key)), MessageType):
elif isinstance(key, (set, list, tuple)) and len(key) > 0 and isinstance(next(iter(key)), IntEnum):
idx = np.isin(self._data['type'], [int(k) for k in key])
return FileIndex(data=self._data[idx], t0=self.t0)
elif isinstance(key, (set, list, tuple)) and len(key) > 0 and MessagePayload.is_subclass(next(iter(key))):
Expand Down Expand Up @@ -429,9 +433,9 @@ def __getitem__(self, key):

def __iter__(self):
if len(self._data) == 0:
return FileIndexIterator(None)
return FileIndexIterator()
else:
return FileIndexIterator(iter(self._data))
return FileIndexIterator(self.enum_class, iter(self._data))

@classmethod
def get_path(cls, data_path):
Expand Down Expand Up @@ -503,7 +507,7 @@ def from_file(self, data_path: str):
self.append(message_type=header.message_type, offset_bytes=offset_bytes, p1_time=p1_time)
return self.to_index()

def append(self, message_type: MessageType, offset_bytes: int, p1_time: Timestamp = None):
def append(self, message_type: IntEnum, offset_bytes: int, p1_time: Timestamp = None):
"""!
@brief Add an entry to the index data being accumulated.
Expand Down

0 comments on commit eb77310

Please sign in to comment.