Skip to content

Commit

Permalink
tags specified field
Browse files Browse the repository at this point in the history
  • Loading branch information
BeachWang committed Dec 20, 2024
1 parent cf01e7e commit 0ba6459
Show file tree
Hide file tree
Showing 3 changed files with 120 additions and 1 deletion.
4 changes: 3 additions & 1 deletion data_juicer/ops/selector/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from .frequency_specified_field_selector import FrequencySpecifiedFieldSelector
from .random_selector import RandomSelector
from .range_specified_field_selector import RangeSpecifiedFieldSelector
from .tags_specified_field_selector import TagsSpecifiedFieldSelector
from .topk_specified_field_selector import TopkSpecifiedFieldSelector

__all__ = [
'FrequencySpecifiedFieldSelector', 'RandomSelector',
'RangeSpecifiedFieldSelector', 'TopkSpecifiedFieldSelector'
'RangeSpecifiedFieldSelector', 'TagsSpecifiedFieldSelector',
'TopkSpecifiedFieldSelector'
]
54 changes: 54 additions & 0 deletions data_juicer/ops/selector/tags_specified_field_selector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import numbers
from typing import List

from ..base_op import OPERATORS, Selector


@OPERATORS.register_module('tags_specified_field_selector')
class TagsSpecifiedFieldSelector(Selector):
"""Selector to select samples based on the tags of specified
field."""

def __init__(self,
field_key: str = '',
target_tags: List[str] = None,
*args,
**kwargs):
"""
Initialization method.
:param field_key: Selector based on the specified value
corresponding to the target key. The target key
corresponding to multi-level field information need to be
separated by '.'.
:param target_tags: Target tags to be select.
:param args: extra args
:param kwargs: extra args
"""
super().__init__(*args, **kwargs)
self.field_key = field_key
self.target_tags = set(target_tags)

def process(self, dataset):
if len(dataset) <= 1 or not self.field_key:
return dataset

field_keys = self.field_key.split('.')
assert field_keys[0] in dataset.features.keys(
), "'{}' not in {}".format(field_keys[0], dataset.features.keys())

selected_index = []
for i, item in enumerate(dataset[field_keys[0]]):
field_value = item
for key in field_keys[1:]:
assert key in field_value.keys(), "'{}' not in {}".format(
key, field_value.keys())
field_value = field_value[key]
assert field_value is None or isinstance(
field_value, str) or isinstance(
field_value, numbers.Number
), 'The {} item is not String, Numbers or NoneType'.format(i)
if field_value in self.target_tags:
selected_index.append(i)

return dataset.select(selected_index)
63 changes: 63 additions & 0 deletions tests/ops/selector/test_tags_specified_selector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import unittest

from data_juicer.core.data import NestedDataset as Dataset

from data_juicer.ops.selector.tags_specified_field_selector import \
TagsSpecifiedFieldSelector
from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase


class TagsSpecifiedFieldSelectorTest(DataJuicerTestCaseBase):

def _run_tag_selector(self, dataset: Dataset, target_list, op):
dataset = op.process(dataset)
res_list = dataset.to_list()
self.assertEqual(res_list, target_list)

def test_tag_select(self):
ds_list = [{
'text': 'a',
'meta': {
'sentiment': 'happy',
}
}, {
'text': 'b',
'meta': {
'sentiment': 'happy',
}
}, {
'text': 'c',
'meta': {
'sentiment': 'sad',
}
}, {
'text': 'd',
'meta': {
'sentiment': 'angry',
}
}]
tgt_list = [{
'text': 'a',
'meta': {
'sentiment': 'happy',
}
}, {
'text': 'b',
'meta': {
'sentiment': 'happy',
}
}, {
'text': 'c',
'meta': {
'sentiment': 'sad',
}
}]
dataset = Dataset.from_list(ds_list)
op = TagsSpecifiedFieldSelector(
field_key='meta.sentiment',
target_tags=['happy', 'sad'])
self._run_tag_selector(dataset, tgt_list, op)


if __name__ == '__main__':
unittest.main()

0 comments on commit 0ba6459

Please sign in to comment.