diff --git a/data_juicer/ops/selector/range_specified_field_selector.py b/data_juicer/ops/selector/range_specified_field_selector.py index 55243b50f..4a1a52bd1 100644 --- a/data_juicer/ops/selector/range_specified_field_selector.py +++ b/data_juicer/ops/selector/range_specified_field_selector.py @@ -1,4 +1,4 @@ -import heapq +import bisect from typing import Optional from pydantic import Field, PositiveInt @@ -17,6 +17,8 @@ class RangeSpecifiedFieldSelector(Selector): def __init__( self, field_key: str = '', + lower_value: float = None, + upper_value: float = None, lower_percentile: Optional[Annotated[float, Field(ge=0, le=1)]] = None, upper_percentile: Optional[Annotated[float, @@ -57,6 +59,8 @@ def __init__( """ super().__init__(*args, **kwargs) self.field_key = field_key + self.lower_value = lower_value + self.upper_value = upper_value self.lower_percentile = lower_percentile self.upper_percentile = upper_percentile self.lower_rank = lower_rank @@ -66,21 +70,10 @@ def process(self, dataset): if len(dataset) <= 1 or not self.field_key: return dataset - if self.lower_percentile is None and self.lower_rank is None: + if self.lower_value is None and self.upper_value is None and \ + self.lower_percentile is None and self.upper_percentile is None \ + and self.lower_rank is None and self.upper_rank is None: return dataset - if self.upper_percentile is None and self.upper_rank is None: - return dataset - - lower_bound, upper_bound = 0, len(dataset) - if self.lower_percentile is not None: - lower_bound = int(self.lower_percentile * len(dataset)) - if self.lower_rank is not None: - lower_bound = max(lower_bound, self.lower_rank) - if self.upper_percentile is not None: - upper_bound = int(self.upper_percentile * len(dataset)) - if self.upper_rank is not None: - upper_bound = min(upper_bound, self.upper_rank) - upper_bound = max(lower_bound, upper_bound) field_keys = self.field_key.split('.') assert field_keys[0] in dataset.features.keys( @@ -102,13 +95,28 @@ def get_field_value_list(cur_dataset, field_keys): return field_value_list field_value_list = get_field_value_list(dataset, field_keys) - select_index = heapq.nsmallest(int(upper_bound), range(len(dataset)), - field_value_list.__getitem__) - sub_dataset = dataset.select(select_index) + field_value_list, indices = zip( + *sorted(list(zip(field_value_list, range(len(field_value_list)))))) + + lower_bound, upper_bound = 0, len(dataset) - 1 + if self.lower_value is not None: + lower_bound = bisect.bisect_left(field_value_list, + self.lower_value) + if self.lower_percentile is not None: + lower_bound = max(lower_bound, + int(self.lower_percentile * len(dataset))) + if self.lower_rank is not None: + lower_bound = max(lower_bound, self.lower_rank) + if self.upper_value is not None: + upper_bound = bisect.bisect_right(field_value_list, + self.upper_value) - 1 + if self.upper_percentile is not None: + upper_bound = min(upper_bound, + int(self.upper_percentile * len(dataset))) + if self.upper_rank is not None: + upper_bound = min(upper_bound, self.upper_rank) + upper_bound = max(lower_bound, upper_bound) - field_value_list = get_field_value_list(sub_dataset, field_keys) - select_index = heapq.nlargest(int(upper_bound - lower_bound), - range(len(sub_dataset)), - field_value_list.__getitem__) + select_index = indices[lower_bound:upper_bound + 1] - return sub_dataset.select(select_index) + return dataset.select(select_index) diff --git a/tests/ops/selector/test_range_specified_field_selector.py b/tests/ops/selector/test_range_specified_field_selector.py index b0dd77a1e..fa23a9e7b 100644 --- a/tests/ops/selector/test_range_specified_field_selector.py +++ b/tests/ops/selector/test_range_specified_field_selector.py @@ -16,6 +16,159 @@ def _run_range_selector(self, dataset: Dataset, target_list, op): target_list = sorted(target_list, key=lambda x: x['text']) self.assertEqual(res_list, target_list) + def test_value_select(self): + ds_list = [{ + 'text': 'Today is Sun', + 'count': 101, + 'meta': { + 'suffix': '.pdf', + 'key1': { + 'key2': { + 'count': 34 + }, + 'count': 5 + } + } + }, { + 'text': 'a v s e c s f e f g a a a ', + 'count': 16, + 'meta': { + 'suffix': '.docx', + 'key1': { + 'key2': { + 'count': 243 + }, + 'count': 63 + } + } + }, { + 'text': '中文也是一个字算一个长度', + 'count': 162, + 'meta': { + 'suffix': '.txt', + 'key1': { + 'key2': { + 'count': None + }, + 'count': 23 + } + } + }, { + 'text': ',。、„”“«»1」「《》´∶:?!', + 'count': None, + 'meta': { + 'suffix': '.html', + 'key1': { + 'key2': { + 'count': 18 + }, + 'count': 48 + } + } + }, { + 'text': '他的英文名字叫Harry Potter', + 'count': 88, + 'meta': { + 'suffix': '.pdf', + 'key1': { + 'key2': { + 'count': 551 + }, + 'count': 78 + } + } + }, { + 'text': '这是一个测试', + 'count': None, + 'meta': { + 'suffix': '.py', + 'key1': { + 'key2': { + 'count': 89 + }, + 'count': 3 + } + } + }, { + 'text': '我出生于2023年12月15日', + 'count': None, + 'meta': { + 'suffix': '.java', + 'key1': { + 'key2': { + 'count': 354.32 + }, + 'count': 67 + } + } + }, { + 'text': 'emoji表情测试下😊,😸31231\n', + 'count': 2, + 'meta': { + 'suffix': '.html', + 'key1': { + 'key2': { + 'count': 354.32 + }, + 'count': 32 + } + } + }, { + 'text': 'a=1\nb\nc=1+2+3+5\nd=6', + 'count': 178, + 'meta': { + 'suffix': '.pdf', + 'key1': { + 'key2': { + 'count': 33 + }, + 'count': 33 + } + } + }, { + 'text': '使用片段分词器对每个页面进行分词,使用语言', + 'count': 666, + 'meta': { + 'suffix': '.xml', + 'key1': { + 'key2': { + 'count': 18 + }, + 'count': 48 + } + } + }] + tgt_list = [{ + 'text': 'a v s e c s f e f g a a a ', + 'count': 16, + 'meta': { + 'suffix': '.docx', + 'key1': { + 'key2': { + 'count': 243 + }, + 'count': 63 + } + } + }, { + 'text': '我出生于2023年12月15日', + 'count': None, + 'meta': { + 'suffix': '.java', + 'key1': { + 'key2': { + 'count': 354.32 + }, + 'count': 67 + } + } + }] + dataset = Dataset.from_list(ds_list) + op = RangeSpecifiedFieldSelector(field_key='meta.key1.count', + lower_value=63, + upper_value=67) + self._run_range_selector(dataset, tgt_list, op) + def test_percentile_select(self): ds_list = [{ 'text': 'Today is Sun',