From dc1c0b087292a315e9208c258a835025798fbade Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=88=98=E5=AE=87=E6=A2=A6?= Date: Sat, 21 Sep 2024 17:54:43 +0800 Subject: [PATCH 1/3] =?UTF-8?q?=E6=94=AF=E6=8C=81=E5=9C=A8=E6=8C=87?= =?UTF-8?q?=E5=AE=9A=E5=AD=97=E6=AE=B5=E6=A0=B9=E6=8D=AE=E5=80=BC=E5=9F=9F?= =?UTF-8?q?=E8=8C=83=E5=9B=B4=E8=BF=9B=E8=A1=8C=E6=95=B0=E6=8D=AE=E7=AD=9B?= =?UTF-8?q?=E9=80=89?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../range_specified_field_selector.py | 55 +++++++++++-------- 1 file changed, 32 insertions(+), 23 deletions(-) diff --git a/data_juicer/ops/selector/range_specified_field_selector.py b/data_juicer/ops/selector/range_specified_field_selector.py index 55243b50f..d51381dac 100644 --- a/data_juicer/ops/selector/range_specified_field_selector.py +++ b/data_juicer/ops/selector/range_specified_field_selector.py @@ -1,4 +1,4 @@ -import heapq +import bisect from typing import Optional from pydantic import Field, PositiveInt @@ -17,6 +17,8 @@ class RangeSpecifiedFieldSelector(Selector): def __init__( self, field_key: str = '', + lower_value: float = None, + upper_value: float = None, lower_percentile: Optional[Annotated[float, Field(ge=0, le=1)]] = None, upper_percentile: Optional[Annotated[float, @@ -57,6 +59,8 @@ def __init__( """ super().__init__(*args, **kwargs) self.field_key = field_key + self.lower_value = lower_value + self.upper_value = upper_value self.lower_percentile = lower_percentile self.upper_percentile = upper_percentile self.lower_rank = lower_rank @@ -66,21 +70,10 @@ def process(self, dataset): if len(dataset) <= 1 or not self.field_key: return dataset - if self.lower_percentile is None and self.lower_rank is None: + if self.lower_value is None and self.upper_value is None and \ + self.lower_percentile is None and self.upper_percentile is None \ + and self.lower_rank is None and self.upper_rank is None: return dataset - if self.upper_percentile is None and self.upper_rank is None: - return dataset - - lower_bound, upper_bound = 0, len(dataset) - if self.lower_percentile is not None: - lower_bound = int(self.lower_percentile * len(dataset)) - if self.lower_rank is not None: - lower_bound = max(lower_bound, self.lower_rank) - if self.upper_percentile is not None: - upper_bound = int(self.upper_percentile * len(dataset)) - if self.upper_rank is not None: - upper_bound = min(upper_bound, self.upper_rank) - upper_bound = max(lower_bound, upper_bound) field_keys = self.field_key.split('.') assert field_keys[0] in dataset.features.keys( @@ -102,13 +95,29 @@ def get_field_value_list(cur_dataset, field_keys): return field_value_list field_value_list = get_field_value_list(dataset, field_keys) - select_index = heapq.nsmallest(int(upper_bound), range(len(dataset)), - field_value_list.__getitem__) - sub_dataset = dataset.select(select_index) + field_value_list, indices = zip( + *sorted(list(zip(field_value_list, range(len(field_value_list)))))) + + lower_bound, upper_bound = 0, len(dataset) - 1 + if self.lower_value is not None: + lower_bound = bisect.bisect_left(field_value_list, + self.lower_value) + if self.lower_percentile is not None: + lower_bound = max(lower_bound, + int(self.lower_percentile * len(dataset))) + if self.lower_rank is not None: + lower_bound = max(lower_bound, self.lower_rank) + if self.upper_value is not None: + upper_bound = bisect.bisect_right(field_value_list, + self.upper_value) - 1 + if self.upper_percentile is not None: + upper_bound = min(upper_bound, + int(self.upper_percentile * len(dataset))) + if self.upper_rank is not None: + upper_bound = min(upper_bound, self.upper_rank) + upper_bound = max(lower_bound, upper_bound) + + select_index = indices[lower_bound:upper_bound + 1] - field_value_list = get_field_value_list(sub_dataset, field_keys) - select_index = heapq.nlargest(int(upper_bound - lower_bound), - range(len(sub_dataset)), - field_value_list.__getitem__) + return dataset.select(select_index) - return sub_dataset.select(select_index) From a4a1181690ee1e93577f9fc29581ba5e397a4ece Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=88=98=E5=AE=87=E6=A2=A6?= Date: Sat, 21 Sep 2024 18:04:49 +0800 Subject: [PATCH 2/3] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E6=96=B0=E5=8F=82?= =?UTF-8?q?=E6=95=B0=E7=9A=84=E5=8D=95=E6=B5=8B=E7=94=A8=E4=BE=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../test_range_specified_field_selector.py | 153 ++++++++++++++++++ 1 file changed, 153 insertions(+) diff --git a/tests/ops/selector/test_range_specified_field_selector.py b/tests/ops/selector/test_range_specified_field_selector.py index b0dd77a1e..fa23a9e7b 100644 --- a/tests/ops/selector/test_range_specified_field_selector.py +++ b/tests/ops/selector/test_range_specified_field_selector.py @@ -16,6 +16,159 @@ def _run_range_selector(self, dataset: Dataset, target_list, op): target_list = sorted(target_list, key=lambda x: x['text']) self.assertEqual(res_list, target_list) + def test_value_select(self): + ds_list = [{ + 'text': 'Today is Sun', + 'count': 101, + 'meta': { + 'suffix': '.pdf', + 'key1': { + 'key2': { + 'count': 34 + }, + 'count': 5 + } + } + }, { + 'text': 'a v s e c s f e f g a a a ', + 'count': 16, + 'meta': { + 'suffix': '.docx', + 'key1': { + 'key2': { + 'count': 243 + }, + 'count': 63 + } + } + }, { + 'text': '中文也是一个字算一个长度', + 'count': 162, + 'meta': { + 'suffix': '.txt', + 'key1': { + 'key2': { + 'count': None + }, + 'count': 23 + } + } + }, { + 'text': ',。、„”“«»1」「《》´∶:?!', + 'count': None, + 'meta': { + 'suffix': '.html', + 'key1': { + 'key2': { + 'count': 18 + }, + 'count': 48 + } + } + }, { + 'text': '他的英文名字叫Harry Potter', + 'count': 88, + 'meta': { + 'suffix': '.pdf', + 'key1': { + 'key2': { + 'count': 551 + }, + 'count': 78 + } + } + }, { + 'text': '这是一个测试', + 'count': None, + 'meta': { + 'suffix': '.py', + 'key1': { + 'key2': { + 'count': 89 + }, + 'count': 3 + } + } + }, { + 'text': '我出生于2023年12月15日', + 'count': None, + 'meta': { + 'suffix': '.java', + 'key1': { + 'key2': { + 'count': 354.32 + }, + 'count': 67 + } + } + }, { + 'text': 'emoji表情测试下😊,😸31231\n', + 'count': 2, + 'meta': { + 'suffix': '.html', + 'key1': { + 'key2': { + 'count': 354.32 + }, + 'count': 32 + } + } + }, { + 'text': 'a=1\nb\nc=1+2+3+5\nd=6', + 'count': 178, + 'meta': { + 'suffix': '.pdf', + 'key1': { + 'key2': { + 'count': 33 + }, + 'count': 33 + } + } + }, { + 'text': '使用片段分词器对每个页面进行分词,使用语言', + 'count': 666, + 'meta': { + 'suffix': '.xml', + 'key1': { + 'key2': { + 'count': 18 + }, + 'count': 48 + } + } + }] + tgt_list = [{ + 'text': 'a v s e c s f e f g a a a ', + 'count': 16, + 'meta': { + 'suffix': '.docx', + 'key1': { + 'key2': { + 'count': 243 + }, + 'count': 63 + } + } + }, { + 'text': '我出生于2023年12月15日', + 'count': None, + 'meta': { + 'suffix': '.java', + 'key1': { + 'key2': { + 'count': 354.32 + }, + 'count': 67 + } + } + }] + dataset = Dataset.from_list(ds_list) + op = RangeSpecifiedFieldSelector(field_key='meta.key1.count', + lower_value=63, + upper_value=67) + self._run_range_selector(dataset, tgt_list, op) + def test_percentile_select(self): ds_list = [{ 'text': 'Today is Sun', From b2d9a1ebb6c1fb16715d965b8d8c7ef28cc183fe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=88=98=E5=AE=87=E6=A2=A6?= Date: Sat, 21 Sep 2024 18:10:57 +0800 Subject: [PATCH 3/3] add pre-commit --- data_juicer/ops/selector/range_specified_field_selector.py | 1 - 1 file changed, 1 deletion(-) diff --git a/data_juicer/ops/selector/range_specified_field_selector.py b/data_juicer/ops/selector/range_specified_field_selector.py index d51381dac..4a1a52bd1 100644 --- a/data_juicer/ops/selector/range_specified_field_selector.py +++ b/data_juicer/ops/selector/range_specified_field_selector.py @@ -120,4 +120,3 @@ def get_field_value_list(cur_dataset, field_keys): select_index = indices[lower_bound:upper_bound + 1] return dataset.select(select_index) -