From 4f0f16c129d5d0f16d62d1a8a53e4443005b1ca4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ce=20Ge=20=28=E6=88=88=E7=AD=96=29?= Date: Thu, 28 Nov 2024 10:11:10 +0800 Subject: [PATCH] Add DPO data OP (#491) * add DPO OP * fix args * refine * refine DPO op * add docs --- configs/config_all.yaml | 12 ++ data_juicer/ops/mapper/__init__.py | 28 ++-- data_juicer/ops/mapper/calibrate_qa_mapper.py | 5 +- .../ops/mapper/calibrate_query_mapper.py | 3 +- .../ops/mapper/calibrate_response_mapper.py | 3 +- .../mapper/extract_entity_attribute_mapper.py | 5 +- .../mapper/extract_entity_relation_mapper.py | 5 +- .../ops/mapper/extract_event_mapper.py | 5 +- .../ops/mapper/extract_keyword_mapper.py | 5 +- .../ops/mapper/extract_nickname_mapper.py | 5 +- .../ops/mapper/pair_preference_mapper.py | 131 ++++++++++++++++++ docs/Operators.md | 9 +- docs/Operators_ZH.md | 3 +- tests/ops/mapper/test_calibrate_qa_mapper.py | 2 +- .../ops/mapper/test_calibrate_query_mapper.py | 4 +- .../mapper/test_calibrate_response_mapper.py | 4 +- .../mapper/test_extract_nickname_mapper.py | 4 +- .../ops/mapper/test_pair_preference_mapper.py | 57 ++++++++ 18 files changed, 243 insertions(+), 47 deletions(-) create mode 100644 data_juicer/ops/mapper/pair_preference_mapper.py create mode 100644 tests/ops/mapper/test_pair_preference_mapper.py diff --git a/configs/config_all.yaml b/configs/config_all.yaml index eeb1ba1b2..ea10be519 100644 --- a/configs/config_all.yaml +++ b/configs/config_all.yaml @@ -244,6 +244,18 @@ process: sampling_params: {} # Sampling parameters for text generation. e.g {'temperature': 0.9, 'top_p': 0.95} - optimize_query_mapper: # optimize query in question-answer pairs. - optimize_response_mapper: # optimize response in question-answer pairs. + - pair_preference_mapper: # construct paired preference samples. + api_model: 'gpt-4o' # API model name. + api_endpoint: null # URL endpoint for the API. + response_path: null # Path to extract content from the API response. Defaults to 'choices.0.message.content'. + system_prompt: null # System prompt for guiding the generation task. + input_template: null # Template for building the model input. + output_pattern: null # Regular expression for parsing model output. + rejected_key: 'rejected_response' # The field name in the sample to store the generated rejected response. + reason_key: 'reason' # The field name in the sample to store the reason for generating the response. + try_num: 3 # The number of retries for the API call in case of response parsing failure. + model_params: {} # Parameters for initializing the API model. + sampling_params: {} # Extra parameters passed to the API call. - punctuation_normalization_mapper: # normalize unicode punctuations to English punctuations. - remove_bibliography_mapper: # remove bibliography from Latex text. - remove_comments_mapper: # remove comments from Latex text, code, etc. diff --git a/data_juicer/ops/mapper/__init__.py b/data_juicer/ops/mapper/__init__.py index 41bf092a3..db4f54e10 100644 --- a/data_juicer/ops/mapper/__init__.py +++ b/data_juicer/ops/mapper/__init__.py @@ -28,6 +28,7 @@ from .optimize_qa_mapper import OptimizeQAMapper from .optimize_query_mapper import OptimizeQueryMapper from .optimize_response_mapper import OptimizeResponseMapper +from .pair_preference_mapper import PairPreferenceMapper from .punctuation_normalization_mapper import PunctuationNormalizationMapper from .remove_bibliography_mapper import RemoveBibliographyMapper from .remove_comments_mapper import RemoveCommentsMapper @@ -73,17 +74,18 @@ 'ImageCaptioningMapper', 'ImageDiffusionMapper', 'ImageFaceBlurMapper', 'ImageTaggingMapper', 'NlpaugEnMapper', 'NlpcdaZhMapper', 'OptimizeQAMapper', 'OptimizeQueryMapper', 'OptimizeResponseMapper', - 'PunctuationNormalizationMapper', 'RemoveBibliographyMapper', - 'RemoveCommentsMapper', 'RemoveHeaderMapper', 'RemoveLongWordsMapper', - 'RemoveNonChineseCharacterlMapper', 'RemoveRepeatSentencesMapper', - 'RemoveSpecificCharsMapper', 'RemoveTableTextMapper', - 'RemoveWordsWithIncorrectSubstringsMapper', 'ReplaceContentMapper', - 'SentenceSplitMapper', 'TextChunkMapper', 'VideoCaptioningFromAudioMapper', - 'VideoCaptioningFromFramesMapper', 'VideoCaptioningFromSummarizerMapper', - 'VideoCaptioningFromVideoMapper', 'VideoFFmpegWrappedMapper', - 'VideoFaceBlurMapper', 'VideoRemoveWatermarkMapper', - 'VideoResizeAspectRatioMapper', 'VideoResizeResolutionMapper', - 'VideoSplitByDurationMapper', 'VideoSplitByKeyFrameMapper', - 'VideoSplitBySceneMapper', 'VideoTaggingFromAudioMapper', - 'VideoTaggingFromFramesMapper', 'WhitespaceNormalizationMapper' + 'PairPreferenceMapper', 'PunctuationNormalizationMapper', + 'RemoveBibliographyMapper', 'RemoveCommentsMapper', 'RemoveHeaderMapper', + 'RemoveLongWordsMapper', 'RemoveNonChineseCharacterlMapper', + 'RemoveRepeatSentencesMapper', 'RemoveSpecificCharsMapper', + 'RemoveTableTextMapper', 'RemoveWordsWithIncorrectSubstringsMapper', + 'ReplaceContentMapper', 'SentenceSplitMapper', 'TextChunkMapper', + 'VideoCaptioningFromAudioMapper', 'VideoCaptioningFromFramesMapper', + 'VideoCaptioningFromSummarizerMapper', 'VideoCaptioningFromVideoMapper', + 'VideoFFmpegWrappedMapper', 'VideoFaceBlurMapper', + 'VideoRemoveWatermarkMapper', 'VideoResizeAspectRatioMapper', + 'VideoResizeResolutionMapper', 'VideoSplitByDurationMapper', + 'VideoSplitByKeyFrameMapper', 'VideoSplitBySceneMapper', + 'VideoTaggingFromAudioMapper', 'VideoTaggingFromFramesMapper', + 'WhitespaceNormalizationMapper' ] diff --git a/data_juicer/ops/mapper/calibrate_qa_mapper.py b/data_juicer/ops/mapper/calibrate_qa_mapper.py index 69b860e33..8480ee899 100644 --- a/data_juicer/ops/mapper/calibrate_qa_mapper.py +++ b/data_juicer/ops/mapper/calibrate_qa_mapper.py @@ -4,14 +4,13 @@ from loguru import logger from pydantic import PositiveInt -from data_juicer.ops.base_op import OPERATORS, UNFORKABLE, Mapper +from data_juicer.ops.base_op import OPERATORS, Mapper from data_juicer.utils.model_utils import get_model, prepare_model OP_NAME = 'calibrate_qa_mapper' # TODO: LLM-based inference. -@UNFORKABLE.register_module(OP_NAME) @OPERATORS.register_module(OP_NAME) class CalibrateQAMapper(Mapper): """ @@ -107,7 +106,7 @@ def process_single(self, sample, rank=None): 'content': self.build_input(sample) }] parsed_q, parsed_a = None, None - for i in range(self.try_num): + for _ in range(self.try_num): try: output = client(messages, **self.sampling_params) parsed_q, parsed_a = self.parse_output(output) diff --git a/data_juicer/ops/mapper/calibrate_query_mapper.py b/data_juicer/ops/mapper/calibrate_query_mapper.py index 88098d7f8..48ae0c4f7 100644 --- a/data_juicer/ops/mapper/calibrate_query_mapper.py +++ b/data_juicer/ops/mapper/calibrate_query_mapper.py @@ -1,11 +1,10 @@ -from data_juicer.ops.base_op import OPERATORS, UNFORKABLE +from data_juicer.ops.base_op import OPERATORS from data_juicer.ops.mapper.calibrate_qa_mapper import CalibrateQAMapper OP_NAME = 'calibrate_query_mapper' # TODO: LLM-based inference. -@UNFORKABLE.register_module(OP_NAME) @OPERATORS.register_module(OP_NAME) class CalibrateQueryMapper(CalibrateQAMapper): """ diff --git a/data_juicer/ops/mapper/calibrate_response_mapper.py b/data_juicer/ops/mapper/calibrate_response_mapper.py index db56af317..1d6456c2b 100644 --- a/data_juicer/ops/mapper/calibrate_response_mapper.py +++ b/data_juicer/ops/mapper/calibrate_response_mapper.py @@ -1,11 +1,10 @@ -from data_juicer.ops.base_op import OPERATORS, UNFORKABLE +from data_juicer.ops.base_op import OPERATORS from data_juicer.ops.mapper.calibrate_qa_mapper import CalibrateQAMapper OP_NAME = 'calibrate_response_mapper' # TODO: LLM-based inference. -@UNFORKABLE.register_module(OP_NAME) @OPERATORS.register_module(OP_NAME) class CalibrateResponseMapper(CalibrateQAMapper): """ diff --git a/data_juicer/ops/mapper/extract_entity_attribute_mapper.py b/data_juicer/ops/mapper/extract_entity_attribute_mapper.py index 1fab935f9..fd93cfe03 100644 --- a/data_juicer/ops/mapper/extract_entity_attribute_mapper.py +++ b/data_juicer/ops/mapper/extract_entity_attribute_mapper.py @@ -5,7 +5,7 @@ from loguru import logger from pydantic import PositiveInt -from data_juicer.ops.base_op import OPERATORS, UNFORKABLE, Mapper +from data_juicer.ops.base_op import OPERATORS, Mapper from data_juicer.utils.constant import Fields from data_juicer.utils.model_utils import get_model, prepare_model @@ -13,7 +13,6 @@ # TODO: LLM-based inference. -@UNFORKABLE.register_module(OP_NAME) @OPERATORS.register_module(OP_NAME) class ExtractEntityAttributeMapper(Mapper): """ @@ -154,7 +153,7 @@ def _process_single_sample(self, text='', rank=None): }] desc, demos = '', [] - for i in range(self.try_num): + for _ in range(self.try_num): try: output = client(messages, **self.sampling_params) desc, demos = self.parse_output(output, attribute) diff --git a/data_juicer/ops/mapper/extract_entity_relation_mapper.py b/data_juicer/ops/mapper/extract_entity_relation_mapper.py index 4b026f2a4..6350101ac 100644 --- a/data_juicer/ops/mapper/extract_entity_relation_mapper.py +++ b/data_juicer/ops/mapper/extract_entity_relation_mapper.py @@ -9,7 +9,7 @@ from loguru import logger from pydantic import NonNegativeInt, PositiveInt -from data_juicer.ops.base_op import OPERATORS, UNFORKABLE, Mapper +from data_juicer.ops.base_op import OPERATORS, Mapper from data_juicer.utils.common_utils import is_float from data_juicer.utils.constant import Fields from data_juicer.utils.model_utils import get_model, prepare_model @@ -20,7 +20,6 @@ # TODO: LLM-based inference. -@UNFORKABLE.register_module(OP_NAME) @OPERATORS.register_module(OP_NAME) class ExtractEntityRelationMapper(Mapper): """ @@ -319,7 +318,7 @@ def process_single(self, sample, rank=None): messages = [{'role': 'user', 'content': input_prompt}] entities, relations = [], [] - for i in range(self.try_num): + for _ in range(self.try_num): try: result = self.light_rag_extraction(messages, rank=rank) entities, relations = self.parse_output(result) diff --git a/data_juicer/ops/mapper/extract_event_mapper.py b/data_juicer/ops/mapper/extract_event_mapper.py index 208684b2c..fddf4fed1 100644 --- a/data_juicer/ops/mapper/extract_event_mapper.py +++ b/data_juicer/ops/mapper/extract_event_mapper.py @@ -5,7 +5,7 @@ from loguru import logger from pydantic import PositiveInt -from data_juicer.ops.base_op import OPERATORS, UNFORKABLE, Mapper +from data_juicer.ops.base_op import OPERATORS, Mapper from data_juicer.utils.constant import Fields from data_juicer.utils.model_utils import get_model, prepare_model @@ -15,7 +15,6 @@ # TODO: LLM-based inference. -@UNFORKABLE.register_module(OP_NAME) @OPERATORS.register_module(OP_NAME) class ExtractEventMapper(Mapper): """ @@ -134,7 +133,7 @@ def _process_single_sample(self, text='', rank=None): }] event_list, character_list = [], [] - for i in range(self.try_num): + for _ in range(self.try_num): try: output = client(messages, **self.sampling_params) event_list, character_list = self.parse_output(output) diff --git a/data_juicer/ops/mapper/extract_keyword_mapper.py b/data_juicer/ops/mapper/extract_keyword_mapper.py index cb1814768..24e3e127e 100644 --- a/data_juicer/ops/mapper/extract_keyword_mapper.py +++ b/data_juicer/ops/mapper/extract_keyword_mapper.py @@ -6,7 +6,7 @@ from loguru import logger from pydantic import PositiveInt -from data_juicer.ops.base_op import OPERATORS, UNFORKABLE, Mapper +from data_juicer.ops.base_op import OPERATORS, Mapper from data_juicer.utils.constant import Fields from data_juicer.utils.model_utils import get_model, prepare_model @@ -16,7 +16,6 @@ # TODO: LLM-based inference. -@UNFORKABLE.register_module(OP_NAME) @OPERATORS.register_module(OP_NAME) class ExtractKeywordMapper(Mapper): """ @@ -173,7 +172,7 @@ def process_single(self, sample, rank=None): messages = [{'role': 'user', 'content': input_prompt}] keywords = [] - for i in range(self.try_num): + for _ in range(self.try_num): try: result = client(messages, **self.sampling_params) keywords = self.parse_output(result) diff --git a/data_juicer/ops/mapper/extract_nickname_mapper.py b/data_juicer/ops/mapper/extract_nickname_mapper.py index b11cbab57..20aeb94db 100644 --- a/data_juicer/ops/mapper/extract_nickname_mapper.py +++ b/data_juicer/ops/mapper/extract_nickname_mapper.py @@ -4,7 +4,7 @@ from loguru import logger from pydantic import PositiveInt -from data_juicer.ops.base_op import OPERATORS, UNFORKABLE, Mapper +from data_juicer.ops.base_op import OPERATORS, Mapper from data_juicer.utils.constant import Fields from data_juicer.utils.model_utils import get_model, prepare_model @@ -12,7 +12,6 @@ # TODO: LLM-based inference. -@UNFORKABLE.register_module(OP_NAME) @OPERATORS.register_module(OP_NAME) class ExtractNicknameMapper(Mapper): """ @@ -143,7 +142,7 @@ def process_single(self, sample, rank=None): 'content': input_prompt }] nickname_relations = [] - for i in range(self.try_num): + for _ in range(self.try_num): try: output = client(messages, **self.sampling_params) nickname_relations = self.parse_output(output) diff --git a/data_juicer/ops/mapper/pair_preference_mapper.py b/data_juicer/ops/mapper/pair_preference_mapper.py new file mode 100644 index 000000000..f839fb5d3 --- /dev/null +++ b/data_juicer/ops/mapper/pair_preference_mapper.py @@ -0,0 +1,131 @@ +import re +from typing import Dict, Optional + +from loguru import logger +from pydantic import PositiveInt + +from data_juicer.ops.base_op import OPERATORS, Mapper +from data_juicer.utils.model_utils import get_model, prepare_model + +OP_NAME = 'pair_preference_mapper' + + +# TODO: Extend LLM-based OPs into API-based implementation. +@OPERATORS.register_module(OP_NAME) +class PairPreferenceMapper(Mapper): + """ + Mapper to construct paired preference samples. + """ + + # avoid leading whitespace + DEFAULT_SYSTEM_PROMPT = ( + '你的任务是根据参考信息修改问答对中的回答,在语言风格、事实性、人物身份、立场等任一方面与原回答相反。' + '必须按照以下标记格式输出,不要输出其他多余内容。\n' + '【回答】\n' + '生成的新回答\n' + '【原因】\n' + '生成该回答的原因') + DEFAULT_INPUT_TEMPLATE = ('【参考信息】\n' + '{reference}\n' + '\n' + '以下是原始问答对:\n' + '【问题】\n' + '{query}\n' + '【回答】\n' + '{response}') + DEFAULT_OUTPUT_PATTERN = r'.*?【回答】\s*(.*?)\s*【原因】\s*(.*)' + + def __init__(self, + api_model: str = 'gpt-4o', + *, + api_endpoint: Optional[str] = None, + response_path: Optional[str] = None, + system_prompt: Optional[str] = None, + input_template: Optional[str] = None, + output_pattern: Optional[str] = None, + rejected_key: str = 'rejected_response', + reason_key: str = 'reason', + try_num: PositiveInt = 3, + model_params: Dict = {}, + sampling_params: Dict = {}, + **kwargs): + """ + Initialization method. + + :param api_model: API model name. + :param api_endpoint: URL endpoint for the API. + :param response_path: Path to extract content from the API response. + Defaults to 'choices.0.message.content'. + :param system_prompt: System prompt for guiding the generation task. + :param input_template: Template for building the model input. It must + contain placeholders '{query}' and '{reponse}', and can optionally + include '{reference}'. + :param output_pattern: Regular expression for parsing model output. + :param rejected_key: The field name in the sample to store the + generated rejected response. Defaults to 'rejected_response'. + :param reason_key: The field name in the sample to store the reason for + generating the response. Defaults to 'reason'. + :param try_num: The number of retries for the API call in case of + response parsing failure. Defaults to 3. + :param model_params: Parameters for initializing the API model. + :param sampling_params: Extra parameters passed to the API call. + e.g {'temperature': 0.9, 'top_p': 0.95} + :param kwargs: Extra keyword arguments. + """ + super().__init__(**kwargs) + + self.system_prompt = system_prompt or self.DEFAULT_SYSTEM_PROMPT + self.input_template = input_template or self.DEFAULT_INPUT_TEMPLATE + self.output_pattern = output_pattern or self.DEFAULT_OUTPUT_PATTERN + + self.rejected_key = rejected_key + self.reason_key = reason_key + + self.model_key = prepare_model(model_type='api', + model=api_model, + endpoint=api_endpoint, + response_path=response_path, + **model_params) + self.try_num = try_num + self.sampling_params = sampling_params + + def build_input(self, sample): + mapping = { + 'query': sample[self.query_key], + 'response': sample[self.response_key], + 'reference': sample.get(self.text_key, '') + } + return self.input_template.format_map(mapping) + + def parse_output(self, raw_output): + logger.debug(raw_output) + match = re.match(self.output_pattern, raw_output, re.DOTALL) + if match: + return match.group(1).strip(), match.group(2).strip() + else: + return ('', '') + + def process_single(self, sample, rank=None): + client = get_model(self.model_key, rank=rank) + + messages = [{ + 'role': 'system', + 'content': self.system_prompt + }, { + 'role': 'user', + 'content': self.build_input(sample) + }] + + parsed_rejected, parsed_reason = '', '' + for _ in range(self.try_num): + try: + output = client(messages, **self.sampling_params) + parsed_rejected, parsed_reason = self.parse_output(output) + if parsed_rejected and parsed_reason: + break + except Exception as e: + logger.warning(f'Exception: {e}') + sample[self.rejected_key] = parsed_rejected + sample[self.reason_key] = parsed_reason + + return sample diff --git a/docs/Operators.md b/docs/Operators.md index 7717ba434..f24523dc5 100644 --- a/docs/Operators.md +++ b/docs/Operators.md @@ -11,7 +11,7 @@ The operators in Data-Juicer are categorized into 5 types. | Type | Number | Description | |-----------------------------------|:------:|-------------------------------------------------| | [ Formatter ]( #formatter ) | 9 | Discovers, loads, and canonicalizes source data | -| [ Mapper ]( #mapper ) | 58 | Edits and transforms samples | +| [ Mapper ]( #mapper ) | 59 | Edits and transforms samples | | [ Filter ]( #filter ) | 44 | Filters out low-quality samples | | [ Deduplicator ]( #deduplicator ) | 8 | Detects and removes duplicate samples | | [ Selector ]( #selector ) | 4 | Selects top samples based on ranking | @@ -57,9 +57,9 @@ All the specific operators are listed below, each featured with several capabili | Operator | Tags | Description | Source code | Unit tests | |------------------------------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|------------------------------------------------------------------------------------|------------------------------------------------------------------------------------| | audio_ffmpeg_wrapped_mapper | ![Audio](https://img.shields.io/badge/Audio-0DA64F?style=plastic) | Simple wrapper to run a FFmpeg audio filter | [code](../data_juicer/ops/mapper/audio_ffmpeg_wrapped_mapper.py) | [tests](../tests/ops/mapper/test_audio_ffmpeg_wrapped_mapper.py) | -| calibrate_qa_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) ![GPU](https://img.shields.io/badge/GPU-F27649?style=plastic) | Calibrate question-answer pairs based on reference text | [code](../data_juicer/ops/mapper/calibrate_qa_mapper.py) | [tests](../tests/ops/mapper/test_calibrate_qa_mapper.py) | -| calibrate_query_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) ![GPU](https://img.shields.io/badge/GPU-F27649?style=plastic) | Calibrate query in question-answer pairs based on reference text | [code](../data_juicer/ops/mapper/calibrate_query_mapper.py) | [tests](../tests/ops/mapper/test_calibrate_query_mapper.py) | -| calibrate_response_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) ![GPU](https://img.shields.io/badge/GPU-F27649?style=plastic) | Calibrate response in question-answer pairs based on reference text | [code](../data_juicer/ops/mapper/calibrate_response_mapper.py) | [tests](../tests/ops/mapper/test_calibrate_response_mapper.py) | +| calibrate_qa_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Calibrate question-answer pairs based on reference text | [code](../data_juicer/ops/mapper/calibrate_qa_mapper.py) | [tests](../tests/ops/mapper/test_calibrate_qa_mapper.py) | +| calibrate_query_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Calibrate query in question-answer pairs based on reference text | [code](../data_juicer/ops/mapper/calibrate_query_mapper.py) | [tests](../tests/ops/mapper/test_calibrate_query_mapper.py) | +| calibrate_response_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Calibrate response in question-answer pairs based on reference text | [code](../data_juicer/ops/mapper/calibrate_response_mapper.py) | [tests](../tests/ops/mapper/test_calibrate_response_mapper.py) | | chinese_convert_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Converts Chinese between Traditional Chinese, Simplified Chinese and Japanese Kanji (by [opencc](https://github.com/BYVoid/OpenCC)) | [code](../data_juicer/ops/mapper/chinese_convert_mapper.py) | [tests](../tests/ops/mapper/test_chinese_convert_mapper.py) | | clean_copyright_mapper | ![Code](https://img.shields.io/badge/Code-590F08?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Removes copyright notice at the beginning of code files (must contain the word *copyright*) | [code](../data_juicer/ops/mapper/clean_copyright_mapper.py) | [tests](../tests/ops/mapper/test_clean_copyright_mapper.py) | | clean_email_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Removes email information | [code](../data_juicer/ops/mapper/clean_email_mapper.py) | [tests](../tests/ops/mapper/test_clean_email_mapper.py) | @@ -86,6 +86,7 @@ All the specific operators are listed below, each featured with several capabili | optimize_qa_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) ![GPU](https://img.shields.io/badge/GPU-F27649?style=plastic) | Optimize both the query and response in question-answering samples. | [code](../data_juicer/ops/mapper/optimize_qa_mapper.py) | [tests](../tests/ops/mapper/test_optimize_qa_mapper.py) | | optimize_query_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) ![GPU](https://img.shields.io/badge/GPU-F27649?style=plastic) | Optimize the query in question-answering samples. | [code](../data_juicer/ops/mapper/optimize_query_mapper.py) | [tests](../tests/ops/mapper/test_optimize_query_mapper.py) | | optimize_response_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) ![GPU](https://img.shields.io/badge/GPU-F27649?style=plastic) | Optimize the response in question-answering samples. | [code](../data_juicer/ops/mapper/optimize_response_mapper.py) | [tests](../tests/ops/mapper/test_optimize_response_mapper.py) | +| pair_preference_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Construct paired preference samples. | [code](../data_juicer/ops/mapper/pair_preference_mapper.py) | [tests](../tests/ops/mapper/test_pair_preference_mapper.py) | | punctuation_normalization_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Normalizes various Unicode punctuations to their ASCII equivalents | [code](../data_juicer/ops/mapper/punctuation_normalization_mapper.py) | [tests](../tests/ops/mapper/test_punctuation_normalization_mapper.py) | | remove_bibliography_mapper | ![LaTeX](https://img.shields.io/badge/LaTeX-D99379?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Removes the bibliography of TeX documents | [code](../data_juicer/ops/mapper/remove_bibliography_mapper.py) | [tests](../tests/ops/mapper/test_remove_bibliography_mapper.py) | | remove_comments_mapper | ![LaTeX](https://img.shields.io/badge/LaTeX-D99379?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Removes the comments of TeX documents | [code](../data_juicer/ops/mapper/remove_comments_mapper.py) | [tests](../tests/ops/mapper/test_remove_comments_mapper.py) | diff --git a/docs/Operators_ZH.md b/docs/Operators_ZH.md index 81aee2149..c771a30e9 100644 --- a/docs/Operators_ZH.md +++ b/docs/Operators_ZH.md @@ -11,7 +11,7 @@ Data-Juicer 中的算子分为以下 5 种类型。 | 类型 | 数量 | 描述 | |------------------------------------|:--:|---------------| | [ Formatter ]( #formatter ) | 9 | 发现、加载、规范化原始数据 | -| [ Mapper ]( #mapper ) | 58 | 对数据样本进行编辑和转换 | +| [ Mapper ]( #mapper ) | 59 | 对数据样本进行编辑和转换 | | [ Filter ]( #filter ) | 44 | 过滤低质量样本 | | [ Deduplicator ]( #deduplicator ) | 8 | 识别、删除重复样本 | | [ Selector ]( #selector ) | 4 | 基于排序选取高质量样本 | @@ -85,6 +85,7 @@ Data-Juicer 中的算子分为以下 5 种类型。 | optimize_qa_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) ![GPU](https://img.shields.io/badge/GPU-F27649?style=plastic) | 指令优化,优化问题和答案 | [code](../data_juicer/ops/mapper/optimize_qa_mapper.py) | [tests](../tests/ops/mapper/test_optimize_qa_mapper.py) | | optimize_query_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) ![GPU](https://img.shields.io/badge/GPU-F27649?style=plastic) | 指令优化,优化 query | [code](../data_juicer/ops/mapper/optimize_query_mapper.py) | [tests](../tests/ops/mapper/test_optimize_query_mapper.py) | | optimize_response_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) ![GPU](https://img.shields.io/badge/GPU-F27649?style=plastic) | 指令优化,优化 response | [code](../data_juicer/ops/mapper/optimize_response_mapper.py) | [tests](../tests/ops/mapper/test_optimize_response_mapper.py) | +| pair_preference_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 构造配对的偏好样本 | [code](../data_juicer/ops/mapper/pair_preference_mapper.py) | [tests](../tests/ops/mapper/test_pair_preference_mapper.py) | | punctuation_normalization_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 将各种 Unicode 标点符号标准化为其 ASCII 等效项 | [code](../data_juicer/ops/mapper/punctuation_normalization_mapper.py) | [tests](../tests/ops/mapper/test_punctuation_normalization_mapper.py) | | remove_bibliography_mapper | ![LaTeX](https://img.shields.io/badge/LaTeX-D99379?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 删除 TeX 文档的参考文献 | [code](../data_juicer/ops/mapper/remove_bibliography_mapper.py) | [tests](../tests/ops/mapper/test_remove_bibliography_mapper.py) | | remove_comments_mapper | ![LaTeX](https://img.shields.io/badge/LaTeX-D99379?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 删除 TeX 文档中的注释 | [code](../data_juicer/ops/mapper/remove_comments_mapper.py) | [tests](../tests/ops/mapper/test_remove_comments_mapper.py) | diff --git a/tests/ops/mapper/test_calibrate_qa_mapper.py b/tests/ops/mapper/test_calibrate_qa_mapper.py index ea237093b..5755ed2b1 100644 --- a/tests/ops/mapper/test_calibrate_qa_mapper.py +++ b/tests/ops/mapper/test_calibrate_qa_mapper.py @@ -76,7 +76,7 @@ def test(self): def test_args(self): op = CalibrateQAMapper( api_model='qwen2.5-72b-instruct', - api_url= + api_endpoint= 'https://dashscope.aliyuncs.com/compatible-mode/v1/chat/completions', response_path='choices.0.message.content') self._run_op(op) diff --git a/tests/ops/mapper/test_calibrate_query_mapper.py b/tests/ops/mapper/test_calibrate_query_mapper.py index 8229c10ed..f95b6c5dc 100644 --- a/tests/ops/mapper/test_calibrate_query_mapper.py +++ b/tests/ops/mapper/test_calibrate_query_mapper.py @@ -69,8 +69,8 @@ def _run_op(self, api_model, response_path=None): def test(self): # before runing this test, set below environment variables: - # export DJ_API_URL=https://dashscope.aliyuncs.com/compatible-mode/v1/chat/completions - # export DJ_API_KEY=your_key + # export OPENAI_API_URL=https://dashscope.aliyuncs.com/compatible-mode/v1 + # export OPENAI_API_KEY=your_key self._run_op('qwen2.5-72b-instruct') diff --git a/tests/ops/mapper/test_calibrate_response_mapper.py b/tests/ops/mapper/test_calibrate_response_mapper.py index e092d4c48..4a9ddbe11 100644 --- a/tests/ops/mapper/test_calibrate_response_mapper.py +++ b/tests/ops/mapper/test_calibrate_response_mapper.py @@ -70,8 +70,8 @@ def _run_op(self, api_model, response_path=None): def test(self): # before runing this test, set below environment variables: - # export DJ_API_URL=https://dashscope.aliyuncs.com/compatible-mode/v1/chat/completions - # export DJ_API_KEY=your_key + # export OPENAI_API_URL=https://dashscope.aliyuncs.com/compatible-mode/v1 + # export OPENAI_API_KEY=your_key self._run_op('qwen2.5-72b-instruct') diff --git a/tests/ops/mapper/test_extract_nickname_mapper.py b/tests/ops/mapper/test_extract_nickname_mapper.py index 635801155..2911a1002 100644 --- a/tests/ops/mapper/test_extract_nickname_mapper.py +++ b/tests/ops/mapper/test_extract_nickname_mapper.py @@ -49,8 +49,8 @@ def _run_op(self, api_model, response_path=None): def test(self): # before runing this test, set below environment variables: - # export DJ_API_URL=https://dashscope.aliyuncs.com/compatible-mode/v1/chat/completions - # export DJ_API_KEY=your_key + # export OPENAI_API_URL=https://dashscope.aliyuncs.com/compatible-mode/v1 + # export OPENAI_API_KEY=your_key self._run_op('qwen2.5-72b-instruct') diff --git a/tests/ops/mapper/test_pair_preference_mapper.py b/tests/ops/mapper/test_pair_preference_mapper.py new file mode 100644 index 000000000..93cd4d877 --- /dev/null +++ b/tests/ops/mapper/test_pair_preference_mapper.py @@ -0,0 +1,57 @@ +import unittest + +from loguru import logger + +from data_juicer.ops.mapper.pair_preference_mapper import PairPreferenceMapper +from data_juicer.utils.unittest_utils import (SKIPPED_TESTS, + DataJuicerTestCaseBase) + + +# Skip tests for this OP because the API call is not configured yet. +# These tests have been tested locally. +@SKIPPED_TESTS.register_module() +class PairPreferenceMapperTest(DataJuicerTestCaseBase): + + def _run_op(self, op, samples): + for sample in samples: + result = op.process(sample) + logger.info(f'Output results: {result}') + self.assertNotEqual(result['rejected_response'], '') + self.assertNotEqual(result['reason'], '') + + def test(self): + # before runing this test, set below environment variables: + # export OPENAI_BASE_URL=https://dashscope.aliyuncs.com/compatible-mode/v1 + # export OPENAI_API_KEY=your_key + + reference = '王八十娘:小远城王八十的娘亲,李莲花刚到小远城时被方多病偷掉钱袋找小乞丐问路时,刚好发现王八十娘被另一个小乞丐撞到便将她扶起,结识了王八十。\n朴二黄:灵山派管家,方多病小厮旺福的父亲。真实身份是金鸳盟的奔雷手辛雷,离开金鸳盟后,用假名朴二黄在灵山派当管家。因害怕王青山看穿他的身份,设计杀死了灵山派的王青山。被捕后识破了李莲花的真实身份,最后在攻击李莲花的时候被方多病情急之下杀死。' # noqa: E501 + samples = [{ + 'text': reference, + 'query': '李莲花,你认识方多病吗?', + 'response': '方多病啊,那可是我的好友。' + }] + op = PairPreferenceMapper(api_model='qwen2.5-72b-instruct') + self._run_op(op, samples) + + def test_no_reference(self): + samples = [{'query': '李莲花,你认识方多病吗?', 'response': '方多病啊,那可是我的好友。'}] + system_prompt = ('修改问答对中的回答,在语言风格、事实性、人物身份、立场等任一方面与原回答相反。' + '必须按照以下标记格式输出,不要输出其他多余内容。\n' + '【回答】\n' + '生成的新回答\n' + '【原因】\n' + '生成该回答的原因') + input_template = ('以下是原始问答对:\n' + '【问题】\n' + '{query}\n' + '【回答】\n' + '{response}') + + op = PairPreferenceMapper(api_model='qwen2.5-72b-instruct', + system_prompt=system_prompt, + input_template=input_template) + self._run_op(op, samples) + + +if __name__ == '__main__': + unittest.main()