Skip to content

Commit

Permalink
Add New Data Types for Simularium Frame Data (#157)
Browse files Browse the repository at this point in the history
* add new data objects for simularium file data for octopus

* reorganize small data classes to be in one file

* fix end index

* update comments, return types

* add checks to make sure requested frame number is valid

* add logs for debugging

* initialize frame_metadata to [] before appending to it

* initialize frame metadata and block indices to empty

* code cleanup

* formatting fixes

* update comments

* remove unused imports

* add some more test

* add __eq__ to TrajectoryData for testing purposes

* reorganize test to improve readibility

* lint fixes

* use existing test data files instead of bringing in new ones

* remove file_name as a parameter

* remove file_name as a parameter in test

* move data classes that are only used by BinaryData to that file

* use the defined constants instead of the values for indexing in binary data

* Compute get_n_agents separately instead of building a new AgentData object

* remove unused import

* use BinaryBlockInfo to track block indices, rather than using a new data type

* remove unused import

* add abstract method decorators

* lint fixes
  • Loading branch information
ascibisz authored Oct 30, 2023
1 parent cdc9339 commit 5d482f1
Show file tree
Hide file tree
Showing 7 changed files with 461 additions and 0 deletions.
4 changes: 4 additions & 0 deletions simulariumio/data_objects/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,7 @@
from .model_meta_data import ModelMetaData # noqa: F401
from .histogram_plot_data import HistogramPlotData # noqa: F401
from .scatter_plot_data import ScatterPlotData # noqa: F401
from .json_data import JsonData # noqa: F401
from .binary_data import BinaryData # noqa: F401
from .simularium_file_data import SimulariumFileData # noqa: F401
from .frame_data import FrameData # noqa: F401
184 changes: 184 additions & 0 deletions simulariumio/data_objects/binary_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
from typing import Dict, List, Tuple
import numpy as np

from .frame_data import FrameData
from .input_file_data import InputFileData
from .trajectory_data import TrajectoryData
from .simularium_file_data import SimulariumFileData
from ..constants import BINARY_BLOCK_TYPE, BINARY_SETTINGS
from ..readers import BinaryBlockInfo, SimulariumBinaryReader


class BinaryData(SimulariumFileData):
def __init__(self, file_contents: bytes):
"""
This object holds binary encoded simulation trajectory file's
data while staying close to the original file format
Parameters
----------
file_contents : bytes
A byte array containing the data of an open .simularium file
"""
self.file_contents = InputFileData(file_contents=file_contents)
self.file_data = SimulariumBinaryReader._binary_data_from_source(
self.file_contents
)
self.frame_metadata: List[FrameMetadata] = []
self.block_info: BinaryBlockInfo = None
# Maps block type id to block index
self.block_indices: Dict[int, int] = {}
self._parse_file()

def _parse_file(self):
# Read offset and length for each data block
self.block_info = SimulariumBinaryReader._parse_binary_header(
self.file_data.byte_view
)
for block_index in range(self.block_info.n_blocks):
block_type_id = SimulariumBinaryReader._binary_block_type(
block_index, self.block_info, self.file_data.int_view
)
self.block_indices[block_type_id] = block_index

# Extract each frame's metadata
spatial_block_index = self.block_indices[
BINARY_BLOCK_TYPE.SPATIAL_DATA_BINARY.value
]
block_offset = self.block_info.block_offsets[spatial_block_index]
spatial_block_offset = (
int(block_offset / BINARY_SETTINGS.BYTES_PER_VALUE)
+ BINARY_SETTINGS.BLOCK_HEADER_N_VALUES
)
n_frames = self.file_data.int_view[spatial_block_offset + 1]
current_frame_offset = (
spatial_block_offset
+ BINARY_SETTINGS.SPATIAL_BLOCK_HEADER_CONSTANT_N_VALUES
+ BINARY_SETTINGS.SPATIAL_BLOCK_HEADER_N_VALUES_PER_FRAME * n_frames
)
for i in range(n_frames):
offset = (
self.file_data.int_view[
spatial_block_offset
+ BINARY_SETTINGS.SPATIAL_BLOCK_HEADER_CONSTANT_N_VALUES
+ BINARY_SETTINGS.SPATIAL_BLOCK_HEADER_N_VALUES_PER_FRAME * i
]
+ block_offset
)
length = self.file_data.int_view[
spatial_block_offset
+ BINARY_SETTINGS.SPATIAL_BLOCK_HEADER_CONSTANT_N_VALUES
+ BINARY_SETTINGS.SPATIAL_BLOCK_HEADER_N_VALUES_PER_FRAME * i
+ 1
]
frame_number = self.file_data.int_view[current_frame_offset]
time = self.file_data.float_view[current_frame_offset + 1]
self.frame_metadata.append(
FrameMetadata(offset, length, frame_number, time)
)
current_frame_offset += int(length / BINARY_SETTINGS.BYTES_PER_VALUE)

def get_frame_at_index(self, frame_number: int) -> FrameData:
"""
Return frame data for frame at index. If there is no frame at the index,
return None.
"""
if frame_number < 0 or frame_number >= len(self.frame_metadata):
# invalid frame number requested
return None

metadata: FrameMetadata = self.frame_metadata[frame_number]
start, end = metadata.get_start_end_indices()
data = self.file_data.byte_view[start:end]
return FrameData(
frame_number=frame_number,
n_agents=self.file_data.int_view[
int(start / BINARY_SETTINGS.BYTES_PER_VALUE) + 2
],
time=metadata.time,
data=data,
)

def get_index_for_time(self, time: float) -> int:
"""
Return index for frame closest to a given timestamp
"""
closest_frame = -1
min_dist = np.inf
for frame in self.frame_metadata:
dist = abs(frame.time - time)
if dist < min_dist:
min_dist = dist
closest_frame = frame.frame_number
else:
# if dist is increasing, we've passed the closest frame
break

# frame index must be <= self.get_num_frames() - 1
return min(closest_frame, self.get_num_frames() - 1)

def get_trajectory_info(self) -> Dict:
"""
Return trajectory info block for trajectory, as dict
"""
block_index = self.block_indices[BINARY_BLOCK_TYPE.TRAJ_INFO_JSON.value]
return SimulariumBinaryReader._binary_block_json(
block_index, self.block_info, self.file_data.byte_view
)

def get_plot_data(self) -> Dict:
"""
Return plot data block for trajectory, as dict
"""
block_index = self.block_indices[BINARY_BLOCK_TYPE.PLOT_DATA_JSON.value]
return SimulariumBinaryReader._binary_block_json(
block_index, self.block_info, self.file_data.byte_view
)

def get_trajectory_data_object(self) -> TrajectoryData:
"""
Return the data of the trajectory, as a TrajectoryData object
"""
trajectory_dict = SimulariumBinaryReader.load_binary(self.file_contents)
return TrajectoryData.from_buffer_data(trajectory_dict)

def get_file_contents(self) -> bytes:
"""
Return raw file data, as bytes
"""
return self.file_contents.get_contents()

def get_num_frames(self) -> int:
"""
Return number of frames in the trajectory
"""
return len(self.frame_metadata)


class FrameMetadata:
def __init__(self, offset: int, length: int, frame_number: int, time: float):
"""
This object holds metadata for a single frame of simularium data
Parameters
----------
offset : int
Number of bytes the block is offset from the start of the byte array
length : int
Number of bytes in the block
frame_number : int
Index of frame in the simulation
time : float
Elapsed simulation time of the frame
"""
self.offset = offset
self.length = length
self.frame_number = frame_number
self.time = time

def get_start_end_indices(self) -> Tuple[int, int]:
"""
Return the start and end indicies for the data block
"""
end = self.offset + self.length
return self.offset, end
26 changes: 26 additions & 0 deletions simulariumio/data_objects/frame_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from typing import Dict, Union


class FrameData:
def __init__(
self, frame_number: int, n_agents: int, time: float, data: Union[bytes, Dict]
):
"""
This object holds frame data for a single frame of simularium data
Parameters
----------
frame_number : int
Index of frame in the simulation
n_agents : int
Number of agents included in the frame
time : float
Elapsed simulation time of the frame
data : bytes or dict
Spatial data for the frame, as a byte array for binary encoded
.simularium files or as a dict for JSON .simularium files
"""
self.frame_number = frame_number
self.n_agents = n_agents
self.time = time
self.data = data
108 changes: 108 additions & 0 deletions simulariumio/data_objects/json_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
from typing import Dict, List
import json
import numpy as np

from .frame_data import FrameData
from .simularium_file_data import SimulariumFileData
from .trajectory_data import TrajectoryData
from ..constants import V1_SPATIAL_BUFFER_STRUCT


class JsonData(SimulariumFileData):
def __init__(self, file_contents: str):
"""
This object holds JSON encoded simulation trajectory file's
data while staying close to the original file format
Parameters
----------
file_contents : str
A string of the data of an open .simularium file
"""
self.data = json.loads(file_contents)
self.n_agents = JsonData._get_n_agents(self.data)

def _get_n_agents(data: Dict) -> List[int]:
# return number of agents in each timestamp as a list
n_agents = []
bundle_data = data["spatialData"]["bundleData"]
for time_index in range(data["trajectoryInfo"]["totalSteps"]):
frame_data = bundle_data[time_index]["data"]
agent_index = 0
buffer_index = 0
while buffer_index + V1_SPATIAL_BUFFER_STRUCT.NSP_INDEX < len(frame_data):
n_subpoints = int(
frame_data[buffer_index + V1_SPATIAL_BUFFER_STRUCT.NSP_INDEX]
)
# length of one agents spatial data = SP_INDEX + n_subpoints
buffer_index += n_subpoints + V1_SPATIAL_BUFFER_STRUCT.SP_INDEX
agent_index += 1
n_agents.append(agent_index)
return n_agents

def get_frame_at_index(self, frame_number: int) -> FrameData:
"""
Return frame data for frame at index. If there is no frame at the index,
return None.
"""
if frame_number < 0 or frame_number >= len(
self.data["spatialData"]["bundleData"]
):
# invalid frame number requested
return None

frame_data = self.data["spatialData"]["bundleData"][frame_number]
return FrameData(
frame_number=frame_number,
n_agents=self.n_agents[frame_number],
time=frame_data["time"],
data=frame_data["data"],
)

def get_index_for_time(self, time: float) -> int:
"""
Return index for frame closest to a given timestamp
"""
closest_frame = -1
min_dist = np.inf
for frame in self.data["spatialData"]["bundleData"]:
dist = abs(frame["time"] - time)
if dist < min_dist:
min_dist = dist
closest_frame = frame["frameNumber"]
else:
# if dist is increasing, we've passed the closest frame
break

# frame index must be <= self.get_num_frames() - 1
return min(closest_frame, self.get_num_frames() - 1)

def get_trajectory_info(self) -> Dict:
"""
Return trajectory info block for trajectory, as dict
"""
return self.data["trajectoryInfo"]

def get_plot_data(self) -> Dict:
"""
Return plot data block for trajectory, as dict
"""
return self.data["plotData"]

def get_trajectory_data_object(self) -> TrajectoryData:
"""
Return the data of the trajectory, as a TrajectoryData object
"""
return TrajectoryData.from_buffer_data(self.data)

def get_file_contents(self) -> Dict:
"""
Return raw file data, as a dict
"""
return self.data

def get_num_frames(self) -> int:
"""
Return number of frames in the trajectory
"""
return len(self.data["spatialData"]["bundleData"])
37 changes: 37 additions & 0 deletions simulariumio/data_objects/simularium_file_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from typing import Dict, Union
from abc import ABC, abstractmethod
from .trajectory_data import TrajectoryData
from .frame_data import FrameData


class SimulariumFileData(ABC):
def __init__(self, file_contents: Union[str, bytes]):
pass

@abstractmethod
def get_frame_at_index(self, frame_number: int) -> Union[FrameData, None]:
pass

@abstractmethod
def get_index_for_time(self, time: float) -> int:
pass

@abstractmethod
def get_trajectory_info(self) -> Dict:
pass

@abstractmethod
def get_plot_data(self) -> Dict:
pass

@abstractmethod
def get_trajectory_data_object(self) -> TrajectoryData:
pass

@abstractmethod
def get_file_contents(self) -> Union[Dict, bytes]:
pass

@abstractmethod
def get_num_frames(self) -> int:
pass
11 changes: 11 additions & 0 deletions simulariumio/data_objects/trajectory_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,3 +146,14 @@ def __deepcopy__(self, memo):
plots=copy.deepcopy(self.plots, memo),
)
return result

def __eq__(self, other):
if isinstance(other, TrajectoryData):
return (
self.meta_data == other.meta_data
and self.agent_data == other.agent_data
and self.time_units == other.time_units
and self.spatial_units == other.spatial_units
and self.plots == other.plots
)
return False
Loading

0 comments on commit 5d482f1

Please sign in to comment.