diff --git a/movement/io/load_bboxes.py b/movement/io/load_bboxes.py index 220dc84c..1fbee4fc 100644 --- a/movement/io/load_bboxes.py +++ b/movement/io/load_bboxes.py @@ -321,7 +321,9 @@ def from_via_tracks_file( logger.debug(f"Validated VIA tracks .csv file {via_file.path}.") # Create an xarray.Dataset from the data - bboxes_arrays = _numpy_arrays_from_via_tracks_file(via_file.path) + bboxes_arrays = _numpy_arrays_from_via_tracks_file( + via_file.path, via_file.frame_regexp + ) ds = from_numpy( position_array=bboxes_arrays["position_array"], shape_array=bboxes_arrays["shape_array"], @@ -347,7 +349,9 @@ def from_via_tracks_file( return ds -def _numpy_arrays_from_via_tracks_file(file_path: Path) -> dict: +def _numpy_arrays_from_via_tracks_file( + file_path: Path, frame_regexp: str +) -> dict: """Extract numpy arrays from the input VIA tracks .csv file. The extracted numpy arrays are returned in a dictionary with the following @@ -370,6 +374,12 @@ def _numpy_arrays_from_via_tracks_file(file_path: Path) -> dict: file_path : pathlib.Path Path to the VIA tracks .csv file containing the bounding boxes' tracks. + frame_regexp : str + Regular expression pattern to extract the frame number from the + filename. The frame number is expected to be encoded in the filename + as an integer number led by at least one zero, followed by the file + extension. + Returns ------- dict @@ -379,7 +389,7 @@ def _numpy_arrays_from_via_tracks_file(file_path: Path) -> dict: # Extract 2D dataframe from input data # (sort data by ID and frame number, and # fill empty frame-ID pairs with nans) - df = _df_from_via_tracks_file(file_path) + df = _df_from_via_tracks_file(file_path, frame_regexp) # Compute indices of the rows where the IDs switch bool_id_diff_from_prev = df["ID"].ne(df["ID"].shift()) # pandas series @@ -417,7 +427,9 @@ def _numpy_arrays_from_via_tracks_file(file_path: Path) -> dict: return array_dict -def _df_from_via_tracks_file(file_path: Path) -> pd.DataFrame: +def _df_from_via_tracks_file( + file_path: Path, frame_regexp: str +) -> pd.DataFrame: """Load VIA tracks .csv file as a dataframe. Read the VIA tracks .csv file as a pandas dataframe with columns: @@ -433,6 +445,10 @@ def _df_from_via_tracks_file(file_path: Path) -> pd.DataFrame: empty frames are filled in with NaNs. The coordinates of the bboxes are assumed to be in the image coordinate system (i.e., the top-left corner of a bbox is its corner with minimum x and y coordinates). + + The frame number is extracted from the filename using the provided + regexp if it is not defined as a 'file_attribute' in the VIA tracks .csv + file. """ # Read VIA tracks .csv file as a pandas dataframe df_file = pd.read_csv(file_path, sep=",", header=0) @@ -443,7 +459,9 @@ def _df_from_via_tracks_file(file_path: Path) -> pd.DataFrame: "ID": _via_attribute_column_to_numpy( df_file, "region_attributes", ["track"], int ), - "frame_number": _extract_frame_number_from_via_tracks_df(df_file), + "frame_number": _extract_frame_number_from_via_tracks_df( + df_file, frame_regexp + ), "x": _via_attribute_column_to_numpy( df_file, "region_shape_attributes", ["x"], float ), @@ -508,7 +526,7 @@ def _extract_confidence_from_via_tracks_df(df) -> np.ndarray: return bbox_confidence -def _extract_frame_number_from_via_tracks_df(df) -> np.ndarray: +def _extract_frame_number_from_via_tracks_df(df, frame_regexp) -> np.ndarray: """Extract frame numbers from the VIA tracks input dataframe. Parameters @@ -517,14 +535,20 @@ def _extract_frame_number_from_via_tracks_df(df) -> np.ndarray: The VIA tracks input dataframe is the one obtained from ``df = pd.read_csv(file_path, sep=",", header=0)``. + frame_regexp : str + Regular expression pattern to extract the frame number from the + filename. The frame number is expected to be encoded in the filename + as an integer number led by at least one zero, followed by the file + extension. + Returns ------- np.ndarray A numpy array of size (n_frames, ) containing the frame numbers. In the VIA tracks .csv file, the frame number is expected to be defined as a 'file_attribute' , or encoded in the filename as an - integer number led by at least one zero, between "_" and ".", followed - by the file extension. + integer number led by at least one zero, followed by the file + extension. """ # Extract frame number from file_attributes if exists @@ -538,10 +562,9 @@ def _extract_frame_number_from_via_tracks_df(df) -> np.ndarray: ) # Else extract from filename else: - pattern = r"_(0\d*)\.\w+$" list_frame_numbers = [ - int(re.search(pattern, f).group(1)) # type: ignore - if re.search(pattern, f) + int(re.search(frame_regexp, f).group(1)) # type: ignore + if re.search(frame_regexp, f) else np.nan for f in df["filename"] ] diff --git a/movement/validators/files.py b/movement/validators/files.py index 8d013a95..6a910010 100644 --- a/movement/validators/files.py +++ b/movement/validators/files.py @@ -234,6 +234,11 @@ class ValidVIATracksCSV: ---------- path : pathlib.Path Path to the VIA tracks .csv file. + frame_regexp : str + Regular expression pattern to extract the frame number from the + filename. By default, the frame number is expected to be encoded in + the filename as an integer number led by at least one zero, followed + by the file extension. Raises ------ @@ -243,6 +248,7 @@ class ValidVIATracksCSV: """ path: Path = field(validator=validators.instance_of(Path)) + frame_regexp: str = r"(0\d*)\.\w+$" @path.validator def _file_contains_valid_header(self, attribute, value): @@ -281,8 +287,8 @@ def _file_contains_valid_frame_numbers(self, attribute, value): files. If the frame number is included as part of the image file name, then - it is expected as an integer led by at least one zero, between "_" and - ".", followed by the file extension. + it is expected as an integer led by at least one zero, followed by the + file extension. """ df = pd.read_csv(value, sep=",", header=0) @@ -309,10 +315,8 @@ def _file_contains_valid_frame_numbers(self, attribute, value): # else: extract frame number from filename. else: - pattern = r"_(0\d*)\.\w+$" - for f_i, f in enumerate(df["filename"]): - regex_match = re.search(pattern, f) + regex_match = re.search(self.frame_regexp, f) if regex_match: # if there is a pattern match list_frame_numbers.append( int(regex_match.group(1)) # type: ignore