Skip to content

Commit

Permalink
Simplify frame number extraction test by splitting it
Browse files Browse the repository at this point in the history
  • Loading branch information
sfmig committed Nov 29, 2024
1 parent fe04736 commit 82cfb39
Showing 1 changed file with 54 additions and 15 deletions.
69 changes: 54 additions & 15 deletions tests/test_unit/test_load_bboxes.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,38 +456,77 @@ def test_extract_confidence_from_via_tracks_df(
),
],
)
def test_extract_frame_number_from_via_tracks_df_filenames(
create_df_input_via_tracks,
via_file_path,
expected_frame_array,
):
"""Test that the function correctly extracts the frame number values from
the images' filenames.
"""
# create the dataframe with the frame number
df = create_df_input_via_tracks(
via_file_path,
small=True,
)

# the VIA tracks .csv files have no frames defined under the
# "file_attributes" so the frame numbers should be extracted
# from the filenames
assert not all(["frame" in row for row in df["file_attributes"]])

# extract frame number from df
frame_array = load_bboxes._extract_frame_number_from_via_tracks_df(df)

assert np.array_equal(frame_array, expected_frame_array)


@pytest.mark.parametrize(
"attribute_column_additions",
"via_file_path, attribute_column_additions, expected_frame_array",
[
(None), # taking "frame" from the images' filenames
({"file_attributes": []}), # taking "frame" from the "file_attributes"
(
pytest.DATA_PATHS.get("VIA_multiple-crabs_5-frames_labels.csv"),
{"file_attributes": [{"frame": 222}]},
np.ones(
3,
)
* 222,
),
(
pytest.DATA_PATHS.get("VIA_single-crab_MOCA-crab-1.csv"),
{
"file_attributes": [
{"frame": 218},
{"frame": 219},
{"frame": 220},
]
},
np.array([218, 219, 220]),
),
],
)
def test_extract_frame_number_from_via_tracks_df(
def test_extract_frame_number_from_via_tracks_df_file_attributes(
create_df_input_via_tracks,
via_file_path,
attribute_column_additions,
expected_frame_array,
):
"""Test that the function correctly extracts the frame number values from
the VIA dataframe.
"""
# If required: define the list of frames
# to append to the dataframe as file attributes
if attribute_column_additions:
attribute_column_additions["file_attributes"] = [
{"frame": f.item()} for f in expected_frame_array
]
the file attributes column.
# create the dataframe with the frame number
# (either from the file name or from the file attributes)
The frame number defined under the "file_attributes" column
should take precedence over the frame numbers encoded in the filenames.
"""
# Create the dataframe with the frame number stored in
# the file_attributes column
df = create_df_input_via_tracks(
via_file_path,
small=True,
attribute_column_additions=attribute_column_additions,
)

# extract frame number from df
# extract frame number from the dataframe
# (should take precedence over the frame numbers in the filenames)
frame_array = load_bboxes._extract_frame_number_from_via_tracks_df(df)

assert np.array_equal(frame_array, expected_frame_array)
Expand Down

0 comments on commit 82cfb39

Please sign in to comment.