Skip to content

Commit

Permalink
refine
Browse files Browse the repository at this point in the history
  • Loading branch information
drcege committed Dec 6, 2023
1 parent a858e3d commit d97bce1
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 3 deletions.
5 changes: 3 additions & 2 deletions data_juicer/ops/filter/face_area_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@ def compute_stats(self, sample, context=False):

# there is no image in this sample, still default ratio 0.0
if self.image_key not in sample or not sample[self.image_key]:
sample[Fields.stats][StatsKeys.face_ratios] = [0.0]
sample[Fields.stats][StatsKeys.face_ratios] = np.empty(0,
dtype=float)
return sample

# load images
Expand Down Expand Up @@ -108,7 +109,7 @@ def compute_stats(self, sample, context=False):
img_area = images[key].width * images[key].height
# Calculate the max face ratio for the current image
max_face_ratios.append(
max([w * h / img_area for _, _, w, h in dets]))
max([w * h / img_area for _, _, w, h in dets], default=0.0))
sample[Fields.stats][StatsKeys.face_ratios] = max_face_ratios

return sample
Expand Down
2 changes: 1 addition & 1 deletion tests/ops/filter/test_face_area_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def _run_face_area_filter(self,
column=[{}] * dataset.num_rows)
dataset = dataset.map(op.compute_stats, num_proc=num_proc)
dataset = dataset.filter(op.process, num_proc=num_proc)
dataset = dataset.remove_columns('__dj__stats__')
dataset = dataset.remove_columns(Fields.stats)
res_list = dataset.to_list()
self.assertEqual(res_list, target_list)

Expand Down

0 comments on commit d97bce1

Please sign in to comment.