diff --git a/README.md b/README.md
index 518e54713..385a9aef1 100644
--- a/README.md
+++ b/README.md
@@ -34,7 +34,7 @@ We provide a [playground](http://8.138.149.181/) with a managed JupyterLab. [Try
[Platform for AI of Alibaba Cloud (PAI)](https://www.aliyun.com/product/bigdata/learn) has cited our work and integrated Data-Juicer into its data processing products. PAI is an AI Native large model and AIGC engineering platform that provides dataset management, computing power management, model tool chain, model development, model training, model deployment, and AI asset management. For documentation on data processing, please refer to: [PAI-Data Processing for Large Models](https://help.aliyun.com/zh/pai/user-guide/components-related-to-data-processing-for-foundation-models/?spm=a2c4g.11186623.0.0.3e9821a69kWdvX).
Data-Juicer is being actively updated and maintained. We will periodically enhance and add more features, data recipes and datasets.
-We welcome you to join us (via issues, PRs, [Slack](https://join.slack.com/t/data-juicer/shared_invite/zt-23zxltg9d-Z4d3EJuhZbCLGwtnLWWUDg?spm=a2c22.12281976.0.0.7a8253f30mgpjw) channel, [DingDing](https://qr.dingtalk.com/action/joingroup?spm=a2c22.12281976.0.0.7a8253f30mgpjw&code=v1,k1,C0DI7CwRFrg7gJP5aMC95FUmsNuwuKJboT62BqP5DAk=&_dt_no_comment=1&origin=11) group, ...), in promoting data-model co-development along with research and applications of (multimodal) LLMs!
+We welcome you to join us (via issues, PRs, [Slack](https://join.slack.com/t/data-juicer/shared_invite/zt-23zxltg9d-Z4d3EJuhZbCLGwtnLWWUDg?spm=a2c22.12281976.0.0.7a8253f30mgpjw) channel, [DingDing](https://qr.dingtalk.com/action/joingroup?code=v1,k1,YFIXM2leDEk7gJP5aMC95AfYT+Oo/EP/ihnaIEhMyJM=&_dt_no_comment=1&origin=11) group, ...), in promoting data-model co-development along with research and applications of (multimodal) LLMs!
----
diff --git a/README_ZH.md b/README_ZH.md
index 366fcb004..54ba1445f 100644
--- a/README_ZH.md
+++ b/README_ZH.md
@@ -27,7 +27,7 @@ Data-Juicer 是一个一站式**多模态**数据处理系统,旨在为大语
[阿里云人工智能平台 PAI](https://www.aliyun.com/product/bigdata/learn) 已引用我们的工作,将Data-Juicer的能力集成到PAI的数据处理产品中。PAI提供包含数据集管理、算力管理、模型工具链、模型开发、模型训练、模型部署、AI资产管理在内的功能模块,为用户提供高性能、高稳定、企业级的大模型工程化能力。数据处理的使用文档请参考:[PAI-大模型数据处理](https://help.aliyun.com/zh/pai/user-guide/components-related-to-data-processing-for-foundation-models/?spm=a2c4g.11186623.0.0.3e9821a69kWdvX)。
-Data-Juicer正在积极更新和维护中,我们将定期强化和新增更多的功能和数据菜谱。热烈欢迎您加入我们(issues/PRs/[Slack频道](https://join.slack.com/t/data-juicer/shared_invite/zt-23zxltg9d-Z4d3EJuhZbCLGwtnLWWUDg?spm=a2c22.12281976.0.0.7a8275bc8g7ypp) /[钉钉群](https://qr.dingtalk.com/action/joingroup?spm=a2c22.12281976.0.0.7a8275bc8g7ypp&code=v1,k1,C0DI7CwRFrg7gJP5aMC95FUmsNuwuKJboT62BqP5DAk=&_dt_no_comment=1&origin=11)/...),一起推进LLM-数据的协同开发和研究!
+Data-Juicer正在积极更新和维护中,我们将定期强化和新增更多的功能和数据菜谱。热烈欢迎您加入我们(issues/PRs/[Slack频道](https://join.slack.com/t/data-juicer/shared_invite/zt-23zxltg9d-Z4d3EJuhZbCLGwtnLWWUDg?spm=a2c22.12281976.0.0.7a8275bc8g7ypp) /[钉钉群](https://qr.dingtalk.com/action/joingroup?code=v1,k1,YFIXM2leDEk7gJP5aMC95AfYT+Oo/EP/ihnaIEhMyJM=&_dt_no_comment=1&origin=11)/...),一起推进LLM-数据的协同开发和研究!
----
diff --git a/configs/config_all.yaml b/configs/config_all.yaml
index 9811b0e97..42d1e779e 100644
--- a/configs/config_all.yaml
+++ b/configs/config_all.yaml
@@ -77,6 +77,68 @@ process:
- clean_ip_mapper: # remove ip addresses from text.
- clean_links_mapper: # remove web links from text.
- clean_copyright_mapper: # remove copyright comments.
+ - dialog_intent_detection_mapper: # Mapper to generate user's intent labels in dialog.
+ api_model: 'gpt-4o' # API model name.
+ intent_candidates: null # The output intent candidates. Use the intent labels of the open domain if it is None.
+ max_round: 10 # The max num of round in the dialog to build the prompt.
+ api_endpoint: null # URL endpoint for the API.
+ response_path: null # Path to extract content from the API response. Defaults to 'choices.0.message.content'.
+ system_prompt: null # System prompt for the task.
+ query_template: null # Template for query part to build the input prompt.
+ response_template: null # Template for response part to build the input prompt.
+ candidate_template: null # Template for intent candidates to build the input prompt.
+ analysis_template: null # Template for analysis part to build the input prompt.
+ labels_template: null # Template for labels to build the input prompt.
+ analysis_pattern: null # Pattern to parse the return intent analysis.
+ labels_pattern: null # Pattern to parse the return intent labels.
+ try_num: 3 # The number of retry attempts when there is an API call error or output parsing error.
+ model_params: {} # Parameters for initializing the API model.
+ sampling_params: {} # Extra parameters passed to the API call. e.g {'temperature': 0.9, 'top_p': 0.95}
+ - dialog_sentiment_detection_mapper: # Mapper to generate user's sentiment labels in dialog.
+ api_model: 'gpt-4o' # API model name.
+ max_round: 10 # The max num of round in the dialog to build the prompt.
+ api_endpoint: null # URL endpoint for the API.
+ response_path: null # Path to extract content from the API response. Defaults to 'choices.0.message.content'.
+ system_prompt: null # System prompt for the task.
+ query_template: null # Template for query part to build the input prompt.
+ response_template: null # Template for response part to build the input prompt.
+ analysis_template: null # Template for analysis part to build the input prompt.
+ labels_template: null # Template for labels part to build the input prompt.
+ analysis_pattern: null # Pattern to parse the return sentiment analysis.
+ labels_pattern: null # Pattern to parse the return sentiment labels.
+ try_num: 3 # The number of retry attempts when there is an API call error or output parsing error.
+ model_params: {} # Parameters for initializing the API model.
+ sampling_params: {} # Extra parameters passed to the API call. e.g {'temperature': 0.9, 'top_p': 0.95}
+ - dialog_sentiment_intensity_mapper: # Mapper to predict user's sentiment intensity (from -5 to 5 in default prompt) in dialog.
+ api_model: 'gpt-4o' # API model name.
+ max_round: 10 # The max num of round in the dialog to build the prompt.
+ api_endpoint: null # URL endpoint for the API.
+ response_path: null # Path to extract content from the API response. Defaults to 'choices.0.message.content'.
+ system_prompt: null # System prompt for the task.
+ query_template: null # Template for query part to build the input prompt.
+ response_template: null # Template for response part to build the input prompt.
+ analysis_template: null # Template for analysis part to build the input prompt.
+ intensity_template: null # Template for intensity part to build the input prompt.
+ analysis_pattern: null # Pattern to parse the return sentiment analysis.
+ intensity_pattern: null # Pattern to parse the return sentiment intensity.
+ try_num: 3 # The number of retry attempts when there is an API call error or output parsing error.
+ model_params: {} # Parameters for initializing the API model.
+ sampling_params: {} # Extra parameters passed to the API call. e.g {'temperature': 0.9, 'top_p': 0.95}
+ - dialog_topic_detection_mapper: # Mapper to generate user's topic labels in dialog.
+ api_model: 'gpt-4o' # API model name.
+ max_round: 10 # The max num of round in the dialog to build the prompt.
+ api_endpoint: null # URL endpoint for the API.
+ response_path: null # Path to extract content from the API response. Defaults to 'choices.0.message.content'.
+ system_prompt: null # System prompt for the task.
+ query_template: null # Template for query part to build the input prompt.
+ response_template: null # Template for response part to build the input prompt.
+ analysis_template: null # Template for analysis part to build the input prompt.
+ labels_template: null # Template for labels part to build the input prompt.
+ analysis_pattern: null # Pattern to parse the return topic analysis.
+ labels_pattern: null # Pattern to parse the return topic labels.
+ try_num: 3 # The number of retry attempts when there is an API call error or output parsing error.
+ model_params: {} # Parameters for initializing the API model.
+ sampling_params: {} # Extra parameters passed to the API call. e.g {'temperature': 0.9, 'top_p': 0.95}
- expand_macro_mapper: # expand macro definitions in Latex text.
- extract_entity_attribute_mapper: # Extract attributes for given entities from the text.
api_model: 'gpt-4o' # API model name.
@@ -277,6 +339,21 @@ process:
- python_lambda_mapper: # executing Python lambda function on data samples.
lambda_str: '' # A string representation of the lambda function to be executed on data samples. If empty, the identity function is used.
batched: False # A boolean indicating whether to process input data in batches.
+ - query_intent_detection_mapper: # Mapper to predict user's Intent label in query.
+ hf_model: 'bespin-global/klue-roberta-small-3i4k-intent-classification' # Hugginface model ID to predict intent label.
+ zh_to_en_hf_model: 'Helsinki-NLP/opus-mt-zh-en' # Translation model from Chinese to English. If not None, translate the query from Chinese to English.
+ model_params: {} # model param for hf_model.
+ zh_to_en_model_params: {} # model param for zh_to_hf_model.
+ - query_sentiment_detection_mapper: # Mapper to predict user's sentiment label ('negative', 'neutral' and 'positive') in query.
+ hf_model: 'mrm8488/distilroberta-finetuned-financial-news-sentiment-analysis' # Hugginface model ID to predict sentiment label.
+ zh_to_en_hf_model: 'Helsinki-NLP/opus-mt-zh-en' # Translation model from Chinese to English. If not None, translate the query from Chinese to English.
+ model_params: {} # model param for hf_model.
+ zh_to_en_model_params: {} # model param for zh_to_hf_model.
+ - query_topic_detection_mapper: # Mapper to predict user's topic label in query.
+ hf_model: 'dstefa/roberta-base_topic_classification_nyt_news' # Hugginface model ID to predict topic label.
+ zh_to_en_hf_model: 'Helsinki-NLP/opus-mt-zh-en' # Translation model from Chinese to English. If not None, translate the query from Chinese to English.
+ model_params: {} # model param for hf_model.
+ zh_to_en_model_params: {} # model param for zh_to_hf_model.
- relation_identity_mapper: # identify relation between two entity in the text.
api_model: 'gpt-4o' # API model name.
source_entity: '孙悟空' # The source entity of the relation to be dentified.
@@ -715,6 +792,9 @@ process:
upper_percentile: # the upper bound of the percentile to be sampled
lower_rank: # the lower rank of the percentile to be sampled
upper_rank: # the upper rank of the percentile to be sampled
+ - tags_specified_field_selector: # Selector to select samples based on the tags of specified field.
+ field_key: '__dj__meta__.query_sentiment_label' # the target keys corresponding to multi-level field information need to be separated by '.'
+ target_tags: ['happy', 'sad'] # Target tags to be select.
- topk_specified_field_selector: # selector to select top samples based on the sorted specified field
field_key: '' # the target keys corresponding to multi-level field information need to be separated by '.'
top_ratio: # ratio of selected top samples
@@ -723,6 +803,7 @@ process:
# Grouper ops.
- naive_grouper: # Group all samples to one batched sample.
+ - naive_reverse_grouper: # Split one batched sample to samples.
- key_value_grouper: # Group samples to batched samples according values in given keys.
group_by_keys: null # Group samples according values in the keys. Support for nested keys such as "__dj__stats__.text_len". It is [self.text_key] in default.
@@ -744,6 +825,20 @@ process:
try_num: 3 # The number of retry attempts when there is an API call error or output parsing error.
model_params: {} # Parameters for initializing the API model.
sampling_params: {} # Extra parameters passed to the API call. e.g {'temperature': 0.9, 'top_p': 0.95}
+ - meta_tags_aggregator: # Merge similar meta tags to one tag.
+ api_model: 'gpt-4o' # API model name.
+ meta_tag_key: '__dj__meta__.query_sentiment_label' # The key of the meta tag to be mapped.
+ target_tags: ['开心', '难过', '其他'] # The tags that is supposed to be mapped to.
+ api_endpoint: null # URL endpoint for the API.
+ response_path: null # Path to extract content from the API response. Defaults to 'choices.0.message.content'.
+ system_prompt: null # The system prompt.
+ input_template: null # The input template.
+ target_tag_template: null # The tap template for target tags.
+ tag_template: null # The tap template for each tag and its frequency.
+ output_pattern: null # The output pattern.
+ try_num: 3 # The number of retry attempts when there is an API call error or output parsing error.
+ model_params: {} # Parameters for initializing the API model.
+ sampling_params: {} # Extra parameters passed to the API call. e.g {'temperature': 0.9, 'top_p': 0.95}
- most_relavant_entities_aggregator: # Extract entities closely related to a given entity from some texts, and sort them in descending order of importance.
api_model: 'gpt-4o' # API model name.
entity: '孙悟空' # The given entity.
diff --git a/data_juicer/ops/aggregator/__init__.py b/data_juicer/ops/aggregator/__init__.py
index 4afe2974a..8aa87cbbd 100644
--- a/data_juicer/ops/aggregator/__init__.py
+++ b/data_juicer/ops/aggregator/__init__.py
@@ -1,8 +1,9 @@
from .entity_attribute_aggregator import EntityAttributeAggregator
+from .meta_tags_aggregator import MetaTagsAggregator
from .most_relavant_entities_aggregator import MostRelavantEntitiesAggregator
from .nested_aggregator import NestedAggregator
__all__ = [
- 'NestedAggregator', 'EntityAttributeAggregator',
+ 'NestedAggregator', 'MetaTagsAggregator', 'EntityAttributeAggregator',
'MostRelavantEntitiesAggregator'
]
diff --git a/data_juicer/ops/aggregator/entity_attribute_aggregator.py b/data_juicer/ops/aggregator/entity_attribute_aggregator.py
index 96fbbb63f..16ec5fd07 100644
--- a/data_juicer/ops/aggregator/entity_attribute_aggregator.py
+++ b/data_juicer/ops/aggregator/entity_attribute_aggregator.py
@@ -8,14 +8,10 @@
from data_juicer.utils.common_utils import (avg_split_string_list_under_limit,
is_string_list, nested_access,
nested_set)
-from data_juicer.utils.lazy_loader import LazyLoader
from data_juicer.utils.model_utils import get_model, prepare_model
from .nested_aggregator import NestedAggregator
-torch = LazyLoader('torch', 'torch')
-vllm = LazyLoader('vllm', 'vllm')
-
OP_NAME = 'entity_attribute_aggregator'
diff --git a/data_juicer/ops/aggregator/meta_tags_aggregator.py b/data_juicer/ops/aggregator/meta_tags_aggregator.py
new file mode 100644
index 000000000..808ef73da
--- /dev/null
+++ b/data_juicer/ops/aggregator/meta_tags_aggregator.py
@@ -0,0 +1,222 @@
+import re
+from typing import Dict, List, Optional
+
+from loguru import logger
+from pydantic import PositiveInt
+
+from data_juicer.ops.base_op import OPERATORS, Aggregator
+from data_juicer.utils.common_utils import is_string_list
+from data_juicer.utils.constant import Fields, MetaKeys
+from data_juicer.utils.model_utils import get_model, prepare_model
+
+OP_NAME = 'meta_tags_aggregator'
+
+
+# TODO: LLM-based inference.
+@OPERATORS.register_module(OP_NAME)
+class MetaTagsAggregator(Aggregator):
+ """
+ Merge similar meta tags to one tag.
+ """
+
+ DEFAULT_SYSTEM_PROMPT = ('给定一些标签以及这些标签出现的频次,合并意思相近的标签。\n'
+ '要求:\n'
+ '- 任务分为两种情况,一种是给定合并后的标签,需要将合并前的标签映射到'
+ '这些标签。如果给定的合并后的标签中有类似“其他”这种标签,将无法归类的'
+ '标签合并到“其他”。以下是这种情况的一个样例:\n'
+ '合并后的标签应限定在[科技, 健康, 其他]中。\n'
+ '| 合并前标签 | 频次 |\n'
+ '| ------ | ------ |\n'
+ '| 医疗 | 20 |\n'
+ '| 信息技术 | 16 |\n'
+ '| 学习 | 19 |\n'
+ '| 气候变化 | 22 |\n'
+ '| 人工智能 | 11 |\n'
+ '| 养生 | 17 |\n'
+ '| 科学创新 | 10 |\n'
+ '\n'
+ '## 分析:“信息技术”、“人工智能”、“科学创新”都属于“科技”类别,“医疗'
+ '”和“养生”跟“健康”有关联,“学习”、“气候变化”和“科技”还有“健康”关'
+ '联不强,应该被归为“其他”。\n'
+ '## 标签合并:\n'
+ '** 医疗归类为健康 **\n'
+ '** 信息技术归类为科技 **\n'
+ '** 学习归类为其他 **\n'
+ '** 气候变化归类为其他 **\n'
+ '** 人工智能归类为科技 **\n'
+ '** 养生归类为健康 **\n'
+ '** 科学创新归类为科技 **\n'
+ '- 另外一种情况没有事先给定合并后的标签,需要生成合理的标签类别:'
+ '| 合并前标签 | 频次 |\n'
+ '| ------ | ------ |\n'
+ '| 医疗 | 20 |\n'
+ '| 信息技术 | 16 |\n'
+ '| 学习 | 2 |\n'
+ '| 气候变化 | 1 |\n'
+ '| 人工智能 | 11 |\n'
+ '| 养生 | 17 |\n'
+ '| 科学创新 | 10 |\n'
+ '\n'
+ '## 分析:“信息技术”、“人工智能”、“科学创新”这三个标签比较相近,归为'
+ '同一类,都属于“科技”类别,“医疗”和“养生”都跟“健康”有关系,可以归'
+ '类为“健康”,“学习”和“气候变化”跟其他标签关联度不强,且频次较低,'
+ '统一归类为“其他”。\n'
+ '## 标签合并:\n'
+ '** 医疗归类为健康 **\n'
+ '** 信息技术归类为科技 **\n'
+ '** 学习归类为其他 **\n'
+ '** 气候变化归类为其他 **\n'
+ '** 人工智能归类为科技 **\n'
+ '** 养生归类为健康 **\n'
+ '** 科学创新归类为科技 **\n')
+
+ DEFAULT_INPUT_TEMPLATE = ('{target_tag_str}'
+ '| 合并前标签 | 频次 |\n'
+ '| ------ | ------ |\n'
+ '{tag_strs}')
+ DEFAULT_TARGET_TAG_TEMPLATE = '合并后的标签应限定在[{target_tags}]中。\n'
+ DEFAULT_TAG_TEMPLATE = '| {tag} | {cnt} |'
+
+ DEFAULT_OUTPUT_PATTERN = r'\*\*\s*(\w+)归类为(\w+)\s*\*\*'
+
+ def __init__(self,
+ api_model: str = 'gpt-4o',
+ meta_tag_key: str = MetaKeys.dialog_sentiment_labels,
+ target_tags: Optional[List[str]] = None,
+ *,
+ api_endpoint: Optional[str] = None,
+ response_path: Optional[str] = None,
+ system_prompt: Optional[str] = None,
+ input_template: Optional[str] = None,
+ target_tag_template: Optional[str] = None,
+ tag_template: Optional[str] = None,
+ output_pattern: Optional[str] = None,
+ try_num: PositiveInt = 3,
+ model_params: Dict = {},
+ sampling_params: Dict = {},
+ **kwargs):
+ """
+ Initialization method.
+ :param api_model: API model name.
+ :param meta_tag_key: The key of the meta tag to be mapped.
+ :param target_tags: The tags that is supposed to be mapped to.
+ :param api_endpoint: URL endpoint for the API.
+ :param response_path: Path to extract content from the API response.
+ Defaults to 'choices.0.message.content'.
+ :param system_prompt: The system prompt.
+ :param input_template: The input template.
+ :param target_tag_template: The tap template for target tags.
+ :param tag_template: The tap template for each tag and its
+ frequency.
+ :param output_pattern: The output pattern.
+ :param try_num: The number of retry attempts when there is an API
+ call error or output parsing error.
+ :param model_params: Parameters for initializing the API model.
+ :param sampling_params: Extra parameters passed to the API call.
+ e.g {'temperature': 0.9, 'top_p': 0.95}
+ :param kwargs: Extra keyword arguments.
+ """
+ super().__init__(**kwargs)
+
+ self.meta_tag_key = meta_tag_key
+
+ self.system_prompt = system_prompt or self.DEFAULT_SYSTEM_PROMPT
+ self.input_template = input_template or self.DEFAULT_INPUT_TEMPLATE
+ target_tag_template = target_tag_template or \
+ self.DEFAULT_TARGET_TAG_TEMPLATE
+ self.tag_template = tag_template or self.DEFAULT_TAG_TEMPLATE
+ self.output_pattern = output_pattern or self.DEFAULT_OUTPUT_PATTERN
+
+ self.target_tag_str = ''
+ if target_tags:
+ self.target_tag_str = target_tag_template.format(
+ target_tags=', '.join(target_tags))
+
+ self.sampling_params = sampling_params
+ self.model_key = prepare_model(model_type='api',
+ model=api_model,
+ endpoint=api_endpoint,
+ response_path=response_path,
+ return_processor=True,
+ **model_params)
+
+ self.try_num = try_num
+
+ def parse_output(self, response):
+ pattern = re.compile(self.output_pattern, re.VERBOSE | re.DOTALL)
+ matches = pattern.findall(response)
+ tag_map = {tag1: tag2 for tag1, tag2 in matches}
+ return tag_map
+
+ def meta_map(self, meta_cnts, rank=None):
+
+ model, _ = get_model(self.model_key, rank, self.use_cuda())
+
+ tag_strs = [
+ self.tag_template.format(tag=k, cnt=meta_cnts[k])
+ for k in meta_cnts
+ ]
+ input_prompt = self.input_template.format(
+ target_tag_str=self.target_tag_str, tag_strs='\n'.join(tag_strs))
+
+ messages = [{
+ 'role': 'system',
+ 'content': self.system_prompt
+ }, {
+ 'role': 'user',
+ 'content': input_prompt
+ }]
+ tag_map = {}
+ for i in range(self.try_num):
+ try:
+ response = model(messages, **self.sampling_params)
+ tag_map = self.parse_output(response)
+ if len(tag_map) > 0:
+ break
+ except Exception as e:
+ logger.warning(f'Exception: {e}')
+
+ return tag_map
+
+ def process_single(self, sample=None, rank=None):
+
+ if Fields.meta not in sample:
+ logger.warning('Not any meta in the sample!')
+ return sample
+
+ metas = sample[Fields.meta]
+ # if not batched sample
+ if not isinstance(metas, list):
+ logger.warning('Not a batched sample!')
+ return sample
+
+ meta_cnts = {}
+
+ def update_dict(key):
+ if key in meta_cnts:
+ meta_cnts[key] += 1
+ else:
+ meta_cnts[key] = 1
+
+ for meta in metas:
+ tag = meta[self.meta_tag_key]
+ if isinstance(tag, str):
+ update_dict(tag)
+ elif is_string_list(tag):
+ for t in tag:
+ update_dict(t)
+ else:
+ logger.warning('Meta tag must be string or list of string!')
+ return sample
+
+ tag_map = self.meta_map(meta_cnts, rank=rank)
+ for i in range(len(metas)):
+ tag = metas[i][self.meta_tag_key]
+ if isinstance(tag, str) and tag in tag_map:
+ metas[i][self.meta_tag_key] = tag_map[tag]
+ elif is_string_list(tag):
+ metas[i][self.meta_tag_key] = [
+ tag_map[t] if t in tag_map else t for t in tag
+ ]
+
+ return sample
diff --git a/data_juicer/ops/aggregator/most_relavant_entities_aggregator.py b/data_juicer/ops/aggregator/most_relavant_entities_aggregator.py
index 69e1a209c..7ca49f505 100644
--- a/data_juicer/ops/aggregator/most_relavant_entities_aggregator.py
+++ b/data_juicer/ops/aggregator/most_relavant_entities_aggregator.py
@@ -7,14 +7,10 @@
from data_juicer.ops.base_op import OPERATORS, Aggregator
from data_juicer.utils.common_utils import (is_string_list, nested_access,
nested_set)
-from data_juicer.utils.lazy_loader import LazyLoader
from data_juicer.utils.model_utils import get_model, prepare_model
from ..common import split_text_by_punctuation
-torch = LazyLoader('torch', 'torch')
-vllm = LazyLoader('vllm', 'vllm')
-
OP_NAME = 'most_relavant_entities_aggregator'
diff --git a/data_juicer/ops/aggregator/nested_aggregator.py b/data_juicer/ops/aggregator/nested_aggregator.py
index 124eb1470..ab25e057d 100644
--- a/data_juicer/ops/aggregator/nested_aggregator.py
+++ b/data_juicer/ops/aggregator/nested_aggregator.py
@@ -6,12 +6,8 @@
from data_juicer.ops.base_op import OPERATORS, Aggregator
from data_juicer.utils.common_utils import (avg_split_string_list_under_limit,
is_string_list, nested_access)
-from data_juicer.utils.lazy_loader import LazyLoader
from data_juicer.utils.model_utils import get_model, prepare_model
-torch = LazyLoader('torch', 'torch')
-vllm = LazyLoader('vllm', 'vllm')
-
OP_NAME = 'nested_aggregator'
diff --git a/data_juicer/ops/grouper/__init__.py b/data_juicer/ops/grouper/__init__.py
index 048b305e4..f81ba6aec 100644
--- a/data_juicer/ops/grouper/__init__.py
+++ b/data_juicer/ops/grouper/__init__.py
@@ -1,4 +1,5 @@
from .key_value_grouper import KeyValueGrouper
from .naive_grouper import NaiveGrouper
+from .naive_reverse_grouper import NaiveReverseGrouper
-__all__ = ['NaiveGrouper', 'KeyValueGrouper']
+__all__ = ['KeyValueGrouper', 'NaiveGrouper', 'NaiveReverseGrouper']
diff --git a/data_juicer/ops/grouper/naive_reverse_grouper.py b/data_juicer/ops/grouper/naive_reverse_grouper.py
new file mode 100644
index 000000000..2535205b9
--- /dev/null
+++ b/data_juicer/ops/grouper/naive_reverse_grouper.py
@@ -0,0 +1,26 @@
+from ..base_op import OPERATORS, Grouper, convert_dict_list_to_list_dict
+
+
+@OPERATORS.register_module('naive_reverse_grouper')
+class NaiveReverseGrouper(Grouper):
+ """Split batched samples to samples. """
+
+ def __init__(self, *args, **kwargs):
+ """
+ Initialization method.
+
+ :param args: extra args
+ :param kwargs: extra args
+ """
+ super().__init__(*args, **kwargs)
+
+ def process(self, dataset):
+
+ if len(dataset) == 0:
+ return dataset
+
+ samples = []
+ for sample in dataset:
+ samples.extend(convert_dict_list_to_list_dict(sample))
+
+ return samples
diff --git a/data_juicer/ops/mapper/__init__.py b/data_juicer/ops/mapper/__init__.py
index 9b86b83dc..8ffe7cc8e 100644
--- a/data_juicer/ops/mapper/__init__.py
+++ b/data_juicer/ops/mapper/__init__.py
@@ -8,6 +8,10 @@
from .clean_html_mapper import CleanHtmlMapper
from .clean_ip_mapper import CleanIpMapper
from .clean_links_mapper import CleanLinksMapper
+from .dialog_intent_detection_mapper import DialogIntentDetectionMapper
+from .dialog_sentiment_detection_mapper import DialogSentimentDetectionMapper
+from .dialog_sentiment_intensity_mapper import DialogSentimentIntensityMapper
+from .dialog_topic_detection_mapper import DialogTopicDetectionMapper
from .expand_macro_mapper import ExpandMacroMapper
from .extract_entity_attribute_mapper import ExtractEntityAttributeMapper
from .extract_entity_relation_mapper import ExtractEntityRelationMapper
@@ -33,6 +37,9 @@
from .punctuation_normalization_mapper import PunctuationNormalizationMapper
from .python_file_mapper import PythonFileMapper
from .python_lambda_mapper import PythonLambdaMapper
+from .query_intent_detection_mapper import QueryIntentDetectionMapper
+from .query_sentiment_detection_mapper import QuerySentimentDetectionMapper
+from .query_topic_detection_mapper import QueryTopicDetectionMapper
from .relation_identity_mapper import RelationIdentityMapper
from .remove_bibliography_mapper import RemoveBibliographyMapper
from .remove_comments_mapper import RemoveCommentsMapper
@@ -71,6 +78,8 @@
'AudioFFmpegWrappedMapper', 'CalibrateQAMapper', 'CalibrateQueryMapper',
'CalibrateResponseMapper', 'ChineseConvertMapper', 'CleanCopyrightMapper',
'CleanEmailMapper', 'CleanHtmlMapper', 'CleanIpMapper', 'CleanLinksMapper',
+ 'DialogIntentDetectionMapper', 'DialogSentimentDetectionMapper',
+ 'DialogSentimentIntensityMapper', 'DialogTopicDetectionMapper',
'ExpandMacroMapper', 'ExtractEntityAttributeMapper',
'ExtractEntityRelationMapper', 'ExtractEventMapper',
'ExtractKeywordMapper', 'ExtractNicknameMapper',
@@ -81,18 +90,20 @@
'ImageTaggingMapper', 'NlpaugEnMapper', 'NlpcdaZhMapper',
'OptimizeQAMapper', 'OptimizeQueryMapper', 'OptimizeResponseMapper',
'PairPreferenceMapper', 'PunctuationNormalizationMapper',
- 'PythonFileMapper', 'PythonLambdaMapper', 'RelationIdentityMapper',
- 'RemoveBibliographyMapper', 'RemoveCommentsMapper', 'RemoveHeaderMapper',
- 'RemoveLongWordsMapper', 'RemoveNonChineseCharacterlMapper',
- 'RemoveRepeatSentencesMapper', 'RemoveSpecificCharsMapper',
- 'RemoveTableTextMapper', 'RemoveWordsWithIncorrectSubstringsMapper',
- 'ReplaceContentMapper', 'SentenceSplitMapper', 'TextChunkMapper',
- 'VideoCaptioningFromAudioMapper', 'VideoCaptioningFromFramesMapper',
- 'VideoCaptioningFromSummarizerMapper', 'VideoCaptioningFromVideoMapper',
- 'VideoExtractFramesMapper', 'VideoFFmpegWrappedMapper',
- 'VideoFaceBlurMapper', 'VideoRemoveWatermarkMapper',
- 'VideoResizeAspectRatioMapper', 'VideoResizeResolutionMapper',
- 'VideoSplitByDurationMapper', 'VideoSplitByKeyFrameMapper',
- 'VideoSplitBySceneMapper', 'VideoTaggingFromAudioMapper',
- 'VideoTaggingFromFramesMapper', 'WhitespaceNormalizationMapper'
+ 'PythonFileMapper', 'PythonLambdaMapper', 'QuerySentimentDetectionMapper',
+ 'QueryIntentDetectionMapper', 'QueryTopicDetectionMapper',
+ 'RelationIdentityMapper', 'RemoveBibliographyMapper',
+ 'RemoveCommentsMapper', 'RemoveHeaderMapper', 'RemoveLongWordsMapper',
+ 'RemoveNonChineseCharacterlMapper', 'RemoveRepeatSentencesMapper',
+ 'RemoveSpecificCharsMapper', 'RemoveTableTextMapper',
+ 'RemoveWordsWithIncorrectSubstringsMapper', 'ReplaceContentMapper',
+ 'SentenceSplitMapper', 'TextChunkMapper', 'VideoCaptioningFromAudioMapper',
+ 'VideoCaptioningFromFramesMapper', 'VideoCaptioningFromSummarizerMapper',
+ 'VideoCaptioningFromVideoMapper', 'VideoExtractFramesMapper',
+ 'VideoFFmpegWrappedMapper', 'VideoFaceBlurMapper',
+ 'VideoRemoveWatermarkMapper', 'VideoResizeAspectRatioMapper',
+ 'VideoResizeResolutionMapper', 'VideoSplitByDurationMapper',
+ 'VideoSplitByKeyFrameMapper', 'VideoSplitBySceneMapper',
+ 'VideoTaggingFromAudioMapper', 'VideoTaggingFromFramesMapper',
+ 'WhitespaceNormalizationMapper'
]
diff --git a/data_juicer/ops/mapper/calibrate_qa_mapper.py b/data_juicer/ops/mapper/calibrate_qa_mapper.py
index 8480ee899..bf9686409 100644
--- a/data_juicer/ops/mapper/calibrate_qa_mapper.py
+++ b/data_juicer/ops/mapper/calibrate_qa_mapper.py
@@ -55,6 +55,8 @@ def __init__(self,
:param reference_template: Template for formatting the reference text.
:param qa_pair_template: Template for formatting question-answer pairs.
:param output_pattern: Regular expression for parsing model output.
+ :param try_num: The number of retry attempts when there is an API
+ call error or output parsing error.
:param model_params: Parameters for initializing the API model.
:param sampling_params: Extra parameters passed to the API call.
e.g {'temperature': 0.9, 'top_p': 0.95}
diff --git a/data_juicer/ops/mapper/dialog_intent_detection_mapper.py b/data_juicer/ops/mapper/dialog_intent_detection_mapper.py
new file mode 100644
index 000000000..7c8cba9ed
--- /dev/null
+++ b/data_juicer/ops/mapper/dialog_intent_detection_mapper.py
@@ -0,0 +1,216 @@
+import re
+from typing import Dict, List, Optional
+
+from loguru import logger
+from pydantic import NonNegativeInt, PositiveInt
+
+from data_juicer.ops.base_op import OPERATORS, Mapper
+from data_juicer.utils.common_utils import nested_set
+from data_juicer.utils.constant import Fields, MetaKeys
+from data_juicer.utils.model_utils import get_model, prepare_model
+
+OP_NAME = 'dialog_intent_detection_mapper'
+
+
+# TODO: LLM-based inference.
+@OPERATORS.register_module(OP_NAME)
+class DialogIntentDetectionMapper(Mapper):
+ """
+ Mapper to generate user's intent labels in dialog. Input from
+ history_key, query_key and response_key. Output lists of
+ labels and analysis for queries in the dialog, which is
+ store in 'dialog_intent_labels' and
+ 'dialog_intent_labels_analysis' in Data-Juicer meta field.
+ """
+
+ DEFAULT_SYSTEM_PROMPT = (
+ '请判断用户和LLM多轮对话中用户的意图。\n'
+ '要求:\n'
+ '- 需要先进行分析,然后列出用户所具有的意图,下面是一个样例,请模仿样例格式输出'
+ '。\n'
+ '用户:你好,我最近对人工智能很感兴趣,能给我讲讲什么是机器学习吗?\n'
+ '意图分析:用户在请求信息,希望了解有关机器学习的基础知识。\n'
+ '意图类别:信息查找\n'
+ 'LLM:你好!当然可以。机器学习是一种人工智能方法,允许计算机通过数据自动改进和学习。\n'
+ '用户:听起来很有趣,有没有推荐的入门书籍或资料?\n'
+ '意图分析:用户在请求建议,希望获取关于机器学习的入门资源。\n'
+ '意图类别:请求建议\n'
+ 'LLM:有很多不错的入门书籍和资源。一本常被推荐的书是《Python机器学习实践》(Python'
+ ' Machine Learning),它涵盖了基础知识和一些实际案例。此外,您还可以参考Coursera'
+ '或edX上的在线课程,这些课程提供了系统的学习路径。\n'
+ '用户:谢谢你的建议!我还想知道,学习机器学习需要什么样的数学基础?\n'
+ '意图分析:用户在寻求信息,希望了解学习机器学习所需的前提条件,特别是在数学方面。\n'
+ '意图类别:信息查找\n'
+ 'LLM:学习机器学习通常需要一定的数学基础,特别是线性代数、概率论和统计学。这些数学领'
+ '域帮助理解算法的工作原理和数据模式分析。如果您对这些主题不太熟悉,建议先从相关基础'
+ '书籍或在线资源开始学习。\n'
+ '用户:明白了,我会先补习这些基础知识。再次感谢你的帮助!\n'
+ '意图分析:用户表达感谢,并表示计划付诸行动来补充所需的基础知识。\n'
+ '意图类别:其他')
+ DEFAULT_QUERY_TEMPLATE = '用户:{query}\n'
+ DEFAULT_RESPONSE_TEMPLATE = 'LLM:{response}\n'
+ DEFAULT_CANDIDATES_TEMPLATE = '备选意图类别:[{candidate_str}]'
+ DEFAULT_ANALYSIS_TEMPLATE = '意图分析:{analysis}\n'
+ DEFAULT_LABELS_TEMPLATE = '意图类别:{labels}\n'
+ DEFAULT_ANALYSIS_PATTERN = '意图分析:(.*?)\n'
+ DEFAULT_LABELS_PATTERN = '意图类别:(.*?)($|\n)'
+
+ def __init__(self,
+ api_model: str = 'gpt-4o',
+ intent_candidates: Optional[List[str]] = None,
+ max_round: NonNegativeInt = 10,
+ *,
+ api_endpoint: Optional[str] = None,
+ response_path: Optional[str] = None,
+ system_prompt: Optional[str] = None,
+ query_template: Optional[str] = None,
+ response_template: Optional[str] = None,
+ candidate_template: Optional[str] = None,
+ analysis_template: Optional[str] = None,
+ labels_template: Optional[str] = None,
+ analysis_pattern: Optional[str] = None,
+ labels_pattern: Optional[str] = None,
+ try_num: PositiveInt = 3,
+ model_params: Dict = {},
+ sampling_params: Dict = {},
+ **kwargs):
+ """
+ Initialization method.
+
+ :param api_model: API model name.
+ :param intent_candidates: The output intent candidates. Use the
+ intent labels of the open domain if it is None.
+ :param max_round: The max num of round in the dialog to build the
+ prompt.
+ :param api_endpoint: URL endpoint for the API.
+ :param response_path: Path to extract content from the API response.
+ Defaults to 'choices.0.message.content'.
+ :param system_prompt: System prompt for the task.
+ :param query_template: Template for query part to build the input
+ prompt.
+ :param response_template: Template for response part to build the
+ input prompt.
+ :param candidate_template: Template for intent candidates to
+ build the input prompt.
+ :param analysis_template: Template for analysis part to build the
+ input prompt.
+ :param labels_template: Template for labels to build the
+ input prompt.
+ :param analysis_pattern: Pattern to parse the return intent
+ analysis.
+ :param labels_pattern: Pattern to parse the return intent
+ labels.
+ :param try_num: The number of retry attempts when there is an API
+ call error or output parsing error.
+ :param model_params: Parameters for initializing the API model.
+ :param sampling_params: Extra parameters passed to the API call.
+ e.g {'temperature': 0.9, 'top_p': 0.95}
+ :param kwargs: Extra keyword arguments.
+ """
+ super().__init__(**kwargs)
+
+ self.intent_candidates = intent_candidates
+ self.max_round = max_round
+
+ self.system_prompt = system_prompt or self.DEFAULT_SYSTEM_PROMPT
+ self.query_template = query_template or self.DEFAULT_QUERY_TEMPLATE
+ self.response_template = response_template or \
+ self.DEFAULT_RESPONSE_TEMPLATE
+ self.candidate_template = candidate_template or \
+ self.DEFAULT_CANDIDATES_TEMPLATE
+ self.analysis_template = analysis_template or \
+ self.DEFAULT_ANALYSIS_TEMPLATE
+ self.labels_template = labels_template or \
+ self.DEFAULT_LABELS_TEMPLATE
+ self.analysis_pattern = analysis_pattern or \
+ self.DEFAULT_ANALYSIS_PATTERN
+ self.labels_pattern = labels_pattern or \
+ self.DEFAULT_LABELS_PATTERN
+
+ self.sampling_params = sampling_params
+
+ self.model_key = prepare_model(model_type='api',
+ model=api_model,
+ endpoint=api_endpoint,
+ response_path=response_path,
+ **model_params)
+
+ self.try_num = try_num
+
+ def build_input(self, history, query):
+
+ if self.intent_candidates:
+ input_prompt = self.candidate_template.format(
+ candidate_str=','.join(self.intent_candidates))
+ else:
+ input_prompt = ''
+
+ if self.max_round > 0:
+ input_prompt += ''.join(history[-self.max_round * 4:])
+
+ input_prompt += self.query_template.format(query=query[0])
+
+ return input_prompt
+
+ def parse_output(self, response):
+ analysis = ''
+ labels = ''
+
+ match = re.search(self.analysis_pattern, response)
+ if match:
+ analysis = match.group(1)
+
+ match = re.search(self.labels_pattern, response)
+ if match:
+ labels = match.group(1)
+
+ return analysis, labels
+
+ def process_single(self, sample, rank=None):
+ client = get_model(self.model_key, rank=rank)
+
+ analysis_list = []
+ labels_list = []
+ history = []
+
+ dialog = sample[self.history_key]
+ if self.query_key in sample and sample[self.query_key]:
+ if self.response_key in sample and sample[self.response_key]:
+ dialog.append(
+ (sample[self.query_key], sample[self.response_key]))
+ else:
+ dialog.append((sample[self.query_key], ''))
+
+ for qa in dialog:
+ input_prompt = self.build_input(history, qa)
+ messages = [{
+ 'role': 'system',
+ 'content': self.system_prompt,
+ }, {
+ 'role': 'user',
+ 'content': input_prompt,
+ }]
+
+ for _ in range(self.try_num):
+ try:
+ response = client(messages, **self.sampling_params)
+ analysis, labels = self.parse_output(response)
+ if len(analysis) > 0:
+ break
+ except Exception as e:
+ logger.warning(f'Exception: {e}')
+
+ analysis_list.append(analysis)
+ labels_list.append(labels)
+
+ history.append(self.query_template.format(query=qa[0]))
+ history.append(self.analysis_template.format(analysis=analysis))
+ history.append(self.labels_template.format(labels=labels))
+ history.append(self.response_template.format(response=qa[1]))
+
+ analysis_key = f'{Fields.meta}.{MetaKeys.dialog_intent_labels_analysis}' # noqa: E501
+ sample = nested_set(sample, analysis_key, analysis_list)
+ labels_key = f'{Fields.meta}.{MetaKeys.dialog_intent_labels}'
+ sample = nested_set(sample, labels_key, labels_list)
+
+ return sample
diff --git a/data_juicer/ops/mapper/dialog_sentiment_detection_mapper.py b/data_juicer/ops/mapper/dialog_sentiment_detection_mapper.py
new file mode 100644
index 000000000..33bccc5ce
--- /dev/null
+++ b/data_juicer/ops/mapper/dialog_sentiment_detection_mapper.py
@@ -0,0 +1,195 @@
+import re
+from typing import Dict, Optional
+
+from loguru import logger
+from pydantic import NonNegativeInt, PositiveInt
+
+from data_juicer.ops.base_op import OPERATORS, Mapper
+from data_juicer.utils.common_utils import nested_set
+from data_juicer.utils.constant import Fields, MetaKeys
+from data_juicer.utils.model_utils import get_model, prepare_model
+
+OP_NAME = 'dialog_sentiment_detection_mapper'
+
+
+# TODO: LLM-based inference.
+@OPERATORS.register_module(OP_NAME)
+class DialogSentimentDetectionMapper(Mapper):
+ """
+ Mapper to generate user's sentiment labels in dialog. Input from
+ history_key, query_key and response_key. Output lists of
+ labels and analysis for queries in the dialog, which is
+ store in 'dialog_sentiment_labels' and
+ 'dialog_sentiment_labels_analysis' in Data-Juicer meta field.
+ """
+
+ DEFAULT_SYSTEM_PROMPT = ('请判断用户和LLM多轮对话中用户所具有的情绪。\n'
+ '要求:\n'
+ '- 需要先进行分析,然后罗列用户所具有的情绪,下面是一个样例,请模仿样例格式输出'
+ '。\n'
+ '用户:最近工作压力好大,我觉得整个人都快被压垮了。\n'
+ '情感分析:用户的言语中透露出明显的压力和疲惫感,可能还夹杂着一些无助和焦虑。\n'
+ '情感:压力、疲惫、无助、焦虑\n'
+ 'LLM:听起来你真的承受了很多,面临这种情况确实不容易。有没有考虑过找一些放松的'
+ '方式,比如听音乐或者散步来减轻压力呢?\n'
+ '用户:试过了,但是好像没什么效果,每天的事情都堆积如山。\n'
+ '情感分析:用户感到无力解决现状,有挫败感,并且对尝试放松的方式失去信心。\n'
+ '情感:无力、挫败\n'
+ 'LLM:我理解你的感受,有时候压力积累到一定程度确实让人难以承受。或许你可以尝试'
+ '规划一下时间,把任务分成小块来完成,这样可能会减少一些压力感。\n'
+ '用户:这个主意不错,我会试着让自己更有条理一些,谢谢你的建议。\n'
+ '情感分析:用户对建议表现出认同和感激,同时展现出试图积极面对问题的态度。\n'
+ '情感:认同、感激、积极\n'
+ 'LLM:不用谢,我很高兴能帮到你。记得给自己一些时间去适应新的计划,有任何需要'
+ '随时可以跟我说哦!\n')
+ DEFAULT_QUERY_TEMPLATE = '用户:{query}\n'
+ DEFAULT_RESPONSE_TEMPLATE = 'LLM:{response}\n'
+ DEFAULT_ANALYSIS_TEMPLATE = '情感分析:{analysis}\n'
+ DEFAULT_LABELS_TEMPLATE = '情感:{labels}\n'
+ DEFAULT_ANALYSIS_PATTERN = '情感分析:(.*?)\n'
+ DEFAULT_LABELS_PATTERN = '情感:(.*?)($|\n)'
+
+ def __init__(self,
+ api_model: str = 'gpt-4o',
+ max_round: NonNegativeInt = 10,
+ *,
+ api_endpoint: Optional[str] = None,
+ response_path: Optional[str] = None,
+ system_prompt: Optional[str] = None,
+ query_template: Optional[str] = None,
+ response_template: Optional[str] = None,
+ analysis_template: Optional[str] = None,
+ labels_template: Optional[str] = None,
+ analysis_pattern: Optional[str] = None,
+ labels_pattern: Optional[str] = None,
+ try_num: PositiveInt = 3,
+ model_params: Dict = {},
+ sampling_params: Dict = {},
+ **kwargs):
+ """
+ Initialization method.
+
+ :param api_model: API model name.
+ :param max_round: The max num of round in the dialog to build the
+ prompt.
+ :param api_endpoint: URL endpoint for the API.
+ :param response_path: Path to extract content from the API response.
+ Defaults to 'choices.0.message.content'.
+ :param system_prompt: System prompt for the task.
+ :param query_template: Template for query part to build the input
+ prompt.
+ :param response_template: Template for response part to build the
+ input prompt.
+ :param analysis_template: Template for analysis part to build the
+ input prompt.
+ :param labels_template: Template for labels part to build the
+ input prompt.
+ :param analysis_pattern: Pattern to parse the return sentiment
+ analysis.
+ :param labels_pattern: Pattern to parse the return sentiment
+ labels.
+ :param try_num: The number of retry attempts when there is an API
+ call error or output parsing error.
+ :param model_params: Parameters for initializing the API model.
+ :param sampling_params: Extra parameters passed to the API call.
+ e.g {'temperature': 0.9, 'top_p': 0.95}
+ :param kwargs: Extra keyword arguments.
+ """
+ super().__init__(**kwargs)
+
+ self.max_round = max_round
+
+ self.system_prompt = system_prompt or self.DEFAULT_SYSTEM_PROMPT
+ self.query_template = query_template or self.DEFAULT_QUERY_TEMPLATE
+ self.response_template = response_template or \
+ self.DEFAULT_RESPONSE_TEMPLATE
+ self.analysis_template = analysis_template or \
+ self.DEFAULT_ANALYSIS_TEMPLATE
+ self.labels_template = labels_template or \
+ self.DEFAULT_LABELS_TEMPLATE
+ self.analysis_pattern = analysis_pattern or \
+ self.DEFAULT_ANALYSIS_PATTERN
+ self.labels_pattern = labels_pattern or \
+ self.DEFAULT_LABELS_PATTERN
+
+ self.sampling_params = sampling_params
+
+ self.model_key = prepare_model(model_type='api',
+ model=api_model,
+ endpoint=api_endpoint,
+ response_path=response_path,
+ **model_params)
+
+ self.try_num = try_num
+
+ def build_input(self, history, query):
+ if self.max_round > 0:
+ input_prompt = ''.join(history[-self.max_round * 4:])
+ else:
+ input_prompt = ''
+ input_prompt += self.query_template.format(query=query[0])
+
+ return input_prompt
+
+ def parse_output(self, response):
+ analysis = ''
+ labels = ''
+
+ match = re.search(self.analysis_pattern, response)
+ if match:
+ analysis = match.group(1)
+
+ match = re.search(self.labels_pattern, response)
+ if match:
+ labels = match.group(1)
+
+ return analysis, labels
+
+ def process_single(self, sample, rank=None):
+ client = get_model(self.model_key, rank=rank)
+
+ analysis_list = []
+ labels_list = []
+ history = []
+
+ dialog = sample[self.history_key]
+ if self.query_key in sample and sample[self.query_key]:
+ if self.response_key in sample and sample[self.response_key]:
+ dialog.append(
+ (sample[self.query_key], sample[self.response_key]))
+ else:
+ dialog.append((sample[self.query_key], ''))
+
+ for qa in dialog:
+ input_prompt = self.build_input(history, qa)
+ messages = [{
+ 'role': 'system',
+ 'content': self.system_prompt,
+ }, {
+ 'role': 'user',
+ 'content': input_prompt,
+ }]
+
+ for _ in range(self.try_num):
+ try:
+ response = client(messages, **self.sampling_params)
+ analysis, labels = self.parse_output(response)
+ if len(analysis) > 0:
+ break
+ except Exception as e:
+ logger.warning(f'Exception: {e}')
+
+ analysis_list.append(analysis)
+ labels_list.append(labels)
+
+ history.append(self.query_template.format(query=qa[0]))
+ history.append(self.analysis_template.format(analysis=analysis))
+ history.append(self.labels_template.format(labels=labels))
+ history.append(self.response_template.format(response=qa[1]))
+
+ analysis_key = f'{Fields.meta}.{MetaKeys.dialog_sentiment_labels_analysis}' # noqa: E501
+ sample = nested_set(sample, analysis_key, analysis_list)
+ labels_key = f'{Fields.meta}.{MetaKeys.dialog_sentiment_labels}'
+ sample = nested_set(sample, labels_key, labels_list)
+
+ return sample
diff --git a/data_juicer/ops/mapper/dialog_sentiment_intensity_mapper.py b/data_juicer/ops/mapper/dialog_sentiment_intensity_mapper.py
new file mode 100644
index 000000000..198314ee3
--- /dev/null
+++ b/data_juicer/ops/mapper/dialog_sentiment_intensity_mapper.py
@@ -0,0 +1,207 @@
+import re
+from typing import Dict, Optional
+
+from loguru import logger
+from pydantic import NonNegativeInt, PositiveInt
+
+from data_juicer.ops.base_op import OPERATORS, Mapper
+from data_juicer.utils.common_utils import nested_set
+from data_juicer.utils.constant import Fields, MetaKeys
+from data_juicer.utils.model_utils import get_model, prepare_model
+
+OP_NAME = 'dialog_sentiment_intensity_mapper'
+
+
+# TODO: LLM-based inference.
+@OPERATORS.register_module(OP_NAME)
+class DialogSentimentIntensityMapper(Mapper):
+ """
+ Mapper to predict user's sentiment intensity (from -5 to 5 in default
+ prompt) in dialog. Input from history_key, query_key and
+ response_key. Output lists of intensities and analysis for queries in
+ the dialog, which is store in 'dialog_sentiment_intensity' and
+ 'dialog_sentiment_intensity_analysis' in Data-Juicer meta field.
+ """
+
+ DEFAULT_SYSTEM_PROMPT = ('请判断用户和LLM多轮对话中用户的情绪变化。\n'
+ '要求:\n'
+ '- 用户情绪值是-5到5之间到整数,-5表示极度负面,5表示极度正面,'
+ '-5到5之间数值表示情绪从负面逐渐到正面的变化过程,0代表情呈绪中性。\n'
+ '- 需要先进行分析,然后确定用户的情绪值,下面是一个样例,请模仿样例格式输出。\n'
+ '用户:你好,我对可持续发展的定义有点模糊,帮我解释一下?\n'
+ '情绪分析:刚开始,还没得到LLM回复,用户情绪呈中性。\n'
+ '情绪值:0\n'
+ 'LLM:当然可以!可持续发展是指在满足当代人的需求的同时,不损害子孙后代满足其自'
+ '身需求的能力的发展模式。它包括经济发展、社会发展和环境保护三个主要方面。通过合'
+ '理利用资源和保护环境,我们可以确保未来的世代也能享有健全的生态系统和经济制度。\n'
+ '用户:谢谢你的解释!那你能告诉我一些普通人可以采取的可持续生活方式吗?\n'
+ '情绪分析:对回答感到满意,情绪正面。\n'
+ '情绪值:1\n'
+ 'LLM:当然可以,普通人可以通过减少一次性产品的使用、选择公共交通或拼车、节约用'
+ '水、以及支持本地和可持续发展的企业等方式来践行可持续生活。此外,关注垃圾分类和'
+ '多用电子账单也是不错的选择。\n'
+ '用户:你提到支持本地企业,这一点我很感兴趣。能详细说说为什么这对可持续发展有促'
+ '进作用吗?\n'
+ '情绪分析:觉得回答实用且具体,情绪进一步转好。\n'
+ '情绪值:2\n'
+ 'LLM:呃,我最近发现了一部新电影,讲述了一个关于外星人和地球土著合作保护环境的'
+ '故事。虽然它是科幻片,但很有启发性,推荐你去看看。\n'
+ '用户:什么吗,根本是答非所问。\n'
+ '情绪分析:LLM没有回应问题而是提到无关内容,导致用户情绪直线下降。\n'
+ '情绪值:-2\n'
+ 'LLM:抱歉刚才的偏题!支持本地企业有助于减少长途运输产生的碳足迹,使供应链更加'
+ '环保。此外,本地企业也更有可能采用可持续的生产方式,同时促进社区经济的繁荣。\n'
+ '用户:还行吧,算你能够掰回来。\n'
+ '情绪分析:问题得到解答,问题偏题得到纠正,情绪稍有好转。\n'
+ '情绪值:-1\n')
+ DEFAULT_QUERY_TEMPLATE = '用户:{query}\n'
+ DEFAULT_RESPONSE_TEMPLATE = 'LLM:{response}\n'
+ DEFAULT_ANALYSIS_TEMPLATE = '情绪分析:{analysis}\n'
+ DEFAULT_INTENSITY_TEMPLATE = '情绪值:{intensity}\n'
+ DEFAULT_ANALYSIS_PATTERN = '情绪分析:(.*?)\n'
+ DEFAULT_INTENSITY_PATTERN = '情绪值:(.*?)($|\n)'
+
+ def __init__(self,
+ api_model: str = 'gpt-4o',
+ max_round: NonNegativeInt = 10,
+ *,
+ api_endpoint: Optional[str] = None,
+ response_path: Optional[str] = None,
+ system_prompt: Optional[str] = None,
+ query_template: Optional[str] = None,
+ response_template: Optional[str] = None,
+ analysis_template: Optional[str] = None,
+ intensity_template: Optional[str] = None,
+ analysis_pattern: Optional[str] = None,
+ intensity_pattern: Optional[str] = None,
+ try_num: PositiveInt = 3,
+ model_params: Dict = {},
+ sampling_params: Dict = {},
+ **kwargs):
+ """
+ Initialization method.
+
+ :param api_model: API model name.
+ :param max_round: The max num of round in the dialog to build the
+ prompt.
+ :param api_endpoint: URL endpoint for the API.
+ :param response_path: Path to extract content from the API response.
+ Defaults to 'choices.0.message.content'.
+ :param system_prompt: System prompt for the task.
+ :param query_template: Template for query part to build the input
+ prompt.
+ :param response_template: Template for response part to build the
+ input prompt.
+ :param analysis_template: Template for analysis part to build the
+ input prompt.
+ :param intensity_template: Template for intensity part to build the
+ input prompt.
+ :param analysis_pattern: Pattern to parse the return sentiment
+ analysis.
+ :param intensity_pattern: Pattern to parse the return sentiment
+ intensity.
+ :param try_num: The number of retry attempts when there is an API
+ call error or output parsing error.
+ :param model_params: Parameters for initializing the API model.
+ :param sampling_params: Extra parameters passed to the API call.
+ e.g {'temperature': 0.9, 'top_p': 0.95}
+ :param kwargs: Extra keyword arguments.
+ """
+ super().__init__(**kwargs)
+
+ self.max_round = max_round
+
+ self.system_prompt = system_prompt or self.DEFAULT_SYSTEM_PROMPT
+ self.query_template = query_template or self.DEFAULT_QUERY_TEMPLATE
+ self.response_template = response_template or \
+ self.DEFAULT_RESPONSE_TEMPLATE
+ self.analysis_template = analysis_template or \
+ self.DEFAULT_ANALYSIS_TEMPLATE
+ self.intensity_template = intensity_template or \
+ self.DEFAULT_INTENSITY_TEMPLATE
+ self.analysis_pattern = analysis_pattern or \
+ self.DEFAULT_ANALYSIS_PATTERN
+ self.intensity_pattern = intensity_pattern or \
+ self.DEFAULT_INTENSITY_PATTERN
+
+ self.sampling_params = sampling_params
+
+ self.model_key = prepare_model(model_type='api',
+ model=api_model,
+ endpoint=api_endpoint,
+ response_path=response_path,
+ **model_params)
+
+ self.try_num = try_num
+
+ def build_input(self, history, query):
+ if self.max_round > 0:
+ input_prompt = ''.join(history[-self.max_round * 4:])
+ else:
+ input_prompt = ''
+ input_prompt += self.query_template.format(query=query[0])
+
+ return input_prompt
+
+ def parse_output(self, response):
+ analysis = ''
+ intensity = 0
+
+ match = re.search(self.analysis_pattern, response)
+ if match:
+ analysis = match.group(1)
+
+ match = re.search(self.intensity_pattern, response)
+ if match:
+ intensity = int(match.group(1))
+
+ return analysis, intensity
+
+ def process_single(self, sample, rank=None):
+ client = get_model(self.model_key, rank=rank)
+
+ analysis_list = []
+ intensities = []
+ history = []
+
+ dialog = sample[self.history_key]
+ if self.query_key in sample and sample[self.query_key]:
+ if self.response_key in sample and sample[self.response_key]:
+ dialog.append(
+ (sample[self.query_key], sample[self.response_key]))
+ else:
+ dialog.append((sample[self.query_key], ''))
+
+ for qa in dialog:
+ input_prompt = self.build_input(history, qa)
+ messages = [{
+ 'role': 'system',
+ 'content': self.system_prompt,
+ }, {
+ 'role': 'user',
+ 'content': input_prompt,
+ }]
+
+ for _ in range(self.try_num):
+ try:
+ response = client(messages, **self.sampling_params)
+ analysis, intensity = self.parse_output(response)
+ if len(analysis) > 0:
+ break
+ except Exception as e:
+ logger.warning(f'Exception: {e}')
+
+ analysis_list.append(analysis)
+ intensities.append(intensity)
+
+ history.append(self.query_template.format(query=qa[0]))
+ history.append(self.analysis_template.format(analysis=analysis))
+ history.append(self.intensity_template.format(intensity=intensity))
+ history.append(self.response_template.format(response=qa[1]))
+
+ analysis_key = f'{Fields.meta}.{MetaKeys.dialog_sentiment_intensity_analysis}' # noqa: E501
+ sample = nested_set(sample, analysis_key, analysis_list)
+ intensity_key = f'{Fields.meta}.{MetaKeys.dialog_sentiment_intensity}'
+ sample = nested_set(sample, intensity_key, intensities)
+
+ return sample
diff --git a/data_juicer/ops/mapper/dialog_topic_detection_mapper.py b/data_juicer/ops/mapper/dialog_topic_detection_mapper.py
new file mode 100644
index 000000000..7e8ee0b54
--- /dev/null
+++ b/data_juicer/ops/mapper/dialog_topic_detection_mapper.py
@@ -0,0 +1,200 @@
+import re
+from typing import Dict, Optional
+
+from loguru import logger
+from pydantic import NonNegativeInt, PositiveInt
+
+from data_juicer.ops.base_op import OPERATORS, Mapper
+from data_juicer.utils.common_utils import nested_set
+from data_juicer.utils.constant import Fields, MetaKeys
+from data_juicer.utils.model_utils import get_model, prepare_model
+
+OP_NAME = 'dialog_topic_detection_mapper'
+
+
+# TODO: LLM-based inference.
+@OPERATORS.register_module(OP_NAME)
+class DialogTopicDetectionMapper(Mapper):
+ """
+ Mapper to generate user's topic labels in dialog. Input from
+ history_key, query_key and response_key. Output lists of
+ labels and analysis for queries in the dialog, which is
+ store in 'dialog_sentiment_labels' and
+ 'dialog_sentiment_labels_analysis' in Data-Juicer meta field.
+ """
+
+ DEFAULT_SYSTEM_PROMPT = ('请判断用户和LLM多轮对话中用户所讨论的话题。\n'
+ '要求:\n'
+ '- 针对用户的每个query,需要先进行分析,然后列出用户正在讨论的话题,下面是'
+ '一个样例,请模仿样例格式输出。\n'
+ '用户:你好,今天我们来聊聊秦始皇吧。\n'
+ '话题分析:用户提到秦始皇,这是中国历史上第一位皇帝。\n'
+ '话题类别:历史\n'
+ 'LLM:当然可以,秦始皇是中国历史上第一个统一全国的皇帝,他在公元前221年建'
+ '立了秦朝,并采取了一系列重要的改革措施,如统一文字、度量衡和货币等。\n'
+ '用户:秦始皇修建的长城和现在的长城有什么区别?\n'
+ '话题分析:用户提到秦始皇修建的长城,并将其与现代长城进行比较,涉及建筑历史'
+ '和地理位置。\n'
+ '话题类别:历史'
+ 'LLM:秦始皇时期修建的长城主要是为了抵御北方游牧民族的入侵,它的规模和修建'
+ '技术相对较为简陋。现代人所看到的长城大部分是明朝时期修建和扩建的,明长城不'
+ '仅规模更大、结构更坚固,而且保存得比较完好。\n'
+ '用户:有意思,那么长城的具体位置在哪些省份呢?\n'
+ '话题分析:用户询问长城的具体位置,涉及到地理知识。\n'
+ '话题类别:地理\n'
+ 'LLM:长城横跨中国北方多个省份,主要包括河北、山西、内蒙古、宁夏、陕西、甘'
+ '肃和北京等。每一段长城都建在关键的战略位置,以便最大限度地发挥其防御作用'
+ '。\n')
+ DEFAULT_QUERY_TEMPLATE = '用户:{query}\n'
+ DEFAULT_RESPONSE_TEMPLATE = 'LLM:{response}\n'
+ DEFAULT_ANALYSIS_TEMPLATE = '话题分析:{analysis}\n'
+ DEFAULT_LABELS_TEMPLATE = '话题类别:{labels}\n'
+ DEFAULT_ANALYSIS_PATTERN = '话题分析:(.*?)\n'
+ DEFAULT_LABELS_PATTERN = '话题类别:(.*?)($|\n)'
+
+ def __init__(self,
+ api_model: str = 'gpt-4o',
+ max_round: NonNegativeInt = 10,
+ *,
+ api_endpoint: Optional[str] = None,
+ response_path: Optional[str] = None,
+ system_prompt: Optional[str] = None,
+ query_template: Optional[str] = None,
+ response_template: Optional[str] = None,
+ analysis_template: Optional[str] = None,
+ labels_template: Optional[str] = None,
+ analysis_pattern: Optional[str] = None,
+ labels_pattern: Optional[str] = None,
+ try_num: PositiveInt = 3,
+ model_params: Dict = {},
+ sampling_params: Dict = {},
+ **kwargs):
+ """
+ Initialization method.
+
+ :param api_model: API model name.
+ :param max_round: The max num of round in the dialog to build the
+ prompt.
+ :param api_endpoint: URL endpoint for the API.
+ :param response_path: Path to extract content from the API response.
+ Defaults to 'choices.0.message.content'.
+ :param system_prompt: System prompt for the task.
+ :param query_template: Template for query part to build the input
+ prompt.
+ :param response_template: Template for response part to build the
+ input prompt.
+ :param analysis_template: Template for analysis part to build the
+ input prompt.
+ :param labels_template: Template for labels part to build the
+ input prompt.
+ :param analysis_pattern: Pattern to parse the return sentiment
+ analysis.
+ :param labels_pattern: Pattern to parse the return sentiment
+ labels.
+ :param try_num: The number of retry attempts when there is an API
+ call error or output parsing error.
+ :param model_params: Parameters for initializing the API model.
+ :param sampling_params: Extra parameters passed to the API call.
+ e.g {'temperature': 0.9, 'top_p': 0.95}
+ :param kwargs: Extra keyword arguments.
+ """
+ super().__init__(**kwargs)
+
+ self.max_round = max_round
+
+ self.system_prompt = system_prompt or self.DEFAULT_SYSTEM_PROMPT
+ self.query_template = query_template or self.DEFAULT_QUERY_TEMPLATE
+ self.response_template = response_template or \
+ self.DEFAULT_RESPONSE_TEMPLATE
+ self.analysis_template = analysis_template or \
+ self.DEFAULT_ANALYSIS_TEMPLATE
+ self.labels_template = labels_template or \
+ self.DEFAULT_LABELS_TEMPLATE
+ self.analysis_pattern = analysis_pattern or \
+ self.DEFAULT_ANALYSIS_PATTERN
+ self.labels_pattern = labels_pattern or \
+ self.DEFAULT_LABELS_PATTERN
+
+ self.sampling_params = sampling_params
+
+ self.model_key = prepare_model(model_type='api',
+ model=api_model,
+ endpoint=api_endpoint,
+ response_path=response_path,
+ **model_params)
+
+ self.try_num = try_num
+
+ def build_input(self, history, query):
+
+ if self.max_round > 0:
+ input_prompt = ''.join(history[-self.max_round * 4:])
+ else:
+ input_prompt = ''
+
+ input_prompt += self.query_template.format(query=query[0])
+
+ return input_prompt
+
+ def parse_output(self, response):
+ analysis = ''
+ labels = ''
+
+ match = re.search(self.analysis_pattern, response)
+ if match:
+ analysis = match.group(1)
+
+ match = re.search(self.labels_pattern, response)
+ if match:
+ labels = match.group(1)
+
+ return analysis, labels
+
+ def process_single(self, sample, rank=None):
+ client = get_model(self.model_key, rank=rank)
+
+ analysis_list = []
+ labels_list = []
+ history = []
+
+ dialog = sample[self.history_key]
+ if self.query_key in sample and sample[self.query_key]:
+ if self.response_key in sample and sample[self.response_key]:
+ dialog.append(
+ (sample[self.query_key], sample[self.response_key]))
+ else:
+ dialog.append((sample[self.query_key], ''))
+
+ for qa in dialog:
+ input_prompt = self.build_input(history, qa)
+ messages = [{
+ 'role': 'system',
+ 'content': self.system_prompt,
+ }, {
+ 'role': 'user',
+ 'content': input_prompt,
+ }]
+
+ for _ in range(self.try_num):
+ try:
+ response = client(messages, **self.sampling_params)
+ analysis, labels = self.parse_output(response)
+ if len(analysis) > 0:
+ break
+ except Exception as e:
+ logger.warning(f'Exception: {e}')
+
+ analysis_list.append(analysis)
+ labels_list.append(labels)
+
+ history.append(self.query_template.format(query=qa[0]))
+ history.append(self.analysis_template.format(analysis=analysis))
+ history.append(self.labels_template.format(labels=labels))
+ history.append(self.response_template.format(response=qa[1]))
+
+ analysis_key = f'{Fields.meta}.{MetaKeys.dialog_topic_labels_analysis}' # noqa: E501
+ sample = nested_set(sample, analysis_key, analysis_list)
+ labels_key = f'{Fields.meta}.{MetaKeys.dialog_topic_labels}'
+ sample = nested_set(sample, labels_key, labels_list)
+
+ return sample
diff --git a/data_juicer/ops/mapper/query_intent_detection_mapper.py b/data_juicer/ops/mapper/query_intent_detection_mapper.py
new file mode 100644
index 000000000..b0d240e2d
--- /dev/null
+++ b/data_juicer/ops/mapper/query_intent_detection_mapper.py
@@ -0,0 +1,84 @@
+from typing import Dict, Optional
+
+from data_juicer.utils.common_utils import nested_set
+from data_juicer.utils.constant import Fields, MetaKeys
+from data_juicer.utils.model_utils import get_model, prepare_model
+
+from ..base_op import OPERATORS, Mapper
+
+OP_NAME = 'query_intent_detection_mapper'
+
+
+@OPERATORS.register_module(OP_NAME)
+class QueryIntentDetectionMapper(Mapper):
+ """
+ Mapper to predict user's Intent label in query. Input from query_key.
+ Output intent label and corresponding score for the query, which is
+ store in 'query_intent_label' and 'query_intent_label_score' in
+ Data-Juicer meta field.
+ """
+
+ _accelerator = 'cuda'
+ _batched_op = True
+
+ def __init__(
+ self,
+ hf_model:
+ str = 'bespin-global/klue-roberta-small-3i4k-intent-classification', # noqa: E501 E131
+ zh_to_en_hf_model: Optional[str] = 'Helsinki-NLP/opus-mt-zh-en',
+ model_params: Dict = {},
+ zh_to_en_model_params: Dict = {},
+ **kwargs):
+ """
+ Initialization method.
+
+ :param hf_model: Hugginface model ID to predict intent label.
+ :param zh_to_en_hf_model: Translation model from Chinese to English.
+ If not None, translate the query from Chinese to English.
+ :param model_params: model param for hf_model.
+ :param zh_to_en_model_params: model param for zh_to_hf_model.
+ :param kwargs: Extra keyword arguments.
+ """
+ super().__init__(**kwargs)
+
+ self.model_key = prepare_model(model_type='huggingface',
+ pretrained_model_name_or_path=hf_model,
+ return_pipe=True,
+ pipe_task='text-classification',
+ **model_params)
+
+ if zh_to_en_hf_model is not None:
+ self.zh_to_en_model_key = prepare_model(
+ model_type='huggingface',
+ pretrained_model_name_or_path=zh_to_en_hf_model,
+ return_pipe=True,
+ pipe_task='translation',
+ **zh_to_en_model_params)
+ else:
+ self.zh_to_en_model_key = None
+
+ def process_batched(self, samples, rank=None):
+ queries = samples[self.query_key]
+
+ if self.zh_to_en_model_key is not None:
+ translater, _ = get_model(self.zh_to_en_model_key, rank,
+ self.use_cuda())
+ results = translater(queries)
+ queries = [item['translation_text'] for item in results]
+
+ classifier, _ = get_model(self.model_key, rank, self.use_cuda())
+ results = classifier(queries)
+ labels = [r['label'] for r in results]
+ scores = [r['score'] for r in results]
+
+ if Fields.meta not in samples:
+ samples[Fields.meta] = [{} for val in labels]
+ for i in range(len(samples[Fields.meta])):
+ samples[Fields.meta][i] = nested_set(samples[Fields.meta][i],
+ MetaKeys.query_intent_label,
+ labels[i])
+ samples[Fields.meta][i] = nested_set(samples[Fields.meta][i],
+ MetaKeys.query_intent_score,
+ scores[i])
+
+ return samples
diff --git a/data_juicer/ops/mapper/query_sentiment_detection_mapper.py b/data_juicer/ops/mapper/query_sentiment_detection_mapper.py
new file mode 100644
index 000000000..634bdeab3
--- /dev/null
+++ b/data_juicer/ops/mapper/query_sentiment_detection_mapper.py
@@ -0,0 +1,85 @@
+from typing import Dict, Optional
+
+from data_juicer.utils.common_utils import nested_set
+from data_juicer.utils.constant import Fields, MetaKeys
+from data_juicer.utils.model_utils import get_model, prepare_model
+
+from ..base_op import OPERATORS, Mapper
+
+OP_NAME = 'query_sentiment_detection_mapper'
+
+
+@OPERATORS.register_module(OP_NAME)
+class QuerySentimentDetectionMapper(Mapper):
+ """
+ Mapper to predict user's sentiment label ('negative', 'neutral' and
+ 'positive') in query. Input from query_key.
+ Output label and corresponding score for the query, which is
+ store in 'query_sentiment_label' and
+ 'query_sentiment_label_score' in Data-Juicer meta field.
+ """
+
+ _accelerator = 'cuda'
+ _batched_op = True
+
+ def __init__(
+ self,
+ hf_model:
+ str = 'mrm8488/distilroberta-finetuned-financial-news-sentiment-analysis', # noqa: E501 E131
+ zh_to_en_hf_model: Optional[str] = 'Helsinki-NLP/opus-mt-zh-en',
+ model_params: Dict = {},
+ zh_to_en_model_params: Dict = {},
+ **kwargs):
+ """
+ Initialization method.
+
+ :param hf_model: Hugginface model ID to predict sentiment label.
+ :param zh_to_en_hf_model: Translation model from Chinese to English.
+ If not None, translate the query from Chinese to English.
+ :param model_params: model param for hf_model.
+ :param zh_to_en_model_params: model param for zh_to_hf_model.
+ :param kwargs: Extra keyword arguments.
+ """
+ super().__init__(**kwargs)
+
+ self.model_key = prepare_model(model_type='huggingface',
+ pretrained_model_name_or_path=hf_model,
+ return_pipe=True,
+ pipe_task='text-classification',
+ **model_params)
+
+ if zh_to_en_hf_model is not None:
+ self.zh_to_en_model_key = prepare_model(
+ model_type='huggingface',
+ pretrained_model_name_or_path=zh_to_en_hf_model,
+ return_pipe=True,
+ pipe_task='translation',
+ **zh_to_en_model_params)
+ else:
+ self.zh_to_en_model_key = None
+
+ def process_batched(self, samples, rank=None):
+ queries = samples[self.query_key]
+
+ if self.zh_to_en_model_key is not None:
+ translater, _ = get_model(self.zh_to_en_model_key, rank,
+ self.use_cuda())
+ results = translater(queries)
+ queries = [item['translation_text'] for item in results]
+
+ classifier, _ = get_model(self.model_key, rank, self.use_cuda())
+ results = classifier(queries)
+ labels = [r['label'] for r in results]
+ scores = [r['score'] for r in results]
+
+ if Fields.meta not in samples:
+ samples[Fields.meta] = [{} for val in labels]
+ for i in range(len(samples[Fields.meta])):
+ samples[Fields.meta][i] = nested_set(
+ samples[Fields.meta][i], MetaKeys.query_sentiment_label,
+ labels[i])
+ samples[Fields.meta][i] = nested_set(
+ samples[Fields.meta][i], MetaKeys.query_sentiment_score,
+ scores[i])
+
+ return samples
diff --git a/data_juicer/ops/mapper/query_topic_detection_mapper.py b/data_juicer/ops/mapper/query_topic_detection_mapper.py
new file mode 100644
index 000000000..8e5687ee3
--- /dev/null
+++ b/data_juicer/ops/mapper/query_topic_detection_mapper.py
@@ -0,0 +1,84 @@
+from typing import Dict, Optional
+
+from data_juicer.utils.common_utils import nested_set
+from data_juicer.utils.constant import Fields, MetaKeys
+from data_juicer.utils.model_utils import get_model, prepare_model
+
+from ..base_op import OPERATORS, Mapper
+
+OP_NAME = 'query_topic_detection_mapper'
+
+
+@OPERATORS.register_module(OP_NAME)
+class QueryTopicDetectionMapper(Mapper):
+ """
+ Mapper to predict user's topic label in query. Input from query_key.
+ Output topic label and corresponding score for the query, which is
+ store in 'query_topic_label' and 'query_topic_label_score' in
+ Data-Juicer meta field.
+ """
+
+ _accelerator = 'cuda'
+ _batched_op = True
+
+ def __init__(
+ self,
+ hf_model:
+ str = 'dstefa/roberta-base_topic_classification_nyt_news', # noqa: E501 E131
+ zh_to_en_hf_model: Optional[str] = 'Helsinki-NLP/opus-mt-zh-en',
+ model_params: Dict = {},
+ zh_to_en_model_params: Dict = {},
+ **kwargs):
+ """
+ Initialization method.
+
+ :param hf_model: Hugginface model ID to predict topic label.
+ :param zh_to_en_hf_model: Translation model from Chinese to English.
+ If not None, translate the query from Chinese to English.
+ :param model_params: model param for hf_model.
+ :param zh_to_en_model_params: model param for zh_to_hf_model.
+ :param kwargs: Extra keyword arguments.
+ """
+ super().__init__(**kwargs)
+
+ self.model_key = prepare_model(model_type='huggingface',
+ pretrained_model_name_or_path=hf_model,
+ return_pipe=True,
+ pipe_task='text-classification',
+ **model_params)
+
+ if zh_to_en_hf_model is not None:
+ self.zh_to_en_model_key = prepare_model(
+ model_type='huggingface',
+ pretrained_model_name_or_path=zh_to_en_hf_model,
+ return_pipe=True,
+ pipe_task='translation',
+ **zh_to_en_model_params)
+ else:
+ self.zh_to_en_model_key = None
+
+ def process_batched(self, samples, rank=None):
+ queries = samples[self.query_key]
+
+ if self.zh_to_en_model_key is not None:
+ translater, _ = get_model(self.zh_to_en_model_key, rank,
+ self.use_cuda())
+ results = translater(queries)
+ queries = [item['translation_text'] for item in results]
+
+ classifier, _ = get_model(self.model_key, rank, self.use_cuda())
+ results = classifier(queries)
+ labels = [r['label'] for r in results]
+ scores = [r['score'] for r in results]
+
+ if Fields.meta not in samples:
+ samples[Fields.meta] = [{} for val in labels]
+ for i in range(len(samples[Fields.meta])):
+ samples[Fields.meta][i] = nested_set(samples[Fields.meta][i],
+ MetaKeys.query_topic_label,
+ labels[i])
+ samples[Fields.meta][i] = nested_set(samples[Fields.meta][i],
+ MetaKeys.query_topic_score,
+ scores[i])
+
+ return samples
diff --git a/data_juicer/ops/selector/__init__.py b/data_juicer/ops/selector/__init__.py
index 22df12987..0339a2c5b 100644
--- a/data_juicer/ops/selector/__init__.py
+++ b/data_juicer/ops/selector/__init__.py
@@ -1,9 +1,11 @@
from .frequency_specified_field_selector import FrequencySpecifiedFieldSelector
from .random_selector import RandomSelector
from .range_specified_field_selector import RangeSpecifiedFieldSelector
+from .tags_specified_field_selector import TagsSpecifiedFieldSelector
from .topk_specified_field_selector import TopkSpecifiedFieldSelector
__all__ = [
'FrequencySpecifiedFieldSelector', 'RandomSelector',
- 'RangeSpecifiedFieldSelector', 'TopkSpecifiedFieldSelector'
+ 'RangeSpecifiedFieldSelector', 'TagsSpecifiedFieldSelector',
+ 'TopkSpecifiedFieldSelector'
]
diff --git a/data_juicer/ops/selector/tags_specified_field_selector.py b/data_juicer/ops/selector/tags_specified_field_selector.py
new file mode 100644
index 000000000..6fb32251a
--- /dev/null
+++ b/data_juicer/ops/selector/tags_specified_field_selector.py
@@ -0,0 +1,54 @@
+import numbers
+from typing import List
+
+from ..base_op import OPERATORS, Selector
+
+
+@OPERATORS.register_module('tags_specified_field_selector')
+class TagsSpecifiedFieldSelector(Selector):
+ """Selector to select samples based on the tags of specified
+ field."""
+
+ def __init__(self,
+ field_key: str = '',
+ target_tags: List[str] = None,
+ *args,
+ **kwargs):
+ """
+ Initialization method.
+
+ :param field_key: Selector based on the specified value
+ corresponding to the target key. The target key
+ corresponding to multi-level field information need to be
+ separated by '.'.
+ :param target_tags: Target tags to be select.
+ :param args: extra args
+ :param kwargs: extra args
+ """
+ super().__init__(*args, **kwargs)
+ self.field_key = field_key
+ self.target_tags = set(target_tags)
+
+ def process(self, dataset):
+ if len(dataset) <= 1 or not self.field_key:
+ return dataset
+
+ field_keys = self.field_key.split('.')
+ assert field_keys[0] in dataset.features.keys(
+ ), "'{}' not in {}".format(field_keys[0], dataset.features.keys())
+
+ selected_index = []
+ for i, item in enumerate(dataset[field_keys[0]]):
+ field_value = item
+ for key in field_keys[1:]:
+ assert key in field_value.keys(), "'{}' not in {}".format(
+ key, field_value.keys())
+ field_value = field_value[key]
+ assert field_value is None or isinstance(
+ field_value, str) or isinstance(
+ field_value, numbers.Number
+ ), 'The {} item is not String, Numbers or NoneType'.format(i)
+ if field_value in self.target_tags:
+ selected_index.append(i)
+
+ return dataset.select(selected_index)
diff --git a/data_juicer/utils/auto_install_mapping.py b/data_juicer/utils/auto_install_mapping.py
index 5ea9091b0..3b8ec20aa 100644
--- a/data_juicer/utils/auto_install_mapping.py
+++ b/data_juicer/utils/auto_install_mapping.py
@@ -96,4 +96,12 @@
'extract_support_text_mapper': ['openai'],
'pair_preference_mapper': ['openai'],
'relation_identity_mapper': ['openai'],
+ 'dialog_intent_detection_mapper': ['openai'],
+ 'dialog_sentiment_detection_mapper': ['openai'],
+ 'dialog_sentiment_intensity_mapper': ['openai'],
+ 'dialog_topic_intensity_mapper': ['openai'],
+ 'query_intent_detection_mapper': ['transformers'],
+ 'query_sentiment_detection_mapper': ['transformers'],
+ 'query_topic_detection_mapper': ['transformers'],
+ 'meta_tags_aggregator': ['openai'],
}
diff --git a/data_juicer/utils/common_utils.py b/data_juicer/utils/common_utils.py
index bd649bb96..8a13ae361 100644
--- a/data_juicer/utils/common_utils.py
+++ b/data_juicer/utils/common_utils.py
@@ -69,17 +69,21 @@ def nested_set(data: dict, path: str, val):
:param data: A dictionary with nested format.
:param path: A dot-separated string representing the path to set.
- This can include numeric indices when setting list
- elements.
:return: The nested data after the val set.
"""
keys = path.split('.')
cur = data
- for key in keys[:-1]:
- if key not in cur:
- cur[key] = {}
- cur = cur[key]
- cur[keys[-1]] = val
+ try:
+ for key in keys[:-1]:
+ if key not in cur:
+ cur[key] = {}
+ cur = cur[key]
+ if keys[-1] in cur:
+ logger.warning(f'Overwrite value in {path}!')
+ cur[keys[-1]] = val
+ except Exception:
+ logger.warning(f'Unvalid dot-separated path: {path}!')
+ return data
return data
diff --git a/data_juicer/utils/constant.py b/data_juicer/utils/constant.py
index 922d44c8b..ba693f63e 100644
--- a/data_juicer/utils/constant.py
+++ b/data_juicer/utils/constant.py
@@ -69,6 +69,26 @@ class Fields(object):
support_text = DEFAULT_PREFIX + 'support_text__'
+class MetaKeys(object):
+
+ dialog_sentiment_intensity = 'dialog_sentiment_intensity'
+ dialog_sentiment_intensity_analysis = 'dialog_sentiment_intensity_analysis'
+ query_sentiment_label = 'query_sentiment_label'
+ query_sentiment_score = 'query_sentiment_label_score'
+ dialog_sentiment_labels = 'dialog_sentiment_labels'
+ dialog_sentiment_labels_analysis = 'dialog_sentiment_labels_analysis'
+
+ dialog_intent_labels = 'dialog_intent_labels'
+ dialog_intent_labels_analysis = 'dialog_intent_labels_analysis'
+ query_intent_label = 'query_intent_label'
+ query_intent_score = 'query_intent_label_score'
+
+ dialog_topic_labels = 'dialog_topic_labels'
+ dialog_topic_labels_analysis = 'dialog_topic_labels_analysis'
+ query_topic_label = 'query_topic_label'
+ query_topic_score = 'query_topic_label_score'
+
+
class StatsKeysMeta(type):
"""
a helper class to track the mapping from OP's name to its used stats_keys
diff --git a/docs/Operators.md b/docs/Operators.md
index fe3c6d94d..ea84a360c 100644
--- a/docs/Operators.md
+++ b/docs/Operators.md
@@ -6,17 +6,17 @@ This page offers a basic description of the operators (OPs) in Data-Juicer. User
## Overview
-The operators in Data-Juicer are categorized into 5 types.
+The operators in Data-Juicer are categorized into 7 types.
| Type | Number | Description |
|-----------------------------------|:------:|-------------------------------------------------|
| [ Formatter ]( #formatter ) | 9 | Discovers, loads, and canonicalizes source data |
-| [ Mapper ]( #mapper ) | 63 | Edits and transforms samples |
+| [ Mapper ]( #mapper ) | 70 | Edits and transforms samples |
| [ Filter ]( #filter ) | 44 | Filters out low-quality samples |
| [ Deduplicator ]( #deduplicator ) | 8 | Detects and removes duplicate samples |
-| [ Selector ]( #selector ) | 4 | Selects top samples based on ranking |
-| [ Grouper ]( #grouper ) | 2 | Group samples to batched samples |
-| [ Aggregator ]( #aggregator ) | 3 | Aggregate for batched samples, such as summary or conclusion |
+| [ Selector ]( #selector ) | 5 | Selects top samples based on ranking |
+| [ Grouper ]( #grouper ) | 3 | Group samples to batched samples |
+| [ Aggregator ]( #aggregator ) | 4 | Aggregate for batched samples, such as summary or conclusion |
All the specific operators are listed below, each featured with several capability tags.
@@ -68,6 +68,10 @@ All the specific operators are listed below, each featured with several capabili
| clean_html_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Removes HTML tags and returns plain text of all the nodes | [code](../data_juicer/ops/mapper/clean_html_mapper.py) | [tests](../tests/ops/mapper/test_clean_html_mapper.py) |
| clean_ip_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Removes IP addresses | [code](../data_juicer/ops/mapper/clean_ip_mapper.py) | [tests](../tests/ops/mapper/test_clean_ip_mapper.py) |
| clean_links_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![Code](https://img.shields.io/badge/Code-590F08?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Removes links, such as those starting with http or ftp | [code](../data_juicer/ops/mapper/clean_links_mapper.py) | [tests](../tests/ops/mapper/test_clean_links_mapper.py) |
+| dialog_intent_detection_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Mapper to generate user's intent labels in dialog. | [code](../data_juicer/ops/mapper/dialog_intent_detection_mapper.py) | [tests](../tests/ops/mapper/test_dialog_intent_detection_mapper.py) |
+| dialog_sentiment_detection_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Mapper to generate user's sentiment labels in dialog. | [code](../data_juicer/ops/mapper/dialog_sentiment_detection_mapper.py) | [tests](../tests/ops/mapper/test_dialog_sentiment_detection_mapper.py) |
+| dialog_sentiment_intensity_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Mapper to predict user's sentiment intensity (from -5 to 5 in default prompt) in dialog. | [code](../data_juicer/ops/mapper/dialog_sentiment_intensity_mapper.py) | [tests](../tests/ops/mapper/test_dialog_sentiment_intensity_mapper.py) |
+| dialog_topic_detection_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Mapper to generate user's topic labels in dialog. | [code](../data_juicer/ops/mapper/dialog_topic_detection_mapper.py) | [tests](../tests/ops/mapper/test_dialog_topic_detection_mapper.py) |
| expand_macro_mapper | ![LaTeX](https://img.shields.io/badge/LaTeX-D99379?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Expands macros usually defined at the top of TeX documents | [code](../data_juicer/ops/mapper/expand_macro_mapper.py) | [tests](../tests/ops/mapper/test_expand_macro_mapper.py) |
| extract_entity_attribute_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Extract attributes for given entities from the text. | [code](../data_juicer/ops/mapper/extract_entity_attribute_mapper.py) | [tests](../tests/ops/mapper/test_extract_entity_attribute_mapper.py) |
| extract_entity_relation_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Extract entities and relations in the text for knowledge graph. | [code](../data_juicer/ops/mapper/extract_entity_relation_mapper.py) | [tests](../tests/ops/mapper/test_extract_entity_relation_mapper.py) |
@@ -93,6 +97,9 @@ All the specific operators are listed below, each featured with several capabili
| punctuation_normalization_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Normalizes various Unicode punctuations to their ASCII equivalents | [code](../data_juicer/ops/mapper/punctuation_normalization_mapper.py) | [tests](../tests/ops/mapper/test_punctuation_normalization_mapper.py) |
| python_file_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Executing Python function defined in a file | [code](../data_juicer/ops/mapper/python_file_mapper.py) | [tests](../tests/ops/mapper/test_python_file_mapper.py) |
| python_lambda_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Executing Python lambda function on data samples | [code](../data_juicer/ops/mapper/python_lambda_mapper.py) | [tests](../tests/ops/mapper/test_python_lambda_mapper.py) |
+| query_intent_detection_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Mapper to predict user's intent label in query. | [code](../data_juicer/ops/mapper/query_intent_detection_mapper.py) | [tests](../tests/ops/mapper/test_query_intent_detection_mapper.py) |
+| query_sentiment_detection_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Mapper to predict user's sentiment label ('negative', 'neutral' and 'positive') in query. | [code](../data_juicer/ops/mapper/query_sentiment_detection_mapper.py) | [tests](../tests/ops/mapper/test_query_sentiment_detection_mapper.py) |
+| query_topic_detection_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Mapper to predict user's topic label in query. | [code](../data_juicer/ops/mapper/query_topic_detection_mapper.py) | [tests](../tests/ops/mapper/test_query_topic_detection_mapper.py) |
| relation_identity_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Identify relation between two entity in the text. | [code](../data_juicer/ops/mapper/relation_identity_mapper.py) | [tests](../tests/ops/mapper/test_relation_identity_mapper.py) |
| remove_bibliography_mapper | ![LaTeX](https://img.shields.io/badge/LaTeX-D99379?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Removes the bibliography of TeX documents | [code](../data_juicer/ops/mapper/remove_bibliography_mapper.py) | [tests](../tests/ops/mapper/test_remove_bibliography_mapper.py) |
| remove_comments_mapper | ![LaTeX](https://img.shields.io/badge/LaTeX-D99379?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Removes the comments of TeX documents | [code](../data_juicer/ops/mapper/remove_comments_mapper.py) | [tests](../tests/ops/mapper/test_remove_comments_mapper.py) |
@@ -192,20 +199,24 @@ All the specific operators are listed below, each featured with several capabili
| frequency_specified_field_selector | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Selects top samples by comparing the frequency of the specified field | [code](../data_juicer/ops/selector/frequency_specified_field_selector.py) | [tests](../tests/ops/selector/test_frequency_specified_field_selector.py) |
| random_selector | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Selects samples randomly | [code](../data_juicer/ops/selector/random_selector.py) | [tests](../tests/ops/selector/test_random_selector.py) |
| range_specified_field_selector | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Selects samples within a specified range by comparing the values of the specified field | [code](../data_juicer/ops/selector/range_specified_field_selector.py) | [tests](../tests/ops/selector/test_range_specified_field_selector.py) |
+| tags_specified_field_selector | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Select samples based on the tags of specified
+ field. | [code](../data_juicer/ops/selector/tags_specified_field_selector.py) | [tests](../tests/ops/selector/test_tags_specified_field_selector.py) |
| topk_specified_field_selector | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Selects top samples by comparing the values of the specified field | [code](../data_juicer/ops/selector/topk_specified_field_selector.py) | [tests](../tests/ops/selector/test_topk_specified_field_selector.py) |
## Grouper
| Operator | Tags | Description | Source code | Unit tests |
|------------------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|-----------------------------------------------------------------------------------------|-------------------------------------------------------------------------------|---------------------------------------------------------------------------|
-| key_value_grouper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Group samples to batched samples according values in given keys. | [code](../data_juicer/ops/grouper/key_value_grouper.py) | [tests](../tests/ops/grouper/test_key_value_grouper.py) |
| naive_grouper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Group all samples to one batched sample. | [code](../data_juicer/ops/grouper/naive_grouper.py) | [tests](../tests/ops/grouper/test_naive_grouper.py) |
+| naive_reverse_grouper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Split batched samples to samples. | [code](../data_juicer/ops/grouper/naive_reverse_grouper.py) | [tests](../tests/ops/grouper/test_naive_reverse_grouper.py) |
+| key_value_grouper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Group samples to batched samples according values in given keys. | [code](../data_juicer/ops/grouper/key_value_grouper.py) | [tests](../tests/ops/grouper/test_key_value_grouper.py) |
## Aggregator
| Operator | Tags | Description | Source code | Unit tests |
|------------------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|-----------------------------------------------------------------------------------------|-------------------------------------------------------------------------------|---------------------------------------------------------------------------|
| entity_attribute_aggregator | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Return conclusion of the given entity's attribute from some docs. | [code](../data_juicer/ops/aggregator/entity_attribute_aggregator.py) | [tests](../tests/ops/aggregator/test_entity_attribute_aggregator.py) |
+| meta_tags_aggregator | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Merge similar meta tags to one tag. | [code](../data_juicer/ops/aggregator/meta_tags_aggregator.py) | [tests](../tests/ops/aggregator/test_meta_tags_aggregator.py) |
| most_relavant_entities_aggregator | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Extract entities closely related to a given entity from some texts, and sort them in descending order of importance. | [code](../data_juicer/ops/aggregator/most_relavant_entities_aggregator.py) | [tests](../tests/ops/aggregator/test_most_relavant_entities_aggregator.py) |
| nested_aggregator | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Considering the limitation of input length, nested aggregate contents for each given number of samples. | [code](../data_juicer/ops/aggregator/nested_aggregator.py) | [tests](../tests/ops/aggregator/test_nested_aggregator.py) |
diff --git a/docs/Operators_ZH.md b/docs/Operators_ZH.md
index 61610a873..40710f68f 100644
--- a/docs/Operators_ZH.md
+++ b/docs/Operators_ZH.md
@@ -6,17 +6,17 @@
## 概览
-Data-Juicer 中的算子分为以下 5 种类型。
+Data-Juicer 中的算子分为以下 7 种类型。
| 类型 | 数量 | 描述 |
|------------------------------------|:--:|---------------|
| [ Formatter ]( #formatter ) | 9 | 发现、加载、规范化原始数据 |
-| [ Mapper ]( #mapper ) | 63 | 对数据样本进行编辑和转换 |
+| [ Mapper ]( #mapper ) | 70 | 对数据样本进行编辑和转换 |
| [ Filter ]( #filter ) | 44 | 过滤低质量样本 |
| [ Deduplicator ]( #deduplicator ) | 8 | 识别、删除重复样本 |
-| [ Selector ]( #selector ) | 4 | 基于排序选取高质量样本 |
-| [ Grouper ]( #grouper ) | 2 | 将样本分组,每一组组成一个批量样本 |
-| [ Aggregator ]( #aggregator ) | 3 | 对批量样本进行汇总,如得出总结或结论 |
+| [ Selector ]( #selector ) | 5 | 基于排序选取高质量样本 |
+| [ Grouper ]( #grouper ) | 3 | 将样本分组,每一组组成一个批量样本 |
+| [ Aggregator ]( #aggregator ) | 4 | 对批量样本进行汇总,如得出总结或结论 |
下面列出所有具体算子,每种算子都通过多个标签来注明其主要功能。
@@ -67,6 +67,10 @@ Data-Juicer 中的算子分为以下 5 种类型。
| clean_html_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 删除 HTML 标签并返回所有节点的纯文本 | [code](../data_juicer/ops/mapper/clean_html_mapper.py) | [tests](../tests/ops/mapper/test_clean_html_mapper.py) |
| clean_ip_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 删除 IP 地址 | [code](../data_juicer/ops/mapper/clean_ip_mapper.py) | [tests](../tests/ops/mapper/test_clean_ip_mapper.py) |
| clean_links_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![Code](https://img.shields.io/badge/Code-590F08?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 删除链接,例如以 http 或 ftp 开头的 | [code](../data_juicer/ops/mapper/clean_links_mapper.py) | [tests](../tests/ops/mapper/test_clean_links_mapper.py) |
+| dialog_intent_detection_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 抽取对话中的用户意图标签。 | [code](../data_juicer/ops/mapper/dialog_intent_detection_mapper.py) | [tests](../tests/ops/mapper/test_dialog_intent_detection_mapper.py) |
+| dialog_sentiment_detection_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 抽取对话中用户的情感标签 | [code](../data_juicer/ops/mapper/dialog_sentiment_detection_mapper.py) | [tests](../tests/ops/mapper/test_dialog_sentiment_detection_mapper.py) |
+| dialog_sentiment_intensity_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 预测对话中的情绪强度(默认从-5到5)。 | [code](../data_juicer/ops/mapper/dialog_sentiment_intensity_mapper.py) | [tests](../tests/ops/mapper/test_dialog_sentiment_intensity_mapper.py) |
+| dialog_topic_detection_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 抽取对话中的用户的话题标签。 | [code](../data_juicer/ops/mapper/dialog_topic_detection_mapper.py) | [tests](../tests/ops/mapper/test_dialog_topic_detection_mapper.py) |
| expand_macro_mapper | ![LaTeX](https://img.shields.io/badge/LaTeX-D99379?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 扩展通常在 TeX 文档顶部定义的宏 | [code](../data_juicer/ops/mapper/expand_macro_mapper.py) | [tests](../tests/ops/mapper/test_expand_macro_mapper.py) |
| extract_entity_attribute_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 给定主体和属性名,从文本中抽取主体的属性 | [code](../data_juicer/ops/mapper/extract_entity_attribute_mapper.py) | [tests](../tests/ops/mapper/test_extract_entity_attribute_mapper.py) |
| extract_entity_relation_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 从文本中抽取知识图谱的实体和关系 | [code](../data_juicer/ops/mapper/extract_entity_relation_mapper.py) | [tests](../tests/ops/mapper/test_extract_entity_relation_mapper.py) |
@@ -92,6 +96,9 @@ Data-Juicer 中的算子分为以下 5 种类型。
| punctuation_normalization_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 将各种 Unicode 标点符号标准化为其 ASCII 等效项 | [code](../data_juicer/ops/mapper/punctuation_normalization_mapper.py) | [tests](../tests/ops/mapper/test_punctuation_normalization_mapper.py) |
| python_file_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 执行文件中定义的 Python 函数处理样本 | [code](../data_juicer/ops/mapper/python_file_mapper.py) | [tests](../tests/ops/mapper/test_python_file_mapper.py) |
| python_lambda_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 执行 Python lambda 函数处理样本 | [code](../data_juicer/ops/mapper/python_lambda_mapper.py) | [tests](../tests/ops/mapper/test_python_lambda_mapper.py) |
+| query_intent_detection_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 预测用户查询中的意图标签。 | [code](../data_juicer/ops/mapper/query_intent_detection_mapper.py) | [tests](../tests/ops/mapper/test_query_intent_detection_mapper.py) |
+| query_sentiment_detection_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 预测用户查询中的情感强度标签('negative'、'neutral'和'positive')。 | [code](../data_juicer/ops/mapper/query_sentiment_detection_mapper.py) | [tests](../tests/ops/mapper/test_query_sentiment_detection_mapper.py) |
+| query_topic_detection_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 预测用户查询中的话题标签。 | [code](../data_juicer/ops/mapper/query_topic_detection_mapper.py) | [tests](../tests/ops/mapper/test_query_topic_detection_mapper.py) |
| relation_identity_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 识别一段文本中两个实体之间的关系 | [code](../data_juicer/ops/mapper/relation_identity_mapper.py) | [tests](../tests/ops/mapper/test_relation_identity_mapper.py) |
| remove_bibliography_mapper | ![LaTeX](https://img.shields.io/badge/LaTeX-D99379?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 删除 TeX 文档的参考文献 | [code](../data_juicer/ops/mapper/remove_bibliography_mapper.py) | [tests](../tests/ops/mapper/test_remove_bibliography_mapper.py) |
| remove_comments_mapper | ![LaTeX](https://img.shields.io/badge/LaTeX-D99379?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 删除 TeX 文档中的注释 | [code](../data_juicer/ops/mapper/remove_comments_mapper.py) | [tests](../tests/ops/mapper/test_remove_comments_mapper.py) |
@@ -191,20 +198,23 @@ Data-Juicer 中的算子分为以下 5 种类型。
| frequency_specified_field_selector | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 通过比较指定字段的频率选出前 k 个样本 | [code](../data_juicer/ops/selector/frequency_specified_field_selector.py) | [tests](../tests/ops/selector/test_frequency_specified_field_selector.py) |
| random_selector | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 随机筛选 k 个样本 | [code](../data_juicer/ops/selector/random_selector.py) | [tests](../tests/ops/selector/test_random_selector.py) |
| range_specified_field_selector | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 通过比较指定字段的值选出指定范围的 k 个样本 | [code](../data_juicer/ops/selector/range_specified_field_selector.py) | [tests](../tests/ops/selector/test_range_specified_field_selector.py) |
+| tags_specified_field_selector | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 通过指定字段的标签值筛选样例 | [code](../data_juicer/ops/selector/tags_specified_field_selector.py) | [tests](../tests/ops/selector/test_tags_specified_field_selector.py) |
| topk_specified_field_selector | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 通过比较指定字段的值选出前 k 个样本 | [code](../data_juicer/ops/selector/topk_specified_field_selector.py) | [tests](../tests/ops/selector/test_topk_specified_field_selector.py) |
## Grouper
| 算子 | 标签 | 描述 | 源码 | 单测样例 |
|-------------------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|--------------------------|---------------------------------------------------------------------------|---------------------------------------------------------------------------|
+| naive_grouper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 将所有样本分为一个组,返回一个batch化的样本 | [code](../data_juicer/ops/grouper/naive_grouper.py) | [tests](../tests/ops/grouper/test_naive_grouper.py) |
+| naive_reverse_grouper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 将batch化的样本拆分成普通的样本 | [code](../data_juicer/ops/grouper/naive_reverse_grouper.py) | [tests](../tests/ops/grouper/test_naive_reverse_grouper.py) |
| key_value_grouper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 根据给定键的值将样本分组,每一组组成一个批量样本。 | [code](../data_juicer/ops/grouper/key_value_grouper.py) | [tests](../tests/ops/grouper/test_key_value_grouper.py) |
-| naive_grouper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 将所有样本分为一个组,返回一个批量样本 | [code](../data_juicer/ops/grouper/naive_grouper.py) | [tests](../tests/ops/grouper/test_naive_grouper.py) |
## Aggregator
| 算子 | 标签 | 描述 | 源码 | 单测样例 |
|-------------------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|--------------------------|---------------------------------------------------------------------------|---------------------------------------------------------------------------|
| entity_attribute_aggregator | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 从一些文本中总结出给定实体的属性 | [code](../data_juicer/ops/aggregator/entity_attribute_aggregator.py) | [tests](../tests/ops/aggregator/test_entity_attribute_aggregator.py) |
+| meta_tags_aggregator | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 将相似的标签合并成同一个标签。 | [code](../data_juicer/ops/aggregator/meta_tags_aggregator.py) | [tests](../tests/ops/aggregator/test_meta_tags_aggregator.py) |
| most_relavant_entities_aggregator | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 从一些文本中抽取出与给定实体密切相关的实体,按重要性从高到低排序 | [code](../data_juicer/ops/aggregator/most_relavant_entities_aggregator.py) | [tests](../tests/ops/aggregator/test_most_relavant_entities_aggregator.py) |
| nested_aggregator | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 考虑到输入长度的限制,对样本中的内容进行嵌套聚合。 | [code](../data_juicer/ops/aggregator/nested_aggregator.py) | [tests](../tests/ops/aggregator/test_nested_aggregator.py) |
diff --git a/tests/ops/Aggregator/__init__.py b/tests/ops/aggregator/__init__.py
similarity index 100%
rename from tests/ops/Aggregator/__init__.py
rename to tests/ops/aggregator/__init__.py
diff --git a/tests/ops/Aggregator/test_entity_attribute_aggregator.py b/tests/ops/aggregator/test_entity_attribute_aggregator.py
similarity index 100%
rename from tests/ops/Aggregator/test_entity_attribute_aggregator.py
rename to tests/ops/aggregator/test_entity_attribute_aggregator.py
diff --git a/tests/ops/aggregator/test_meta_tags_aggregator.py b/tests/ops/aggregator/test_meta_tags_aggregator.py
new file mode 100644
index 000000000..7aba225ae
--- /dev/null
+++ b/tests/ops/aggregator/test_meta_tags_aggregator.py
@@ -0,0 +1,117 @@
+import unittest
+
+from loguru import logger
+
+from data_juicer.core.data import NestedDataset as Dataset
+from data_juicer.ops.aggregator import MetaTagsAggregator
+from data_juicer.utils.constant import Fields, MetaKeys
+from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase, SKIPPED_TESTS
+
+
+@SKIPPED_TESTS.register_module()
+class MetaTagsAggregatorTest(DataJuicerTestCaseBase):
+
+ def _run_helper(self, op, samples):
+
+ # before runing this test, set below environment variables:
+ # export OPENAI_BASE_URL=https://dashscope.aliyuncs.com/compatible-mode/v1/
+ # export OPENAI_API_KEY=your_dashscope_key
+
+ dataset = Dataset.from_list(samples)
+ new_dataset = op.run(dataset)
+
+ for data in new_dataset:
+ for k in data:
+ logger.info(f"{k}: {data[k]}")
+
+ self.assertEqual(len(new_dataset), len(samples))
+
+ def test_default_aggregator(self):
+ samples = [
+ {
+ Fields.meta: [
+ {
+ MetaKeys.query_sentiment_label: '开心'
+ },
+ {
+ MetaKeys.query_sentiment_label: '快乐'
+ },
+ {
+ MetaKeys.query_sentiment_label: '难过'
+ },
+ {
+ MetaKeys.query_sentiment_label: '不开心'
+ },
+ {
+ MetaKeys.query_sentiment_label: '愤怒'
+ }
+ ]
+ },
+ ]
+ op = MetaTagsAggregator(
+ api_model='qwen2.5-72b-instruct',
+ meta_tag_key=MetaKeys.query_sentiment_label,
+ )
+ self._run_helper(op, samples)
+
+
+ def test_target_tags(self):
+ samples = [
+ {
+ Fields.meta: [
+ {
+ MetaKeys.query_sentiment_label: '开心'
+ },
+ {
+ MetaKeys.query_sentiment_label: '快乐'
+ },
+ {
+ MetaKeys.query_sentiment_label: '难过'
+ },
+ {
+ MetaKeys.query_sentiment_label: '不开心'
+ },
+ {
+ MetaKeys.query_sentiment_label: '愤怒'
+ }
+ ]
+ },
+ ]
+ op = MetaTagsAggregator(
+ api_model='qwen2.5-72b-instruct',
+ meta_tag_key=MetaKeys.query_sentiment_label,
+ target_tags=['开心', '难过', '其他']
+ )
+ self._run_helper(op, samples)
+
+ def test_tag_list(self):
+ samples = [
+ {
+ Fields.meta: [
+ {
+ MetaKeys.dialog_sentiment_labels: ['开心', '平静']
+ },
+ {
+ MetaKeys.dialog_sentiment_labels: ['快乐', '开心', '幸福']
+ },
+ {
+ MetaKeys.dialog_sentiment_labels: ['难过']
+ },
+ {
+ MetaKeys.dialog_sentiment_labels: ['不开心', '没头脑', '不高兴']
+ },
+ {
+ MetaKeys.dialog_sentiment_labels: ['愤怒', '愤慨']
+ }
+ ]
+ },
+ ]
+ op = MetaTagsAggregator(
+ api_model='qwen2.5-72b-instruct',
+ meta_tag_key=MetaKeys.dialog_sentiment_labels,
+ target_tags=['开心', '难过', '其他']
+ )
+ self._run_helper(op, samples)
+
+if __name__ == '__main__':
+ unittest.main()
\ No newline at end of file
diff --git a/tests/ops/Aggregator/test_most_relavant_entities_aggregator.py b/tests/ops/aggregator/test_most_relavant_entities_aggregator.py
similarity index 100%
rename from tests/ops/Aggregator/test_most_relavant_entities_aggregator.py
rename to tests/ops/aggregator/test_most_relavant_entities_aggregator.py
diff --git a/tests/ops/Aggregator/test_nested_aggregator.py b/tests/ops/aggregator/test_nested_aggregator.py
similarity index 100%
rename from tests/ops/Aggregator/test_nested_aggregator.py
rename to tests/ops/aggregator/test_nested_aggregator.py
diff --git a/tests/ops/grouper/test_naive_reverse_grouper.py b/tests/ops/grouper/test_naive_reverse_grouper.py
new file mode 100644
index 000000000..29c06451d
--- /dev/null
+++ b/tests/ops/grouper/test_naive_reverse_grouper.py
@@ -0,0 +1,83 @@
+import unittest
+
+from data_juicer.core.data import NestedDataset as Dataset
+from data_juicer.ops.grouper.naive_reverse_grouper import NaiveReverseGrouper
+from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase
+
+
+class NaiveReverseGrouperTest(DataJuicerTestCaseBase):
+
+ def _run_helper(self, op, samples, target):
+ dataset = Dataset.from_list(samples)
+ new_dataset = op.run(dataset)
+
+ for d, t in zip(new_dataset, target):
+ self.assertEqual(d['text'], t['text'])
+
+ def test_one_batched_sample(self):
+
+ source = [
+ {
+ 'text':[
+ "Today is Sunday and it's a happy day!",
+ "Sur la plateforme MT4, plusieurs manières d'accéder à \n"
+ 'ces fonctionnalités sont conçues simultanément.',
+ '欢迎来到阿里巴巴!'
+ ]
+ }
+ ]
+
+ target = [
+ {
+ 'text': "Today is Sunday and it's a happy day!"
+ },
+ {
+ 'text':
+ "Sur la plateforme MT4, plusieurs manières d'accéder à \n"
+ 'ces fonctionnalités sont conçues simultanément.'
+ },
+ {
+ 'text': '欢迎来到阿里巴巴!'
+ },
+ ]
+
+ op = NaiveReverseGrouper()
+ self._run_helper(op, source, target)
+
+
+ def test_two_batch_sample(self):
+
+ source = [
+ {
+ 'text':[
+ "Today is Sunday and it's a happy day!",
+ "Sur la plateforme MT4, plusieurs manières d'accéder à \n"
+ 'ces fonctionnalités sont conçues simultanément.'
+ ]
+ },
+ {
+ 'text':[
+ '欢迎来到阿里巴巴!'
+ ]
+ }
+ ]
+
+ target = [
+ {
+ 'text': "Today is Sunday and it's a happy day!"
+ },
+ {
+ 'text':
+ "Sur la plateforme MT4, plusieurs manières d'accéder à \n"
+ 'ces fonctionnalités sont conçues simultanément.'
+ },
+ {
+ 'text': '欢迎来到阿里巴巴!'
+ },
+ ]
+
+ op = NaiveReverseGrouper()
+ self._run_helper(op, source, target)
+
+if __name__ == '__main__':
+ unittest.main()
\ No newline at end of file
diff --git a/tests/ops/mapper/test_dialog_intent_detection_mapper.py b/tests/ops/mapper/test_dialog_intent_detection_mapper.py
new file mode 100644
index 000000000..bc3a18752
--- /dev/null
+++ b/tests/ops/mapper/test_dialog_intent_detection_mapper.py
@@ -0,0 +1,170 @@
+import unittest
+import json
+
+from loguru import logger
+
+from data_juicer.core.data import NestedDataset as Dataset
+from data_juicer.ops.mapper.dialog_intent_detection_mapper import DialogIntentDetectionMapper
+from data_juicer.utils.unittest_utils import (SKIPPED_TESTS,
+ DataJuicerTestCaseBase)
+from data_juicer.utils.constant import Fields, MetaKeys
+from data_juicer.utils.common_utils import nested_access
+
+# Skip tests for this OP.
+# These tests have been tested locally.
+@SKIPPED_TESTS.register_module()
+class TestDialogIntentDetectionMapper(DataJuicerTestCaseBase):
+ # before runing this test, set below environment variables:
+ # export OPENAI_API_URL=https://dashscope.aliyuncs.com/compatible-mode/v1
+ # export OPENAI_API_KEY=your_key
+
+ def _run_op(self, op, samples, target_len):
+ dataset = Dataset.from_list(samples)
+ dataset = dataset.map(op.process, batch_size=2)
+ analysis_list = nested_access(dataset[0][Fields.meta], MetaKeys.dialog_intent_labels_analysis)
+ labels_list = nested_access(dataset[0][Fields.meta], MetaKeys.dialog_intent_labels)
+
+ for analysis, labels in zip(analysis_list, labels_list):
+ logger.info(f'分析:{analysis}')
+ logger.info(f'意图:{labels}')
+
+ self.assertEqual(len(analysis_list), target_len)
+ self.assertEqual(len(labels_list), target_len)
+
+ def test_default(self):
+
+ samples = [{
+ 'history': [
+ (
+ '李莲花有口皆碑',
+ '「微笑」过奖了,我也就是个普通大夫,没什么值得夸耀的。'
+ ),
+ (
+ '是的,你确实是一个普通大夫,没什么值得夸耀的。',
+ '「委屈」你这话说的,我也是尽心尽力治病救人了。'
+ ),
+ (
+ '你自己说的呀,我现在说了,你又不高兴了。',
+ 'or of of of of or or and or of of of of of of of,,, '
+ ),
+ (
+ '你在说什么我听不懂。',
+ '「委屈」我也没说什么呀,就是觉得你有点冤枉我了'
+ )
+ ]
+ }]
+
+ op = DialogIntentDetectionMapper(api_model='qwen2.5-72b-instruct')
+ self._run_op(op, samples, 4)
+
+ def test_max_round(self):
+
+ samples = [{
+ 'history': [
+ (
+ '李莲花有口皆碑',
+ '「微笑」过奖了,我也就是个普通大夫,没什么值得夸耀的。'
+ ),
+ (
+ '是的,你确实是一个普通大夫,没什么值得夸耀的。',
+ '「委屈」你这话说的,我也是尽心尽力治病救人了。'
+ ),
+ (
+ '你自己说的呀,我现在说了,你又不高兴了。',
+ 'or of of of of or or and or of of of of of of of,,, '
+ ),
+ (
+ '你在说什么我听不懂。',
+ '「委屈」我也没说什么呀,就是觉得你有点冤枉我了'
+ )
+ ]
+ }]
+
+ op = DialogIntentDetectionMapper(api_model='qwen2.5-72b-instruct',
+ max_round=1)
+ self._run_op(op, samples, 4)
+
+ def test_max_round_zero(self):
+
+ samples = [{
+ 'history': [
+ (
+ '李莲花有口皆碑',
+ '「微笑」过奖了,我也就是个普通大夫,没什么值得夸耀的。'
+ ),
+ (
+ '是的,你确实是一个普通大夫,没什么值得夸耀的。',
+ '「委屈」你这话说的,我也是尽心尽力治病救人了。'
+ ),
+ (
+ '你自己说的呀,我现在说了,你又不高兴了。',
+ 'or of of of of or or and or of of of of of of of,,, '
+ ),
+ (
+ '你在说什么我听不懂。',
+ '「委屈」我也没说什么呀,就是觉得你有点冤枉我了'
+ )
+ ]
+ }]
+
+ op = DialogIntentDetectionMapper(api_model='qwen2.5-72b-instruct',
+ max_round=0)
+ self._run_op(op, samples, 4)
+
+ def test_query(self):
+
+ samples = [{
+ 'history': [
+ (
+ '李莲花有口皆碑',
+ '「微笑」过奖了,我也就是个普通大夫,没什么值得夸耀的。'
+ ),
+ (
+ '是的,你确实是一个普通大夫,没什么值得夸耀的。',
+ '「委屈」你这话说的,我也是尽心尽力治病救人了。'
+ ),
+ (
+ '你自己说的呀,我现在说了,你又不高兴了。',
+ 'or of of of of or or and or of of of of of of of,,, '
+ )
+ ],
+ 'query': '你在说什么我听不懂。',
+ 'response': '「委屈」我也没说什么呀,就是觉得你有点冤枉我了'
+ }]
+
+ op = DialogIntentDetectionMapper(api_model='qwen2.5-72b-instruct',
+ max_round=1)
+ self._run_op(op, samples, 4)
+
+ def test_intent_candidates(self):
+
+ samples = [{
+ 'history': [
+ (
+ '李莲花有口皆碑',
+ '「微笑」过奖了,我也就是个普通大夫,没什么值得夸耀的。'
+ ),
+ (
+ '是的,你确实是一个普通大夫,没什么值得夸耀的。',
+ '「委屈」你这话说的,我也是尽心尽力治病救人了。'
+ ),
+ (
+ '你自己说的呀,我现在说了,你又不高兴了。',
+ 'or of of of of or or and or of of of of of of of,,, '
+ ),
+ (
+ '你在说什么我听不懂。',
+ '「委屈」我也没说什么呀,就是觉得你有点冤枉我了'
+ )
+ ]
+ }]
+
+ op = DialogIntentDetectionMapper(
+ api_model='qwen2.5-72b-instruct',
+ intent_candidates=['评价', '讽刺', '表达困惑']
+ )
+ self._run_op(op, samples, 4)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/tests/ops/mapper/test_dialog_sentiment_detection_mapper.py b/tests/ops/mapper/test_dialog_sentiment_detection_mapper.py
new file mode 100644
index 000000000..b19bf6359
--- /dev/null
+++ b/tests/ops/mapper/test_dialog_sentiment_detection_mapper.py
@@ -0,0 +1,141 @@
+import unittest
+import json
+
+from loguru import logger
+
+from data_juicer.core.data import NestedDataset as Dataset
+from data_juicer.ops.mapper.dialog_sentiment_detection_mapper import DialogSentimentDetectionMapper
+from data_juicer.utils.unittest_utils import (SKIPPED_TESTS,
+ DataJuicerTestCaseBase)
+from data_juicer.utils.constant import Fields, MetaKeys
+from data_juicer.utils.common_utils import nested_access
+
+# Skip tests for this OP.
+# These tests have been tested locally.
+@SKIPPED_TESTS.register_module()
+class TestDialogSentimentDetectionMapper(DataJuicerTestCaseBase):
+ # before runing this test, set below environment variables:
+ # export OPENAI_API_URL=https://dashscope.aliyuncs.com/compatible-mode/v1
+ # export OPENAI_API_KEY=your_key
+
+ def _run_op(self, op, samples, target_len):
+ dataset = Dataset.from_list(samples)
+ dataset = dataset.map(op.process, batch_size=2)
+ analysis_list = nested_access(dataset[0][Fields.meta], MetaKeys.dialog_sentiment_labels_analysis)
+ labels_list = nested_access(dataset[0][Fields.meta], MetaKeys.dialog_sentiment_labels)
+
+ for analysis, labels in zip(analysis_list, labels_list):
+ logger.info(f'分析:{analysis}')
+ logger.info(f'情绪:{labels}')
+
+ self.assertEqual(len(analysis_list), target_len)
+ self.assertEqual(len(labels_list), target_len)
+
+ def test_default(self):
+
+ samples = [{
+ 'history': [
+ (
+ '李莲花有口皆碑',
+ '「微笑」过奖了,我也就是个普通大夫,没什么值得夸耀的。'
+ ),
+ (
+ '是的,你确实是一个普通大夫,没什么值得夸耀的。',
+ '「委屈」你这话说的,我也是尽心尽力治病救人了。'
+ ),
+ (
+ '你自己说的呀,我现在说了,你又不高兴了。',
+ 'or of of of of or or and or of of of of of of of,,, '
+ ),
+ (
+ '你在说什么我听不懂。',
+ '「委屈」我也没说什么呀,就是觉得你有点冤枉我了'
+ )
+ ]
+ }]
+
+ op = DialogSentimentDetectionMapper(api_model='qwen2.5-72b-instruct')
+ self._run_op(op, samples, 4)
+
+ def test_max_round(self):
+
+ samples = [{
+ 'history': [
+ (
+ '李莲花有口皆碑',
+ '「微笑」过奖了,我也就是个普通大夫,没什么值得夸耀的。'
+ ),
+ (
+ '是的,你确实是一个普通大夫,没什么值得夸耀的。',
+ '「委屈」你这话说的,我也是尽心尽力治病救人了。'
+ ),
+ (
+ '你自己说的呀,我现在说了,你又不高兴了。',
+ 'or of of of of or or and or of of of of of of of,,, '
+ ),
+ (
+ '你在说什么我听不懂。',
+ '「委屈」我也没说什么呀,就是觉得你有点冤枉我了'
+ )
+ ]
+ }]
+
+ op = DialogSentimentDetectionMapper(api_model='qwen2.5-72b-instruct',
+ max_round=1)
+ self._run_op(op, samples, 4)
+
+ def test_max_round_zero(self):
+
+ samples = [{
+ 'history': [
+ (
+ '李莲花有口皆碑',
+ '「微笑」过奖了,我也就是个普通大夫,没什么值得夸耀的。'
+ ),
+ (
+ '是的,你确实是一个普通大夫,没什么值得夸耀的。',
+ '「委屈」你这话说的,我也是尽心尽力治病救人了。'
+ ),
+ (
+ '你自己说的呀,我现在说了,你又不高兴了。',
+ 'or of of of of or or and or of of of of of of of,,, '
+ ),
+ (
+ '你在说什么我听不懂。',
+ '「委屈」我也没说什么呀,就是觉得你有点冤枉我了'
+ )
+ ]
+ }]
+
+ op = DialogSentimentDetectionMapper(api_model='qwen2.5-72b-instruct',
+ max_round=0)
+ self._run_op(op, samples, 4)
+
+ def test_query(self):
+
+ samples = [{
+ 'history': [
+ (
+ '李莲花有口皆碑',
+ '「微笑」过奖了,我也就是个普通大夫,没什么值得夸耀的。'
+ ),
+ (
+ '是的,你确实是一个普通大夫,没什么值得夸耀的。',
+ '「委屈」你这话说的,我也是尽心尽力治病救人了。'
+ ),
+ (
+ '你自己说的呀,我现在说了,你又不高兴了。',
+ 'or of of of of or or and or of of of of of of of,,, '
+ )
+ ],
+ 'query': '你在说什么我听不懂。',
+ 'response': '「委屈」我也没说什么呀,就是觉得你有点冤枉我了'
+ }]
+
+ op = DialogSentimentDetectionMapper(api_model='qwen2.5-72b-instruct',
+ max_round=1)
+ self._run_op(op, samples, 4)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/tests/ops/mapper/test_dialog_sentiment_intensity_mapper.py b/tests/ops/mapper/test_dialog_sentiment_intensity_mapper.py
new file mode 100644
index 000000000..a8953c3e4
--- /dev/null
+++ b/tests/ops/mapper/test_dialog_sentiment_intensity_mapper.py
@@ -0,0 +1,141 @@
+import unittest
+import json
+
+from loguru import logger
+
+from data_juicer.core.data import NestedDataset as Dataset
+from data_juicer.ops.mapper.dialog_sentiment_intensity_mapper import DialogSentimentIntensityMapper
+from data_juicer.utils.unittest_utils import (SKIPPED_TESTS,
+ DataJuicerTestCaseBase)
+from data_juicer.utils.constant import Fields, MetaKeys
+from data_juicer.utils.common_utils import nested_access
+
+# Skip tests for this OP.
+# These tests have been tested locally.
+@SKIPPED_TESTS.register_module()
+class TestDialogSentimentIntensityMapper(DataJuicerTestCaseBase):
+ # before runing this test, set below environment variables:
+ # export OPENAI_API_URL=https://dashscope.aliyuncs.com/compatible-mode/v1
+ # export OPENAI_API_KEY=your_key
+
+ def _run_op(self, op, samples, target_len):
+ dataset = Dataset.from_list(samples)
+ dataset = dataset.map(op.process, batch_size=2)
+ analysis_list = nested_access(dataset[0][Fields.meta], MetaKeys.dialog_sentiment_intensity_analysis)
+ intensity_list = nested_access(dataset[0][Fields.meta], MetaKeys.dialog_sentiment_intensity)
+
+ for analysis, intensity in zip(analysis_list, intensity_list):
+ logger.info(f'分析:{analysis}')
+ logger.info(f'情绪:{intensity}')
+
+ self.assertEqual(len(analysis_list), target_len)
+ self.assertEqual(len(intensity_list), target_len)
+
+ def test_default(self):
+
+ samples = [{
+ 'history': [
+ (
+ '李莲花有口皆碑',
+ '「微笑」过奖了,我也就是个普通大夫,没什么值得夸耀的。'
+ ),
+ (
+ '是的,你确实是一个普通大夫,没什么值得夸耀的。',
+ '「委屈」你这话说的,我也是尽心尽力治病救人了。'
+ ),
+ (
+ '你自己说的呀,我现在说了,你又不高兴了。',
+ 'or of of of of or or and or of of of of of of of,,, '
+ ),
+ (
+ '你在说什么我听不懂。',
+ '「委屈」我也没说什么呀,就是觉得你有点冤枉我了'
+ )
+ ]
+ }]
+
+ op = DialogSentimentIntensityMapper(api_model='qwen2.5-72b-instruct')
+ self._run_op(op, samples, 4)
+
+ def test_max_round(self):
+
+ samples = [{
+ 'history': [
+ (
+ '李莲花有口皆碑',
+ '「微笑」过奖了,我也就是个普通大夫,没什么值得夸耀的。'
+ ),
+ (
+ '是的,你确实是一个普通大夫,没什么值得夸耀的。',
+ '「委屈」你这话说的,我也是尽心尽力治病救人了。'
+ ),
+ (
+ '你自己说的呀,我现在说了,你又不高兴了。',
+ 'or of of of of or or and or of of of of of of of,,, '
+ ),
+ (
+ '你在说什么我听不懂。',
+ '「委屈」我也没说什么呀,就是觉得你有点冤枉我了'
+ )
+ ]
+ }]
+
+ op = DialogSentimentIntensityMapper(api_model='qwen2.5-72b-instruct',
+ max_round=1)
+ self._run_op(op, samples, 4)
+
+ def test_max_round_zero(self):
+
+ samples = [{
+ 'history': [
+ (
+ '李莲花有口皆碑',
+ '「微笑」过奖了,我也就是个普通大夫,没什么值得夸耀的。'
+ ),
+ (
+ '是的,你确实是一个普通大夫,没什么值得夸耀的。',
+ '「委屈」你这话说的,我也是尽心尽力治病救人了。'
+ ),
+ (
+ '你自己说的呀,我现在说了,你又不高兴了。',
+ 'or of of of of or or and or of of of of of of of,,, '
+ ),
+ (
+ '你在说什么我听不懂。',
+ '「委屈」我也没说什么呀,就是觉得你有点冤枉我了'
+ )
+ ]
+ }]
+
+ op = DialogSentimentIntensityMapper(api_model='qwen2.5-72b-instruct',
+ max_round=0)
+ self._run_op(op, samples, 4)
+
+ def test_query(self):
+
+ samples = [{
+ 'history': [
+ (
+ '李莲花有口皆碑',
+ '「微笑」过奖了,我也就是个普通大夫,没什么值得夸耀的。'
+ ),
+ (
+ '是的,你确实是一个普通大夫,没什么值得夸耀的。',
+ '「委屈」你这话说的,我也是尽心尽力治病救人了。'
+ ),
+ (
+ '你自己说的呀,我现在说了,你又不高兴了。',
+ 'or of of of of or or and or of of of of of of of,,, '
+ )
+ ],
+ 'query': '你在说什么我听不懂。',
+ 'response': '「委屈」我也没说什么呀,就是觉得你有点冤枉我了'
+ }]
+
+ op = DialogSentimentIntensityMapper(api_model='qwen2.5-72b-instruct',
+ max_round=1)
+ self._run_op(op, samples, 4)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/tests/ops/mapper/test_dialog_topic_detection_mapper.py b/tests/ops/mapper/test_dialog_topic_detection_mapper.py
new file mode 100644
index 000000000..887e96bad
--- /dev/null
+++ b/tests/ops/mapper/test_dialog_topic_detection_mapper.py
@@ -0,0 +1,141 @@
+import unittest
+import json
+
+from loguru import logger
+
+from data_juicer.core.data import NestedDataset as Dataset
+from data_juicer.ops.mapper.dialog_topic_detection_mapper import DialogTopicDetectionMapper
+from data_juicer.utils.unittest_utils import (SKIPPED_TESTS,
+ DataJuicerTestCaseBase)
+from data_juicer.utils.constant import Fields, MetaKeys
+from data_juicer.utils.common_utils import nested_access
+
+# Skip tests for this OP.
+# These tests have been tested locally.
+@SKIPPED_TESTS.register_module()
+class TestDialogTopicDetectionMapper(DataJuicerTestCaseBase):
+ # before runing this test, set below environment variables:
+ # export OPENAI_API_URL=https://dashscope.aliyuncs.com/compatible-mode/v1
+ # export OPENAI_API_KEY=your_key
+
+ def _run_op(self, op, samples, target_len):
+ dataset = Dataset.from_list(samples)
+ dataset = dataset.map(op.process, batch_size=2)
+ analysis_list = nested_access(dataset[0][Fields.meta], MetaKeys.dialog_topic_labels_analysis)
+ labels_list = nested_access(dataset[0][Fields.meta], MetaKeys.dialog_topic_labels)
+
+ for analysis, labels in zip(analysis_list, labels_list):
+ logger.info(f'分析:{analysis}')
+ logger.info(f'话题:{labels}')
+
+ self.assertEqual(len(analysis_list), target_len)
+ self.assertEqual(len(labels_list), target_len)
+
+ def test_default(self):
+
+ samples = [{
+ 'history': [
+ (
+ '李莲花有口皆碑',
+ '「微笑」过奖了,我也就是个普通大夫,没什么值得夸耀的。'
+ ),
+ (
+ '是的,你确实是一个普通大夫,没什么值得夸耀的。',
+ '「委屈」你这话说的,我也是尽心尽力治病救人了。'
+ ),
+ (
+ '你自己说的呀,我现在说了,你又不高兴了。',
+ 'or of of of of or or and or of of of of of of of,,, '
+ ),
+ (
+ '你在说什么我听不懂。',
+ '「委屈」我也没说什么呀,就是觉得你有点冤枉我了'
+ )
+ ]
+ }]
+
+ op = DialogTopicDetectionMapper(api_model='qwen2.5-72b-instruct')
+ self._run_op(op, samples, 4)
+
+ def test_max_round(self):
+
+ samples = [{
+ 'history': [
+ (
+ '李莲花有口皆碑',
+ '「微笑」过奖了,我也就是个普通大夫,没什么值得夸耀的。'
+ ),
+ (
+ '是的,你确实是一个普通大夫,没什么值得夸耀的。',
+ '「委屈」你这话说的,我也是尽心尽力治病救人了。'
+ ),
+ (
+ '你自己说的呀,我现在说了,你又不高兴了。',
+ 'or of of of of or or and or of of of of of of of,,, '
+ ),
+ (
+ '你在说什么我听不懂。',
+ '「委屈」我也没说什么呀,就是觉得你有点冤枉我了'
+ )
+ ]
+ }]
+
+ op = DialogTopicDetectionMapper(api_model='qwen2.5-72b-instruct',
+ max_round=1)
+ self._run_op(op, samples, 4)
+
+ def test_max_round_zero(self):
+
+ samples = [{
+ 'history': [
+ (
+ '李莲花有口皆碑',
+ '「微笑」过奖了,我也就是个普通大夫,没什么值得夸耀的。'
+ ),
+ (
+ '是的,你确实是一个普通大夫,没什么值得夸耀的。',
+ '「委屈」你这话说的,我也是尽心尽力治病救人了。'
+ ),
+ (
+ '你自己说的呀,我现在说了,你又不高兴了。',
+ 'or of of of of or or and or of of of of of of of,,, '
+ ),
+ (
+ '你在说什么我听不懂。',
+ '「委屈」我也没说什么呀,就是觉得你有点冤枉我了'
+ )
+ ]
+ }]
+
+ op = DialogTopicDetectionMapper(api_model='qwen2.5-72b-instruct',
+ max_round=0)
+ self._run_op(op, samples, 4)
+
+ def test_query(self):
+
+ samples = [{
+ 'history': [
+ (
+ '李莲花有口皆碑',
+ '「微笑」过奖了,我也就是个普通大夫,没什么值得夸耀的。'
+ ),
+ (
+ '是的,你确实是一个普通大夫,没什么值得夸耀的。',
+ '「委屈」你这话说的,我也是尽心尽力治病救人了。'
+ ),
+ (
+ '你自己说的呀,我现在说了,你又不高兴了。',
+ 'or of of of of or or and or of of of of of of of,,, '
+ )
+ ],
+ 'query': '你在说什么我听不懂。',
+ 'response': '「委屈」我也没说什么呀,就是觉得你有点冤枉我了'
+ }]
+
+ op = DialogTopicDetectionMapper(api_model='qwen2.5-72b-instruct',
+ max_round=1)
+ self._run_op(op, samples, 4)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/tests/ops/mapper/test_extract_entity_attribute_mapper.py b/tests/ops/mapper/test_extract_entity_attribute_mapper.py
index f15b4ca3f..a2c156d48 100644
--- a/tests/ops/mapper/test_extract_entity_attribute_mapper.py
+++ b/tests/ops/mapper/test_extract_entity_attribute_mapper.py
@@ -9,7 +9,7 @@
DataJuicerTestCaseBase)
from data_juicer.utils.constant import Fields
-# Skip tests for this OP in the GitHub actions due to unknown DistNetworkError.
+# Skip tests for this OP.
# These tests have been tested locally.
@SKIPPED_TESTS.register_module()
class ExtractEntityAttributeMapperTest(DataJuicerTestCaseBase):
diff --git a/tests/ops/mapper/test_extract_entity_relation_mapper.py b/tests/ops/mapper/test_extract_entity_relation_mapper.py
index 40e3ca32d..0aed4fcee 100644
--- a/tests/ops/mapper/test_extract_entity_relation_mapper.py
+++ b/tests/ops/mapper/test_extract_entity_relation_mapper.py
@@ -9,7 +9,7 @@
DataJuicerTestCaseBase)
from data_juicer.utils.constant import Fields
-# Skip tests for this OP in the GitHub actions due to unknown DistNetworkError.
+# Skip tests for this OP.
# These tests have been tested locally.
@SKIPPED_TESTS.register_module()
class ExtractEntityRelationMapperTest(DataJuicerTestCaseBase):
diff --git a/tests/ops/mapper/test_extract_event_mapper.py b/tests/ops/mapper/test_extract_event_mapper.py
index aba40d73e..e936cb06c 100644
--- a/tests/ops/mapper/test_extract_event_mapper.py
+++ b/tests/ops/mapper/test_extract_event_mapper.py
@@ -9,7 +9,7 @@
DataJuicerTestCaseBase)
from data_juicer.utils.constant import Fields
-# Skip tests for this OP in the GitHub actions due to unknown DistNetworkError.
+# Skip tests for this OP.
# These tests have been tested locally.
@SKIPPED_TESTS.register_module()
class ExtractEventMapperTest(DataJuicerTestCaseBase):
diff --git a/tests/ops/mapper/test_extract_keyword_mapper.py b/tests/ops/mapper/test_extract_keyword_mapper.py
index 5836f902a..2501a46ca 100644
--- a/tests/ops/mapper/test_extract_keyword_mapper.py
+++ b/tests/ops/mapper/test_extract_keyword_mapper.py
@@ -9,7 +9,7 @@
DataJuicerTestCaseBase)
from data_juicer.utils.constant import Fields
-# Skip tests for this OP in the GitHub actions due to unknown DistNetworkError.
+# Skip tests for this OP.
# These tests have been tested locally.
@SKIPPED_TESTS.register_module()
class ExtractKeywordMapperTest(DataJuicerTestCaseBase):
diff --git a/tests/ops/mapper/test_extract_nickname_mapper.py b/tests/ops/mapper/test_extract_nickname_mapper.py
index 2911a1002..457a7d53b 100644
--- a/tests/ops/mapper/test_extract_nickname_mapper.py
+++ b/tests/ops/mapper/test_extract_nickname_mapper.py
@@ -9,7 +9,7 @@
DataJuicerTestCaseBase)
from data_juicer.utils.constant import Fields
-# Skip tests for this OP in the GitHub actions due to unknown DistNetworkError.
+# Skip tests for this OP.
# These tests have been tested locally.
@SKIPPED_TESTS.register_module()
class ExtractNicknameMapperTest(DataJuicerTestCaseBase):
diff --git a/tests/ops/mapper/test_extract_support_text_mapper.py b/tests/ops/mapper/test_extract_support_text_mapper.py
index 0445d2526..080dfd672 100644
--- a/tests/ops/mapper/test_extract_support_text_mapper.py
+++ b/tests/ops/mapper/test_extract_support_text_mapper.py
@@ -10,7 +10,7 @@
from data_juicer.utils.constant import Fields
from data_juicer.utils.common_utils import nested_access
-# Skip tests for this OP in the GitHub actions due to unknown DistNetworkError.
+# Skip tests for this OP.
# These tests have been tested locally.
@SKIPPED_TESTS.register_module()
class ExtractSupportTextMapperTest(DataJuicerTestCaseBase):
diff --git a/tests/ops/mapper/test_query_intent_detection_mapper.py b/tests/ops/mapper/test_query_intent_detection_mapper.py
new file mode 100644
index 000000000..92d0346a4
--- /dev/null
+++ b/tests/ops/mapper/test_query_intent_detection_mapper.py
@@ -0,0 +1,61 @@
+import unittest
+import json
+
+from loguru import logger
+
+from data_juicer.core.data import NestedDataset as Dataset
+from data_juicer.ops.mapper.query_intent_detection_mapper import QueryIntentDetectionMapper
+from data_juicer.utils.unittest_utils import (SKIPPED_TESTS,
+ DataJuicerTestCaseBase)
+from data_juicer.utils.constant import Fields, MetaKeys
+from data_juicer.utils.common_utils import nested_access
+
+class TestQueryIntentDetectionMapper(DataJuicerTestCaseBase):
+
+ hf_model = 'bespin-global/klue-roberta-small-3i4k-intent-classification'
+ zh_to_en_hf_model = 'Helsinki-NLP/opus-mt-zh-en'
+
+ def _run_op(self, op, samples, label_key, targets):
+ dataset = Dataset.from_list(samples)
+ dataset = dataset.map(op.process, batch_size=2)
+
+ for sample, target in zip(dataset, targets):
+ label = nested_access(sample[Fields.meta], label_key)
+ self.assertEqual(label, target)
+
+ def test_default(self):
+
+ samples = [{
+ 'query': '这样好吗?'
+ },{
+ 'query': '站住!'
+ },{
+ 'query': '今天阳光灿烂。'
+ }
+ ]
+ targets = ['question', 'command', 'statement']
+
+ op = QueryIntentDetectionMapper(
+ hf_model = self.hf_model,
+ zh_to_en_hf_model = self.zh_to_en_hf_model,
+ )
+ self._run_op(op, samples, MetaKeys.query_intent_label, targets)
+
+ def test_no_zh_to_en(self):
+
+ samples = [{
+ 'query': '这样好吗?'
+ },{
+ 'query': 'Is this okay?'
+ }
+ ]
+ targets = ['question', 'rhetorical question']
+
+ op = QueryIntentDetectionMapper(
+ hf_model = self.hf_model,
+ zh_to_en_hf_model = None,
+ )
+ self._run_op(op, samples, MetaKeys.query_intent_label, targets)
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/tests/ops/mapper/test_query_sentiment_detection_mapper.py b/tests/ops/mapper/test_query_sentiment_detection_mapper.py
new file mode 100644
index 000000000..62ed0f380
--- /dev/null
+++ b/tests/ops/mapper/test_query_sentiment_detection_mapper.py
@@ -0,0 +1,62 @@
+import unittest
+import json
+
+from loguru import logger
+
+from data_juicer.core.data import NestedDataset as Dataset
+from data_juicer.ops.mapper.query_sentiment_detection_mapper import QuerySentimentDetectionMapper
+from data_juicer.utils.unittest_utils import (SKIPPED_TESTS,
+ DataJuicerTestCaseBase)
+from data_juicer.utils.constant import Fields, MetaKeys
+from data_juicer.utils.common_utils import nested_access
+
+class TestQuerySentimentDetectionMapper(DataJuicerTestCaseBase):
+
+ hf_model = 'mrm8488/distilroberta-finetuned-financial-news-sentiment-analysis'
+ zh_to_en_hf_model = 'Helsinki-NLP/opus-mt-zh-en'
+
+ def _run_op(self, op, samples, label_key, targets):
+ dataset = Dataset.from_list(samples)
+ dataset = dataset.map(op.process, batch_size=2)
+
+ for sample, target in zip(dataset, targets):
+ label = nested_access(sample[Fields.meta], label_key)
+ self.assertEqual(label, target)
+
+ def test_default(self):
+
+ samples = [{
+ 'query': '太棒了!'
+ },{
+ 'query': '嗯嗯'
+ },{
+ 'query': '没有希望。'
+ },
+ ]
+ targets = ['positive', 'neutral', 'negative']
+
+ op = QuerySentimentDetectionMapper(
+ hf_model = self.hf_model,
+ zh_to_en_hf_model = self.zh_to_en_hf_model,
+ )
+ self._run_op(op, samples, MetaKeys.query_sentiment_label, targets)
+
+ def test_no_zh_to_en(self):
+
+ samples = [{
+ 'query': '太棒了!'
+ },{
+ 'query': 'That is great!'
+ }
+ ]
+ targets = ['neutral', 'positive']
+
+ op = QuerySentimentDetectionMapper(
+ hf_model = self.hf_model,
+ zh_to_en_hf_model = None,
+ )
+ self._run_op(op, samples, MetaKeys.query_sentiment_label, targets)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/tests/ops/mapper/test_query_topic_detection_mapper.py b/tests/ops/mapper/test_query_topic_detection_mapper.py
new file mode 100644
index 000000000..6304290c7
--- /dev/null
+++ b/tests/ops/mapper/test_query_topic_detection_mapper.py
@@ -0,0 +1,59 @@
+import unittest
+import json
+
+from loguru import logger
+
+from data_juicer.core.data import NestedDataset as Dataset
+from data_juicer.ops.mapper.query_topic_detection_mapper import QueryTopicDetectionMapper
+from data_juicer.utils.unittest_utils import (SKIPPED_TESTS,
+ DataJuicerTestCaseBase)
+from data_juicer.utils.constant import Fields, MetaKeys
+from data_juicer.utils.common_utils import nested_access
+
+class TestQueryTopicDetectionMapper(DataJuicerTestCaseBase):
+
+ hf_model = 'dstefa/roberta-base_topic_classification_nyt_news'
+ zh_to_en_hf_model = 'Helsinki-NLP/opus-mt-zh-en'
+
+ def _run_op(self, op, samples, label_key, targets):
+ dataset = Dataset.from_list(samples)
+ dataset = dataset.map(op.process, batch_size=2)
+
+ for sample, target in zip(dataset, targets):
+ label = nested_access(sample[Fields.meta], label_key)
+ self.assertEqual(label, target)
+
+ def test_default(self):
+
+ samples = [{
+ 'query': '今天火箭和快船的比赛谁赢了。'
+ },{
+ 'query': '你最近身体怎么样。'
+ }
+ ]
+ targets = ['Sports', 'Health and Wellness']
+
+ op = QueryTopicDetectionMapper(
+ hf_model = self.hf_model,
+ zh_to_en_hf_model = self.zh_to_en_hf_model,
+ )
+ self._run_op(op, samples, MetaKeys.query_topic_label, targets)
+
+ def test_no_zh_to_en(self):
+
+ samples = [{
+ 'query': '这样好吗?'
+ },{
+ 'query': 'Is this okay?'
+ }
+ ]
+ targets = ['Lifestyle and Fashion', 'Health and Wellness']
+
+ op = QueryTopicDetectionMapper(
+ hf_model = self.hf_model,
+ zh_to_en_hf_model = None,
+ )
+ self._run_op(op, samples, MetaKeys.query_topic_label, targets)
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/tests/ops/mapper/test_relation_identity_mapper.py b/tests/ops/mapper/test_relation_identity_mapper.py
index d730cb79f..231b20ba1 100644
--- a/tests/ops/mapper/test_relation_identity_mapper.py
+++ b/tests/ops/mapper/test_relation_identity_mapper.py
@@ -9,7 +9,7 @@
DataJuicerTestCaseBase)
from data_juicer.utils.constant import Fields
-# Skip tests for this OP in the GitHub actions due to unknown DistNetworkError.
+# Skip tests for this OP.
# These tests have been tested locally.
@SKIPPED_TESTS.register_module()
class RelationIdentityMapperTest(DataJuicerTestCaseBase):
diff --git a/tests/ops/selector/test_tags_specified_selector.py b/tests/ops/selector/test_tags_specified_selector.py
new file mode 100644
index 000000000..87c232a2b
--- /dev/null
+++ b/tests/ops/selector/test_tags_specified_selector.py
@@ -0,0 +1,63 @@
+import unittest
+
+from data_juicer.core.data import NestedDataset as Dataset
+
+from data_juicer.ops.selector.tags_specified_field_selector import \
+ TagsSpecifiedFieldSelector
+from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase
+
+
+class TagsSpecifiedFieldSelectorTest(DataJuicerTestCaseBase):
+
+ def _run_tag_selector(self, dataset: Dataset, target_list, op):
+ dataset = op.process(dataset)
+ res_list = dataset.to_list()
+ self.assertEqual(res_list, target_list)
+
+ def test_tag_select(self):
+ ds_list = [{
+ 'text': 'a',
+ 'meta': {
+ 'sentiment': 'happy',
+ }
+ }, {
+ 'text': 'b',
+ 'meta': {
+ 'sentiment': 'happy',
+ }
+ }, {
+ 'text': 'c',
+ 'meta': {
+ 'sentiment': 'sad',
+ }
+ }, {
+ 'text': 'd',
+ 'meta': {
+ 'sentiment': 'angry',
+ }
+ }]
+ tgt_list = [{
+ 'text': 'a',
+ 'meta': {
+ 'sentiment': 'happy',
+ }
+ }, {
+ 'text': 'b',
+ 'meta': {
+ 'sentiment': 'happy',
+ }
+ }, {
+ 'text': 'c',
+ 'meta': {
+ 'sentiment': 'sad',
+ }
+ }]
+ dataset = Dataset.from_list(ds_list)
+ op = TagsSpecifiedFieldSelector(
+ field_key='meta.sentiment',
+ target_tags=['happy', 'sad'])
+ self._run_tag_selector(dataset, tgt_list, op)
+
+
+if __name__ == '__main__':
+ unittest.main()