-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add New Data Types for Simularium Frame Data (#157)
* add new data objects for simularium file data for octopus * reorganize small data classes to be in one file * fix end index * update comments, return types * add checks to make sure requested frame number is valid * add logs for debugging * initialize frame_metadata to [] before appending to it * initialize frame metadata and block indices to empty * code cleanup * formatting fixes * update comments * remove unused imports * add some more test * add __eq__ to TrajectoryData for testing purposes * reorganize test to improve readibility * lint fixes * use existing test data files instead of bringing in new ones * remove file_name as a parameter * remove file_name as a parameter in test * move data classes that are only used by BinaryData to that file * use the defined constants instead of the values for indexing in binary data * Compute get_n_agents separately instead of building a new AgentData object * remove unused import * use BinaryBlockInfo to track block indices, rather than using a new data type * remove unused import * add abstract method decorators * lint fixes
- Loading branch information
Showing
7 changed files
with
461 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,184 @@ | ||
from typing import Dict, List, Tuple | ||
import numpy as np | ||
|
||
from .frame_data import FrameData | ||
from .input_file_data import InputFileData | ||
from .trajectory_data import TrajectoryData | ||
from .simularium_file_data import SimulariumFileData | ||
from ..constants import BINARY_BLOCK_TYPE, BINARY_SETTINGS | ||
from ..readers import BinaryBlockInfo, SimulariumBinaryReader | ||
|
||
|
||
class BinaryData(SimulariumFileData): | ||
def __init__(self, file_contents: bytes): | ||
""" | ||
This object holds binary encoded simulation trajectory file's | ||
data while staying close to the original file format | ||
Parameters | ||
---------- | ||
file_contents : bytes | ||
A byte array containing the data of an open .simularium file | ||
""" | ||
self.file_contents = InputFileData(file_contents=file_contents) | ||
self.file_data = SimulariumBinaryReader._binary_data_from_source( | ||
self.file_contents | ||
) | ||
self.frame_metadata: List[FrameMetadata] = [] | ||
self.block_info: BinaryBlockInfo = None | ||
# Maps block type id to block index | ||
self.block_indices: Dict[int, int] = {} | ||
self._parse_file() | ||
|
||
def _parse_file(self): | ||
# Read offset and length for each data block | ||
self.block_info = SimulariumBinaryReader._parse_binary_header( | ||
self.file_data.byte_view | ||
) | ||
for block_index in range(self.block_info.n_blocks): | ||
block_type_id = SimulariumBinaryReader._binary_block_type( | ||
block_index, self.block_info, self.file_data.int_view | ||
) | ||
self.block_indices[block_type_id] = block_index | ||
|
||
# Extract each frame's metadata | ||
spatial_block_index = self.block_indices[ | ||
BINARY_BLOCK_TYPE.SPATIAL_DATA_BINARY.value | ||
] | ||
block_offset = self.block_info.block_offsets[spatial_block_index] | ||
spatial_block_offset = ( | ||
int(block_offset / BINARY_SETTINGS.BYTES_PER_VALUE) | ||
+ BINARY_SETTINGS.BLOCK_HEADER_N_VALUES | ||
) | ||
n_frames = self.file_data.int_view[spatial_block_offset + 1] | ||
current_frame_offset = ( | ||
spatial_block_offset | ||
+ BINARY_SETTINGS.SPATIAL_BLOCK_HEADER_CONSTANT_N_VALUES | ||
+ BINARY_SETTINGS.SPATIAL_BLOCK_HEADER_N_VALUES_PER_FRAME * n_frames | ||
) | ||
for i in range(n_frames): | ||
offset = ( | ||
self.file_data.int_view[ | ||
spatial_block_offset | ||
+ BINARY_SETTINGS.SPATIAL_BLOCK_HEADER_CONSTANT_N_VALUES | ||
+ BINARY_SETTINGS.SPATIAL_BLOCK_HEADER_N_VALUES_PER_FRAME * i | ||
] | ||
+ block_offset | ||
) | ||
length = self.file_data.int_view[ | ||
spatial_block_offset | ||
+ BINARY_SETTINGS.SPATIAL_BLOCK_HEADER_CONSTANT_N_VALUES | ||
+ BINARY_SETTINGS.SPATIAL_BLOCK_HEADER_N_VALUES_PER_FRAME * i | ||
+ 1 | ||
] | ||
frame_number = self.file_data.int_view[current_frame_offset] | ||
time = self.file_data.float_view[current_frame_offset + 1] | ||
self.frame_metadata.append( | ||
FrameMetadata(offset, length, frame_number, time) | ||
) | ||
current_frame_offset += int(length / BINARY_SETTINGS.BYTES_PER_VALUE) | ||
|
||
def get_frame_at_index(self, frame_number: int) -> FrameData: | ||
""" | ||
Return frame data for frame at index. If there is no frame at the index, | ||
return None. | ||
""" | ||
if frame_number < 0 or frame_number >= len(self.frame_metadata): | ||
# invalid frame number requested | ||
return None | ||
|
||
metadata: FrameMetadata = self.frame_metadata[frame_number] | ||
start, end = metadata.get_start_end_indices() | ||
data = self.file_data.byte_view[start:end] | ||
return FrameData( | ||
frame_number=frame_number, | ||
n_agents=self.file_data.int_view[ | ||
int(start / BINARY_SETTINGS.BYTES_PER_VALUE) + 2 | ||
], | ||
time=metadata.time, | ||
data=data, | ||
) | ||
|
||
def get_index_for_time(self, time: float) -> int: | ||
""" | ||
Return index for frame closest to a given timestamp | ||
""" | ||
closest_frame = -1 | ||
min_dist = np.inf | ||
for frame in self.frame_metadata: | ||
dist = abs(frame.time - time) | ||
if dist < min_dist: | ||
min_dist = dist | ||
closest_frame = frame.frame_number | ||
else: | ||
# if dist is increasing, we've passed the closest frame | ||
break | ||
|
||
# frame index must be <= self.get_num_frames() - 1 | ||
return min(closest_frame, self.get_num_frames() - 1) | ||
|
||
def get_trajectory_info(self) -> Dict: | ||
""" | ||
Return trajectory info block for trajectory, as dict | ||
""" | ||
block_index = self.block_indices[BINARY_BLOCK_TYPE.TRAJ_INFO_JSON.value] | ||
return SimulariumBinaryReader._binary_block_json( | ||
block_index, self.block_info, self.file_data.byte_view | ||
) | ||
|
||
def get_plot_data(self) -> Dict: | ||
""" | ||
Return plot data block for trajectory, as dict | ||
""" | ||
block_index = self.block_indices[BINARY_BLOCK_TYPE.PLOT_DATA_JSON.value] | ||
return SimulariumBinaryReader._binary_block_json( | ||
block_index, self.block_info, self.file_data.byte_view | ||
) | ||
|
||
def get_trajectory_data_object(self) -> TrajectoryData: | ||
""" | ||
Return the data of the trajectory, as a TrajectoryData object | ||
""" | ||
trajectory_dict = SimulariumBinaryReader.load_binary(self.file_contents) | ||
return TrajectoryData.from_buffer_data(trajectory_dict) | ||
|
||
def get_file_contents(self) -> bytes: | ||
""" | ||
Return raw file data, as bytes | ||
""" | ||
return self.file_contents.get_contents() | ||
|
||
def get_num_frames(self) -> int: | ||
""" | ||
Return number of frames in the trajectory | ||
""" | ||
return len(self.frame_metadata) | ||
|
||
|
||
class FrameMetadata: | ||
def __init__(self, offset: int, length: int, frame_number: int, time: float): | ||
""" | ||
This object holds metadata for a single frame of simularium data | ||
Parameters | ||
---------- | ||
offset : int | ||
Number of bytes the block is offset from the start of the byte array | ||
length : int | ||
Number of bytes in the block | ||
frame_number : int | ||
Index of frame in the simulation | ||
time : float | ||
Elapsed simulation time of the frame | ||
""" | ||
self.offset = offset | ||
self.length = length | ||
self.frame_number = frame_number | ||
self.time = time | ||
|
||
def get_start_end_indices(self) -> Tuple[int, int]: | ||
""" | ||
Return the start and end indicies for the data block | ||
""" | ||
end = self.offset + self.length | ||
return self.offset, end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
from typing import Dict, Union | ||
|
||
|
||
class FrameData: | ||
def __init__( | ||
self, frame_number: int, n_agents: int, time: float, data: Union[bytes, Dict] | ||
): | ||
""" | ||
This object holds frame data for a single frame of simularium data | ||
Parameters | ||
---------- | ||
frame_number : int | ||
Index of frame in the simulation | ||
n_agents : int | ||
Number of agents included in the frame | ||
time : float | ||
Elapsed simulation time of the frame | ||
data : bytes or dict | ||
Spatial data for the frame, as a byte array for binary encoded | ||
.simularium files or as a dict for JSON .simularium files | ||
""" | ||
self.frame_number = frame_number | ||
self.n_agents = n_agents | ||
self.time = time | ||
self.data = data |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
from typing import Dict, List | ||
import json | ||
import numpy as np | ||
|
||
from .frame_data import FrameData | ||
from .simularium_file_data import SimulariumFileData | ||
from .trajectory_data import TrajectoryData | ||
from ..constants import V1_SPATIAL_BUFFER_STRUCT | ||
|
||
|
||
class JsonData(SimulariumFileData): | ||
def __init__(self, file_contents: str): | ||
""" | ||
This object holds JSON encoded simulation trajectory file's | ||
data while staying close to the original file format | ||
Parameters | ||
---------- | ||
file_contents : str | ||
A string of the data of an open .simularium file | ||
""" | ||
self.data = json.loads(file_contents) | ||
self.n_agents = JsonData._get_n_agents(self.data) | ||
|
||
def _get_n_agents(data: Dict) -> List[int]: | ||
# return number of agents in each timestamp as a list | ||
n_agents = [] | ||
bundle_data = data["spatialData"]["bundleData"] | ||
for time_index in range(data["trajectoryInfo"]["totalSteps"]): | ||
frame_data = bundle_data[time_index]["data"] | ||
agent_index = 0 | ||
buffer_index = 0 | ||
while buffer_index + V1_SPATIAL_BUFFER_STRUCT.NSP_INDEX < len(frame_data): | ||
n_subpoints = int( | ||
frame_data[buffer_index + V1_SPATIAL_BUFFER_STRUCT.NSP_INDEX] | ||
) | ||
# length of one agents spatial data = SP_INDEX + n_subpoints | ||
buffer_index += n_subpoints + V1_SPATIAL_BUFFER_STRUCT.SP_INDEX | ||
agent_index += 1 | ||
n_agents.append(agent_index) | ||
return n_agents | ||
|
||
def get_frame_at_index(self, frame_number: int) -> FrameData: | ||
""" | ||
Return frame data for frame at index. If there is no frame at the index, | ||
return None. | ||
""" | ||
if frame_number < 0 or frame_number >= len( | ||
self.data["spatialData"]["bundleData"] | ||
): | ||
# invalid frame number requested | ||
return None | ||
|
||
frame_data = self.data["spatialData"]["bundleData"][frame_number] | ||
return FrameData( | ||
frame_number=frame_number, | ||
n_agents=self.n_agents[frame_number], | ||
time=frame_data["time"], | ||
data=frame_data["data"], | ||
) | ||
|
||
def get_index_for_time(self, time: float) -> int: | ||
""" | ||
Return index for frame closest to a given timestamp | ||
""" | ||
closest_frame = -1 | ||
min_dist = np.inf | ||
for frame in self.data["spatialData"]["bundleData"]: | ||
dist = abs(frame["time"] - time) | ||
if dist < min_dist: | ||
min_dist = dist | ||
closest_frame = frame["frameNumber"] | ||
else: | ||
# if dist is increasing, we've passed the closest frame | ||
break | ||
|
||
# frame index must be <= self.get_num_frames() - 1 | ||
return min(closest_frame, self.get_num_frames() - 1) | ||
|
||
def get_trajectory_info(self) -> Dict: | ||
""" | ||
Return trajectory info block for trajectory, as dict | ||
""" | ||
return self.data["trajectoryInfo"] | ||
|
||
def get_plot_data(self) -> Dict: | ||
""" | ||
Return plot data block for trajectory, as dict | ||
""" | ||
return self.data["plotData"] | ||
|
||
def get_trajectory_data_object(self) -> TrajectoryData: | ||
""" | ||
Return the data of the trajectory, as a TrajectoryData object | ||
""" | ||
return TrajectoryData.from_buffer_data(self.data) | ||
|
||
def get_file_contents(self) -> Dict: | ||
""" | ||
Return raw file data, as a dict | ||
""" | ||
return self.data | ||
|
||
def get_num_frames(self) -> int: | ||
""" | ||
Return number of frames in the trajectory | ||
""" | ||
return len(self.data["spatialData"]["bundleData"]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
from typing import Dict, Union | ||
from abc import ABC, abstractmethod | ||
from .trajectory_data import TrajectoryData | ||
from .frame_data import FrameData | ||
|
||
|
||
class SimulariumFileData(ABC): | ||
def __init__(self, file_contents: Union[str, bytes]): | ||
pass | ||
|
||
@abstractmethod | ||
def get_frame_at_index(self, frame_number: int) -> Union[FrameData, None]: | ||
pass | ||
|
||
@abstractmethod | ||
def get_index_for_time(self, time: float) -> int: | ||
pass | ||
|
||
@abstractmethod | ||
def get_trajectory_info(self) -> Dict: | ||
pass | ||
|
||
@abstractmethod | ||
def get_plot_data(self) -> Dict: | ||
pass | ||
|
||
@abstractmethod | ||
def get_trajectory_data_object(self) -> TrajectoryData: | ||
pass | ||
|
||
@abstractmethod | ||
def get_file_contents(self) -> Union[Dict, bytes]: | ||
pass | ||
|
||
@abstractmethod | ||
def get_num_frames(self) -> int: | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.