Skip to content

Commit

Permalink
Simplify transpose
Browse files Browse the repository at this point in the history
  • Loading branch information
lochhh committed Nov 28, 2024
1 parent c1420ce commit e03dd51
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
2 changes: 1 addition & 1 deletion movement/io/load_poses.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,7 @@ def _ds_from_sleap_analysis_file(

with h5py.File(file.path, "r") as f:
# transpose to shape: (n_frames, n_space, n_keypoints, n_tracks)
tracks = f["tracks"][:].transpose((3, 1, 2, 0))
tracks = f["tracks"][:].transpose(3, 1, 2, 0)
# Create an array of NaNs for the confidence scores
scores = np.full(tracks.shape[:1] + tracks.shape[2:], np.nan)
individual_names = [n.decode() for n in f["track_names"][:]] or None
Expand Down
6 changes: 3 additions & 3 deletions movement/io/save_poses.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,14 @@ def _ds_to_dlc_style_df(
"""
# Concatenate the pose tracks and confidence scores into one array
# and reverse the order of the dimensions except for the time dimension
tracks_with_scores = np.concatenate(
(
ds.position.data,
ds.confidence.data[:, np.newaxis, ...],
),
axis=1,
)
# Reverse the order of the dimensions except for the time dimension
transpose_order = [0] + list(range(tracks_with_scores.ndim - 1, 0, -1))
tracks_with_scores = tracks_with_scores.transpose(transpose_order)

Expand Down Expand Up @@ -324,8 +324,8 @@ def to_sleap_analysis_file(ds: xr.Dataset, file_path: str | Path) -> None:
pos_x = ds.position.sel(space="x").values
# Mask denoting which individuals are present in each frame
track_occupancy = (~np.all(np.isnan(pos_x), axis=1)).astype(int)
tracks = np.transpose(ds.position.data, (3, 1, 2, 0))
point_scores = np.transpose(ds.confidence.data, (2, 1, 0))
tracks = ds.position.data.transpose(3, 1, 2, 0)
point_scores = ds.confidence.data.T
instance_scores = np.full((n_individuals, n_frames), np.nan, dtype=float)
tracking_scores = np.full((n_individuals, n_frames), np.nan, dtype=float)
labels_path = (
Expand Down

0 comments on commit e03dd51

Please sign in to comment.