Skip to content

Commit

Permalink
Relax requirement for frame number
Browse files Browse the repository at this point in the history
  • Loading branch information
sfmig committed Sep 17, 2024
1 parent 0c3fa3c commit 64eb1de
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 16 deletions.
45 changes: 34 additions & 11 deletions movement/io/load_bboxes.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,9 @@ def from_via_tracks_file(
logger.debug(f"Validated VIA tracks .csv file {via_file.path}.")

# Create an xarray.Dataset from the data
bboxes_arrays = _numpy_arrays_from_via_tracks_file(via_file.path)
bboxes_arrays = _numpy_arrays_from_via_tracks_file(
via_file.path, via_file.frame_regexp
)
ds = from_numpy(
position_array=bboxes_arrays["position_array"],
shape_array=bboxes_arrays["shape_array"],
Expand All @@ -347,7 +349,9 @@ def from_via_tracks_file(
return ds


def _numpy_arrays_from_via_tracks_file(file_path: Path) -> dict:
def _numpy_arrays_from_via_tracks_file(
file_path: Path, frame_regexp: str
) -> dict:
"""Extract numpy arrays from the input VIA tracks .csv file.
The extracted numpy arrays are returned in a dictionary with the following
Expand All @@ -370,6 +374,12 @@ def _numpy_arrays_from_via_tracks_file(file_path: Path) -> dict:
file_path : pathlib.Path
Path to the VIA tracks .csv file containing the bounding boxes' tracks.
frame_regexp : str
Regular expression pattern to extract the frame number from the
filename. The frame number is expected to be encoded in the filename
as an integer number led by at least one zero, followed by the file
extension.
Returns
-------
dict
Expand All @@ -379,7 +389,7 @@ def _numpy_arrays_from_via_tracks_file(file_path: Path) -> dict:
# Extract 2D dataframe from input data
# (sort data by ID and frame number, and
# fill empty frame-ID pairs with nans)
df = _df_from_via_tracks_file(file_path)
df = _df_from_via_tracks_file(file_path, frame_regexp)

# Compute indices of the rows where the IDs switch
bool_id_diff_from_prev = df["ID"].ne(df["ID"].shift()) # pandas series
Expand Down Expand Up @@ -417,7 +427,9 @@ def _numpy_arrays_from_via_tracks_file(file_path: Path) -> dict:
return array_dict


def _df_from_via_tracks_file(file_path: Path) -> pd.DataFrame:
def _df_from_via_tracks_file(
file_path: Path, frame_regexp: str
) -> pd.DataFrame:
"""Load VIA tracks .csv file as a dataframe.
Read the VIA tracks .csv file as a pandas dataframe with columns:
Expand All @@ -433,6 +445,10 @@ def _df_from_via_tracks_file(file_path: Path) -> pd.DataFrame:
empty frames are filled in with NaNs. The coordinates of the bboxes
are assumed to be in the image coordinate system (i.e., the top-left
corner of a bbox is its corner with minimum x and y coordinates).
The frame number is extracted from the filename using the provided
regexp if it is not defined as a 'file_attribute' in the VIA tracks .csv
file.
"""
# Read VIA tracks .csv file as a pandas dataframe
df_file = pd.read_csv(file_path, sep=",", header=0)
Expand All @@ -443,7 +459,9 @@ def _df_from_via_tracks_file(file_path: Path) -> pd.DataFrame:
"ID": _via_attribute_column_to_numpy(
df_file, "region_attributes", ["track"], int
),
"frame_number": _extract_frame_number_from_via_tracks_df(df_file),
"frame_number": _extract_frame_number_from_via_tracks_df(
df_file, frame_regexp
),
"x": _via_attribute_column_to_numpy(
df_file, "region_shape_attributes", ["x"], float
),
Expand Down Expand Up @@ -508,7 +526,7 @@ def _extract_confidence_from_via_tracks_df(df) -> np.ndarray:
return bbox_confidence


def _extract_frame_number_from_via_tracks_df(df) -> np.ndarray:
def _extract_frame_number_from_via_tracks_df(df, frame_regexp) -> np.ndarray:
"""Extract frame numbers from the VIA tracks input dataframe.
Parameters
Expand All @@ -517,14 +535,20 @@ def _extract_frame_number_from_via_tracks_df(df) -> np.ndarray:
The VIA tracks input dataframe is the one obtained from
``df = pd.read_csv(file_path, sep=",", header=0)``.
frame_regexp : str
Regular expression pattern to extract the frame number from the
filename. The frame number is expected to be encoded in the filename
as an integer number led by at least one zero, followed by the file
extension.
Returns
-------
np.ndarray
A numpy array of size (n_frames, ) containing the frame numbers.
In the VIA tracks .csv file, the frame number is expected to be
defined as a 'file_attribute' , or encoded in the filename as an
integer number led by at least one zero, between "_" and ".", followed
by the file extension.
integer number led by at least one zero, followed by the file
extension.
"""
# Extract frame number from file_attributes if exists
Expand All @@ -538,10 +562,9 @@ def _extract_frame_number_from_via_tracks_df(df) -> np.ndarray:
)
# Else extract from filename
else:
pattern = r"_(0\d*)\.\w+$"
list_frame_numbers = [
int(re.search(pattern, f).group(1)) # type: ignore
if re.search(pattern, f)
int(re.search(frame_regexp, f).group(1)) # type: ignore
if re.search(frame_regexp, f)
else np.nan
for f in df["filename"]
]
Expand Down
14 changes: 9 additions & 5 deletions movement/validators/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,11 @@ class ValidVIATracksCSV:
----------
path : pathlib.Path
Path to the VIA tracks .csv file.
frame_regexp : str
Regular expression pattern to extract the frame number from the
filename. By default, the frame number is expected to be encoded in
the filename as an integer number led by at least one zero, followed
by the file extension.
Raises
------
Expand All @@ -243,6 +248,7 @@ class ValidVIATracksCSV:
"""

path: Path = field(validator=validators.instance_of(Path))
frame_regexp: str = r"(0\d*)\.\w+$"

@path.validator
def _file_contains_valid_header(self, attribute, value):
Expand Down Expand Up @@ -281,8 +287,8 @@ def _file_contains_valid_frame_numbers(self, attribute, value):
files.
If the frame number is included as part of the image file name, then
it is expected as an integer led by at least one zero, between "_" and
".", followed by the file extension.
it is expected as an integer led by at least one zero, followed by the
file extension.
"""
df = pd.read_csv(value, sep=",", header=0)
Expand All @@ -309,10 +315,8 @@ def _file_contains_valid_frame_numbers(self, attribute, value):

# else: extract frame number from filename.
else:
pattern = r"_(0\d*)\.\w+$"

for f_i, f in enumerate(df["filename"]):
regex_match = re.search(pattern, f)
regex_match = re.search(self.frame_regexp, f)
if regex_match: # if there is a pattern match
list_frame_numbers.append(
int(regex_match.group(1)) # type: ignore
Expand Down

0 comments on commit 64eb1de

Please sign in to comment.