From 63d430a7c9b785aa18c659416ff50eea5f8a6a73 Mon Sep 17 00:00:00 2001 From: null <3213204+drcege@users.noreply.github.com> Date: Wed, 23 Oct 2024 15:32:32 +0800 Subject: [PATCH 001/118] add api call --- .../filter/image_pair_similarity_filter.py | 2 +- data_juicer/utils/model_utils.py | 91 +++++++++++++++++++ 2 files changed, 92 insertions(+), 1 deletion(-) diff --git a/data_juicer/ops/filter/image_pair_similarity_filter.py b/data_juicer/ops/filter/image_pair_similarity_filter.py index 3299f9ad2..de576f07e 100644 --- a/data_juicer/ops/filter/image_pair_similarity_filter.py +++ b/data_juicer/ops/filter/image_pair_similarity_filter.py @@ -30,7 +30,7 @@ def __init__(self, *args, **kwargs): """ - Initialization method. + Initialization method. :param hf_clip: clip model name on huggingface to compute the similarity between image and text. diff --git a/data_juicer/utils/model_utils.py b/data_juicer/utils/model_utils.py index cda046b81..08a0c37c2 100644 --- a/data_juicer/utils/model_utils.py +++ b/data_juicer/utils/model_utils.py @@ -5,6 +5,7 @@ from typing import Optional, Union import multiprocess as mp +import requests import wget from loguru import logger @@ -589,6 +590,95 @@ def prepare_opencv_classifier(model_path): return model +class APIModel: + + def __init__(self, + *, + api_url=None, + api_key=None, + response_path='choices.0.message.content'): + if api_url is None: + api_url = os.getenv('DJ_API_URL') + if api_url is None: + base_url = os.getenv('OPENAI_BASE_URL', + 'https://api.openai.com/v1') + api_url = base_url.rstrip('/') + '/chat/completions' + self.api_url = api_url + + if api_key is None: + api_key = os.getenv('DJ_API_KEY') or os.getenv('OPENAI_API_KEY') + self.api_key = api_key + + self.headers = { + 'Content-Type': 'application/json', + 'Authorization': f'Bearer {self.api_key}' + } + self.response_path = response_path + + def __call__(self, *, messages, model, **kwargs): + """Sends messages to the configured API model and returns the parsed response. + + :param messages: The messages to send to the API. + :param model: The model to be used for generating responses. + :param kwargs: Additional parameters for the API request. + + :return: The parsed response from the API, or None if an error occurs. + """ + payload = { + 'model': model, + 'messages': messages, + **kwargs, + } + try: + response = requests.post(self.api_url, + json=payload, + headers=self.headers) + response.raise_for_status() + result = response.json() + return self.nested_access(result, self.response_path) + except Exception as e: + logger.exception(e) + return None + + @staticmethod + def nested_access(data, path): + """Access nested data using a dot-separated path. + + :param data: The data structure to access. + :param path: A dot-separated string representing the path to access. + :return: The value at the specified path, if it exists. + """ + keys = path.split('.') + for key in keys: + # Convert string keys to integers if they are numeric + key = int(key) if key.isdigit() else key + data = data[key] + return data + + +def prepare_api_model(*, + api_url=None, + api_key=None, + response_path='choices.0.message.content'): + """Creates a callable API model for interacting with the OpenAI-compatible API. + + This callable object supports custom result parsing and is suitable for use + with incompatible proxy servers. + + :param api_url: The URL of the API. If not provided, it will fallback + to the environment variable or a default OpenAI URL. + :param api_key: The API key for authorization. If not provided, it will + fallback to the environment variable. + :param response_path: The path to extract content from the API response. + Defaults to 'choices.0.message.content'. + :return: A callable API model object that can be used to send messages + and receive responses. + """ + return APIModel(api_url=api_url, + api_key=api_key, + response_path=response_path) + + MODEL_FUNCTION_MAPPING = { 'fasttext': prepare_fasttext_model, 'sentencepiece': prepare_sentencepiece_for_lang, @@ -602,6 +692,7 @@ def prepare_opencv_classifier(model_path): 'recognizeAnything': prepare_recognizeAnything_model, 'vllm': prepare_vllm_model, 'opencv_classifier': prepare_opencv_classifier, + 'api': prepare_api_model, } From 6720da42991a264013ec52a895e23b17e518e081 Mon Sep 17 00:00:00 2001 From: null <3213204+drcege@users.noreply.github.com> Date: Thu, 24 Oct 2024 17:59:53 +0800 Subject: [PATCH 002/118] add call_api ops --- data_juicer/ops/mapper/__init__.py | 31 +++-- data_juicer/ops/mapper/calibrate_qa_mapper.py | 108 ++++++++++++++++++ .../ops/mapper/calibrate_query_mapper.py | 19 +++ .../ops/mapper/calibrate_response_mapper.py | 19 +++ data_juicer/utils/model_utils.py | 19 +-- 5 files changed, 176 insertions(+), 20 deletions(-) create mode 100644 data_juicer/ops/mapper/calibrate_qa_mapper.py create mode 100644 data_juicer/ops/mapper/calibrate_query_mapper.py create mode 100644 data_juicer/ops/mapper/calibrate_response_mapper.py diff --git a/data_juicer/ops/mapper/__init__.py b/data_juicer/ops/mapper/__init__.py index eb814b374..95d4901e2 100644 --- a/data_juicer/ops/mapper/__init__.py +++ b/data_juicer/ops/mapper/__init__.py @@ -1,16 +1,17 @@ # yapf: disable -from . import (audio_ffmpeg_wrapped_mapper, chinese_convert_mapper, - clean_copyright_mapper, clean_email_mapper, clean_html_mapper, - clean_ip_mapper, clean_links_mapper, expand_macro_mapper, - extract_qa_mapper, fix_unicode_mapper, - generate_instruction_mapper, image_blur_mapper, - image_captioning_from_gpt4v_mapper, image_captioning_mapper, - image_diffusion_mapper, image_face_blur_mapper, - image_tagging_mapper, nlpaug_en_mapper, nlpcda_zh_mapper, - optimize_instruction_mapper, punctuation_normalization_mapper, - remove_bibliography_mapper, remove_comments_mapper, - remove_header_mapper, remove_long_words_mapper, - remove_non_chinese_character_mapper, +from . import (audio_ffmpeg_wrapped_mapper, calibrate_qa_mapper, + calibrate_query_mapper, calibrate_response_mapper, + chinese_convert_mapper, clean_copyright_mapper, + clean_email_mapper, clean_html_mapper, clean_ip_mapper, + clean_links_mapper, expand_macro_mapper, extract_qa_mapper, + fix_unicode_mapper, generate_instruction_mapper, + image_blur_mapper, image_captioning_from_gpt4v_mapper, + image_captioning_mapper, image_diffusion_mapper, + image_face_blur_mapper, image_tagging_mapper, nlpaug_en_mapper, + nlpcda_zh_mapper, optimize_instruction_mapper, + punctuation_normalization_mapper, remove_bibliography_mapper, + remove_comments_mapper, remove_header_mapper, + remove_long_words_mapper, remove_non_chinese_character_mapper, remove_repeat_sentences_mapper, remove_specific_chars_mapper, remove_table_text_mapper, remove_words_with_incorrect_substrings_mapper, @@ -27,6 +28,9 @@ video_tagging_from_frames_mapper, whitespace_normalization_mapper) from .audio_ffmpeg_wrapped_mapper import AudioFFmpegWrappedMapper +from .calibrate_qa_mapper import CalibrateQAMapper +from .calibrate_query_mapper import CalibrateQueryMapper +from .calibrate_response_mapper import CalibrateResponseMapper from .chinese_convert_mapper import ChineseConvertMapper from .clean_copyright_mapper import CleanCopyrightMapper from .clean_email_mapper import CleanEmailMapper @@ -79,6 +83,9 @@ from .whitespace_normalization_mapper import WhitespaceNormalizationMapper __all__ = [ + 'CalibrateQAMapper', + 'CalibrateQueryMapper', + 'CalibrateResponseMapper', 'VideoCaptioningFromAudioMapper', 'VideoTaggingFromAudioMapper', 'ImageCaptioningFromGPT4VMapper', diff --git a/data_juicer/ops/mapper/calibrate_qa_mapper.py b/data_juicer/ops/mapper/calibrate_qa_mapper.py new file mode 100644 index 000000000..5bb2a1e45 --- /dev/null +++ b/data_juicer/ops/mapper/calibrate_qa_mapper.py @@ -0,0 +1,108 @@ +import re +from typing import Dict, Optional + +from data_juicer.ops.base_op import OPERATORS, UNFORKABLE, Mapper +from data_juicer.utils.model_utils import get_model, prepare_model + +OP_NAME = 'calibrate_qa_mapper' + + +# TODO: LLM-based inference. +@UNFORKABLE.register_module(OP_NAME) +@OPERATORS.register_module(OP_NAME) +class CalibrateQAMapper(Mapper): + """ + Mapper to calibrate question-answer pairs. + """ + + # avoid leading whitespace + DEFAULT_SYSTEM_PROMPT = ('请根据提供的【参考信息】对【问题】和【回答】进行校准,使其更加详细、准确。\n' + '按照以下格式输出:\n' + '【问题】\n' + '优化后的问题\n' + '【回答】\n' + '优化后的回答') + DEFAULT_INPUT_TEMPLATE = '{reference}\n{qa_pair}' + DEFAULT_REFERENCE_TEMPLATE = '【参考信息】\n{}' + DEFAULT_QA_PAIR_TEMPLATE = '【问题】\n{}\n【回答】\n{}' + DEFAULT_OUTPUT_PATTERN = r'【问题】\s*(.*?)\s*【回答】\s*(.*)' + + def __init__(self, + *, + api_model: str, + api_url: Optional[str] = None, + api_key: Optional[str] = None, + response_path: Optional[str] = None, + system_prompt: Optional[str] = None, + input_template: Optional[str] = None, + reference_template: Optional[str] = None, + qa_pair_template: Optional[str] = None, + output_pattern: Optional[str] = None, + api_params: Dict = {}, + **kwargs): + """ + Initialization method. + + :param api_model: API model name. + :param api_url: API URL. Defaults to DJ_API_URL environment variable. + :param api_key: API key. Defaults to DJ_API_KEY environment variable. + :param response_path: Path to extract content from the API response. + :param system_prompt: System prompt for the calibration task. + :param input_template: Template for building the model input. + :param reference_template: Template for formatting the reference text. + :param qa_pair_template: Template for formatting question-answer pairs. + :param output_pattern: Pattern for parsing model output. + :param api_params: Extra API parameters. + :param kwargs: Extra keyword arguments. + """ + super().__init__(**kwargs) + + self.system_prompt = system_prompt or self.DEFAULT_SYSTEM_PROMPT + self.input_template = input_template or self.DEFAULT_INPUT_TEMPLATE + self.reference_template = reference_template or \ + self.DEFAULT_REFERENCE_TEMPLATE + self.qa_pair_template = qa_pair_template or \ + self.DEFAULT_QA_PAIR_TEMPLATE + self.output_pattern = output_pattern or self.DEFAULT_OUTPUT_PATTERN + + self.api_params = api_params + self.model_key = prepare_model(model_type='api', + api_model=api_model, + api_url=api_url, + api_key=api_key, + response_path=response_path) + + def build_input(self, sample): + reference = self.reference_template.format(sample[self.text_key]) + qa_pair = self.qa_pair_template.format(sample[self.query_key], + sample[self.response_key]) + input_prompt = self.input_template.format(reference=reference, + qa_pair=qa_pair) + return input_prompt + + def parse_output(self, raw_output): + match = re.match(self.output_pattern, raw_output) + if match: + return match.group(1).strip(), match.group(2).strip() + else: + return None, None + + def process_single(self, sample=None, rank=None): + client = get_model(self.model_key, rank=rank) + + messages = [{ + 'role': 'system', + 'content': self.system_prompt + }, { + 'role': 'user', + 'content': self.build_input(sample) + }] + output = client(messages, **self.api_params) + + parsed_q, parsed_a = self.parse_output(output) + if parsed_q: + sample[self.query_key] = parsed_q + if parsed_a: + sample[self.response_key] = parsed_a + + return sample diff --git a/data_juicer/ops/mapper/calibrate_query_mapper.py b/data_juicer/ops/mapper/calibrate_query_mapper.py new file mode 100644 index 000000000..d573a2c76 --- /dev/null +++ b/data_juicer/ops/mapper/calibrate_query_mapper.py @@ -0,0 +1,19 @@ +from data_juicer.ops.base_op import OPERATORS, UNFORKABLE +from data_juicer.ops.mapper.calibrate_qa_mapper import CalibrateQAMapper + +OP_NAME = 'calibrate_query_mapper' + + +# TODO: LLM-based inference. +@UNFORKABLE.register_module(OP_NAME) +@OPERATORS.register_module(OP_NAME) +class CalibrateQueryMapper(CalibrateQAMapper): + """ + Mapper to calibrate only query in question-answer pairs. + """ + + DEFAULT_SYSTEM_PROMPT = '请根据提供的【参考信息】对问答对中的【问题】进行校准,\ + 使其更加详细、准确,且仍可以由原答案回答。只输出优化后的问题,不要输出多余内容。' + + def parse_output(self, raw_output): + return raw_output.strip(), None diff --git a/data_juicer/ops/mapper/calibrate_response_mapper.py b/data_juicer/ops/mapper/calibrate_response_mapper.py new file mode 100644 index 000000000..f2bcd0afa --- /dev/null +++ b/data_juicer/ops/mapper/calibrate_response_mapper.py @@ -0,0 +1,19 @@ +from data_juicer.ops.base_op import OPERATORS, UNFORKABLE +from data_juicer.ops.mapper.calibrate_qa_mapper import CalibrateQAMapper + +OP_NAME = 'calibrate_response_mapper' + + +# TODO: LLM-based inference. +@UNFORKABLE.register_module(OP_NAME) +@OPERATORS.register_module(OP_NAME) +class CalibrateResponseMapper(CalibrateQAMapper): + """ + Mapper to calibrate only response in question-answer pairs. + """ + + DEFAULT_SYSTEM_PROMPT = '请根据提供的【参考信息】对问答对中的【回答】进行校准,\ + 使其更加详细、准确,且仍可以回答原问题。只输出优化后的回答,不要输出多余内容。' + + def parse_output(self, raw_output): + return None, raw_output.strip() diff --git a/data_juicer/utils/model_utils.py b/data_juicer/utils/model_utils.py index 08a0c37c2..bd491311b 100644 --- a/data_juicer/utils/model_utils.py +++ b/data_juicer/utils/model_utils.py @@ -594,9 +594,12 @@ class APIModel: def __init__(self, *, + api_model, api_url=None, api_key=None, - response_path='choices.0.message.content'): + response_path=None): + self.api_model = api_model + if api_url is None: api_url = os.getenv('DJ_API_URL') if api_url is None: @@ -609,13 +612,16 @@ def __init__(self, api_key = os.getenv('DJ_API_KEY') or os.getenv('OPENAI_API_KEY') self.api_key = api_key + if response_path is None: + response_path = 'choices.0.message.content' + self.response_path = response_path + self.headers = { 'Content-Type': 'application/json', 'Authorization': f'Bearer {self.api_key}' } - self.response_path = response_path - def __call__(self, *, messages, model, **kwargs): + def __call__(self, *, messages, **kwargs): """Sends messages to the configured API model and returns the parsed response. :param messages: The messages to send to the API. @@ -625,7 +631,7 @@ def __call__(self, *, messages, model, **kwargs): :return: The parsed response from the API, or None if an error occurs. """ payload = { - 'model': model, + 'model': self.model, 'messages': messages, **kwargs, } @@ -656,10 +662,7 @@ def nested_access(data, path): return data -def prepare_api_model(*, - api_url=None, - api_key=None, - response_path='choices.0.message.content'): +def prepare_api_model(*, api_url=None, api_key=None, response_path=None): """Creates a callable API model for interacting with the OpenAI-compatible API. This callable object supports custom result parsing and is suitable for use From 8daa6e167044e0dffabb77d8d305afb12dfaeafd Mon Sep 17 00:00:00 2001 From: "gece.gc" Date: Tue, 29 Oct 2024 03:07:12 +0000 Subject: [PATCH 003/118] clean --- data_juicer/ops/deduplicator/__init__.py | 11 +-- data_juicer/ops/filter/__init__.py | 83 ++++-------------- data_juicer/ops/mapper/__init__.py | 102 +++++------------------ data_juicer/ops/selector/__init__.py | 2 - 4 files changed, 41 insertions(+), 157 deletions(-) diff --git a/data_juicer/ops/deduplicator/__init__.py b/data_juicer/ops/deduplicator/__init__.py index 69f73b361..56aec0e10 100644 --- a/data_juicer/ops/deduplicator/__init__.py +++ b/data_juicer/ops/deduplicator/__init__.py @@ -1,7 +1,3 @@ -from . import (document_deduplicator, document_minhash_deduplicator, - document_simhash_deduplicator, image_deduplicator, - ray_document_deduplicator, ray_image_deduplicator, - ray_video_deduplicator, video_deduplicator) from .document_deduplicator import DocumentDeduplicator from .document_minhash_deduplicator import DocumentMinhashDeduplicator from .document_simhash_deduplicator import DocumentSimhashDeduplicator @@ -13,7 +9,8 @@ from .video_deduplicator import VideoDeduplicator __all__ = [ - 'VideoDeduplicator', 'RayBasicDeduplicator', 'DocumentMinhashDeduplicator', - 'RayImageDeduplicator', 'RayDocumentDeduplicator', 'DocumentDeduplicator', - 'ImageDeduplicator', 'DocumentSimhashDeduplicator', 'RayVideoDeduplicator' + 'DocumentDeduplicator', 'DocumentMinhashDeduplicator', + 'DocumentSimhashDeduplicator', 'ImageDeduplicator', 'RayBasicDeduplicator', + 'RayDocumentDeduplicator', 'RayImageDeduplicator', 'RayVideoDeduplicator', + 'VideoDeduplicator' ] diff --git a/data_juicer/ops/filter/__init__.py b/data_juicer/ops/filter/__init__.py index f21c81546..718f06cd3 100644 --- a/data_juicer/ops/filter/__init__.py +++ b/data_juicer/ops/filter/__init__.py @@ -1,25 +1,3 @@ -# yapf: disable -from . import (alphanumeric_filter, audio_duration_filter, - audio_nmf_snr_filter, audio_size_filter, - average_line_length_filter, character_repetition_filter, - flagged_words_filter, image_aesthetics_filter, - image_aspect_ratio_filter, image_face_count_filter, - image_face_ratio_filter, image_nsfw_filter, - image_pair_similarity_filter, image_shape_filter, - image_size_filter, image_text_matching_filter, - image_text_similarity_filter, image_watermark_filter, - language_id_score_filter, maximum_line_length_filter, - perplexity_filter, phrase_grounding_recall_filter, - special_characters_filter, specified_field_filter, - specified_numeric_field_filter, stopwords_filter, suffix_filter, - text_action_filter, text_entity_dependency_filter, - text_length_filter, token_num_filter, video_aesthetics_filter, - video_aspect_ratio_filter, video_duration_filter, - video_frames_text_similarity_filter, video_motion_score_filter, - video_nsfw_filter, video_ocr_area_ratio_filter, - video_resolution_filter, video_tagging_from_frames_filter, - video_watermark_filter, word_repetition_filter, - words_num_filter) from .alphanumeric_filter import AlphanumericFilter from .audio_duration_filter import AudioDurationFilter from .audio_nmf_snr_filter import AudioNMFSNRFilter @@ -66,49 +44,20 @@ from .words_num_filter import WordsNumFilter __all__ = [ - 'ImageTextSimilarityFilter', - 'VideoAspectRatioFilter', - 'ImageTextMatchingFilter', - 'ImageNSFWFilter', - 'TokenNumFilter', - 'TextLengthFilter', - 'SpecifiedNumericFieldFilter', - 'AudioNMFSNRFilter', - 'VideoAestheticsFilter', - 'PerplexityFilter', - 'PhraseGroundingRecallFilter', - 'MaximumLineLengthFilter', - 'AverageLineLengthFilter', - 'SpecifiedFieldFilter', - 'VideoTaggingFromFramesFilter', - 'TextEntityDependencyFilter', - 'VideoResolutionFilter', - 'AlphanumericFilter', - 'ImageWatermarkFilter', - 'ImageAestheticsFilter', - 'AudioSizeFilter', - 'StopWordsFilter', - 'CharacterRepetitionFilter', - 'ImageShapeFilter', - 'VideoDurationFilter', - 'TextActionFilter', - 'VideoOcrAreaRatioFilter', - 'VideoNSFWFilter', - 'SpecialCharactersFilter', - 'VideoFramesTextSimilarityFilter', - 'ImageAspectRatioFilter', - 'AudioDurationFilter', - 'LanguageIDScoreFilter', - 'SuffixFilter', - 'ImageSizeFilter', - 'VideoWatermarkFilter', - 'WordsNumFilter', - 'ImageFaceCountFilter', - 'ImageFaceRatioFilter', - 'FlaggedWordFilter', - 'WordRepetitionFilter', - 'VideoMotionScoreFilter', - 'ImagePairSimilarityFilter' + 'AlphanumericFilter', 'AudioDurationFilter', 'AudioNMFSNRFilter', + 'AudioSizeFilter', 'AverageLineLengthFilter', 'CharacterRepetitionFilter', + 'FlaggedWordFilter', 'ImageAestheticsFilter', 'ImageAspectRatioFilter', + 'ImageFaceCountFilter', 'ImageFaceRatioFilter', 'ImageNSFWFilter', + 'ImagePairSimilarityFilter', 'ImageShapeFilter', 'ImageSizeFilter', + 'ImageTextMatchingFilter', 'ImageTextSimilarityFilter', + 'ImageWatermarkFilter', 'LanguageIDScoreFilter', 'MaximumLineLengthFilter', + 'PerplexityFilter', 'PhraseGroundingRecallFilter', + 'SpecialCharactersFilter', 'SpecifiedFieldFilter', + 'SpecifiedNumericFieldFilter', 'StopWordsFilter', 'SuffixFilter', + 'TextActionFilter', 'TextEntityDependencyFilter', 'TextLengthFilter', + 'TokenNumFilter', 'VideoAestheticsFilter', 'VideoAspectRatioFilter', + 'VideoDurationFilter', 'VideoFramesTextSimilarityFilter', + 'VideoMotionScoreFilter', 'VideoNSFWFilter', 'VideoOcrAreaRatioFilter', + 'VideoResolutionFilter', 'VideoTaggingFromFramesFilter', + 'VideoWatermarkFilter', 'WordRepetitionFilter', 'WordsNumFilter' ] - -# yapf: enable diff --git a/data_juicer/ops/mapper/__init__.py b/data_juicer/ops/mapper/__init__.py index 95d4901e2..d54a13b77 100644 --- a/data_juicer/ops/mapper/__init__.py +++ b/data_juicer/ops/mapper/__init__.py @@ -1,32 +1,3 @@ -# yapf: disable -from . import (audio_ffmpeg_wrapped_mapper, calibrate_qa_mapper, - calibrate_query_mapper, calibrate_response_mapper, - chinese_convert_mapper, clean_copyright_mapper, - clean_email_mapper, clean_html_mapper, clean_ip_mapper, - clean_links_mapper, expand_macro_mapper, extract_qa_mapper, - fix_unicode_mapper, generate_instruction_mapper, - image_blur_mapper, image_captioning_from_gpt4v_mapper, - image_captioning_mapper, image_diffusion_mapper, - image_face_blur_mapper, image_tagging_mapper, nlpaug_en_mapper, - nlpcda_zh_mapper, optimize_instruction_mapper, - punctuation_normalization_mapper, remove_bibliography_mapper, - remove_comments_mapper, remove_header_mapper, - remove_long_words_mapper, remove_non_chinese_character_mapper, - remove_repeat_sentences_mapper, remove_specific_chars_mapper, - remove_table_text_mapper, - remove_words_with_incorrect_substrings_mapper, - replace_content_mapper, sentence_split_mapper, - video_captioning_from_audio_mapper, - video_captioning_from_frames_mapper, - video_captioning_from_summarizer_mapper, - video_captioning_from_video_mapper, video_face_blur_mapper, - video_ffmpeg_wrapped_mapper, video_remove_watermark_mapper, - video_resize_aspect_ratio_mapper, - video_resize_resolution_mapper, video_split_by_duration_mapper, - video_split_by_key_frame_mapper, video_split_by_scene_mapper, - video_tagging_from_audio_mapper, - video_tagging_from_frames_mapper, - whitespace_normalization_mapper) from .audio_ffmpeg_wrapped_mapper import AudioFFmpegWrappedMapper from .calibrate_qa_mapper import CalibrateQAMapper from .calibrate_query_mapper import CalibrateQueryMapper @@ -83,56 +54,25 @@ from .whitespace_normalization_mapper import WhitespaceNormalizationMapper __all__ = [ - 'CalibrateQAMapper', - 'CalibrateQueryMapper', - 'CalibrateResponseMapper', - 'VideoCaptioningFromAudioMapper', - 'VideoTaggingFromAudioMapper', - 'ImageCaptioningFromGPT4VMapper', - 'PunctuationNormalizationMapper', - 'RemoveBibliographyMapper', - 'SentenceSplitMapper', - 'VideoSplitBySceneMapper', - 'CleanIpMapper', - 'CleanLinksMapper', - 'RemoveHeaderMapper', - 'RemoveTableTextMapper', - 'VideoRemoveWatermarkMapper', - 'RemoveRepeatSentencesMapper', - 'ImageDiffusionMapper', - 'ImageFaceBlurMapper', - 'VideoFFmpegWrappedMapper', - 'ChineseConvertMapper', - 'NlpcdaZhMapper', - 'OptimizeInstructionMapper', - 'ImageBlurMapper', - 'CleanCopyrightMapper', - 'RemoveNonChineseCharacterlMapper', - 'VideoSplitByKeyFrameMapper', - 'RemoveSpecificCharsMapper', - 'VideoResizeAspectRatioMapper', - 'CleanHtmlMapper', - 'WhitespaceNormalizationMapper', - 'VideoTaggingFromFramesMapper', - 'RemoveCommentsMapper', - 'ExpandMacroMapper', - 'ExtractQAMapper', - 'ImageCaptioningMapper', - 'RemoveWordsWithIncorrectSubstringsMapper', - 'VideoCaptioningFromVideoMapper', - 'VideoCaptioningFromSummarizerMapper', - 'GenerateInstructionMapper', - 'FixUnicodeMapper', - 'NlpaugEnMapper', - 'VideoCaptioningFromFramesMapper', - 'RemoveLongWordsMapper', - 'VideoResizeResolutionMapper', - 'CleanEmailMapper', - 'ReplaceContentMapper', - 'AudioFFmpegWrappedMapper', - 'VideoSplitByDurationMapper', - 'VideoFaceBlurMapper', - 'ImageTaggingMapper', + 'AudioFFmpegWrappedMapper', 'CalibrateQAMapper', 'CalibrateQueryMapper', + 'CalibrateResponseMapper', 'ChineseConvertMapper', 'CleanCopyrightMapper', + 'CleanEmailMapper', 'CleanHtmlMapper', 'CleanIpMapper', 'CleanLinksMapper', + 'ExpandMacroMapper', 'ExtractQAMapper', 'FixUnicodeMapper', + 'GenerateInstructionMapper', 'ImageBlurMapper', + 'ImageCaptioningFromGPT4VMapper', 'ImageCaptioningMapper', + 'ImageDiffusionMapper', 'ImageFaceBlurMapper', 'ImageTaggingMapper', + 'NlpaugEnMapper', 'NlpcdaZhMapper', 'OptimizeInstructionMapper', + 'PunctuationNormalizationMapper', 'RemoveBibliographyMapper', + 'RemoveCommentsMapper', 'RemoveHeaderMapper', 'RemoveLongWordsMapper', + 'RemoveNonChineseCharacterlMapper', 'RemoveRepeatSentencesMapper', + 'RemoveSpecificCharsMapper', 'RemoveTableTextMapper', + 'RemoveWordsWithIncorrectSubstringsMapper', 'ReplaceContentMapper', + 'SentenceSplitMapper', 'VideoCaptioningFromAudioMapper', + 'VideoCaptioningFromFramesMapper', 'VideoCaptioningFromSummarizerMapper', + 'VideoCaptioningFromVideoMapper', 'VideoFFmpegWrappedMapper', + 'VideoFaceBlurMapper', 'VideoRemoveWatermarkMapper', + 'VideoResizeAspectRatioMapper', 'VideoResizeResolutionMapper', + 'VideoSplitByDurationMapper', 'VideoSplitByKeyFrameMapper', + 'VideoSplitBySceneMapper', 'VideoTaggingFromAudioMapper', + 'VideoTaggingFromFramesMapper', 'WhitespaceNormalizationMapper' ] - -# yapf: enable diff --git a/data_juicer/ops/selector/__init__.py b/data_juicer/ops/selector/__init__.py index a90f6db8e..22df12987 100644 --- a/data_juicer/ops/selector/__init__.py +++ b/data_juicer/ops/selector/__init__.py @@ -1,5 +1,3 @@ -from . import (frequency_specified_field_selector, random_selector, - range_specified_field_selector, topk_specified_field_selector) from .frequency_specified_field_selector import FrequencySpecifiedFieldSelector from .random_selector import RandomSelector from .range_specified_field_selector import RangeSpecifiedFieldSelector From ef11951b82161b78c56d7193599978801ffa2733 Mon Sep 17 00:00:00 2001 From: "gece.gc" Date: Tue, 29 Oct 2024 03:12:38 +0000 Subject: [PATCH 004/118] minor update --- data_juicer/ops/base_op.py | 5 + data_juicer/ops/mapper/calibrate_qa_mapper.py | 9 +- .../ops/mapper/calibrate_query_mapper.py | 2 +- .../ops/mapper/calibrate_response_mapper.py | 2 +- data_juicer/utils/model_utils.py | 191 +++++++++--------- 5 files changed, 110 insertions(+), 99 deletions(-) diff --git a/data_juicer/ops/base_op.py b/data_juicer/ops/base_op.py index 690193175..918831504 100644 --- a/data_juicer/ops/base_op.py +++ b/data_juicer/ops/base_op.py @@ -133,6 +133,11 @@ def __init__(self, *args, **kwargs): self.image_key = kwargs.get('image_key', 'images') self.audio_key = kwargs.get('audio_key', 'audios') self.video_key = kwargs.get('video_key', 'videos') + + self.query_key = kwargs.get('query_key', 'query') + self.response_key = kwargs.get('response_key', 'response') + self.history_key = kwargs.get('history_key', 'history') + self.batch_size = kwargs.get('batch_size', 1000) # whether the model can be accelerated using cuda diff --git a/data_juicer/ops/mapper/calibrate_qa_mapper.py b/data_juicer/ops/mapper/calibrate_qa_mapper.py index 5bb2a1e45..7a3ed5836 100644 --- a/data_juicer/ops/mapper/calibrate_qa_mapper.py +++ b/data_juicer/ops/mapper/calibrate_qa_mapper.py @@ -28,8 +28,8 @@ class CalibrateQAMapper(Mapper): DEFAULT_OUTPUT_PATTERN = r'【问题】\s*(.*?)\s*【回答】\s*(.*)' def __init__(self, - *, api_model: str, + *, api_url: Optional[str] = None, api_key: Optional[str] = None, response_path: Optional[str] = None, @@ -38,7 +38,7 @@ def __init__(self, reference_template: Optional[str] = None, qa_pair_template: Optional[str] = None, output_pattern: Optional[str] = None, - api_params: Dict = {}, + api_params: Optional[Dict] = None, **kwargs): """ Initialization method. @@ -47,11 +47,12 @@ def __init__(self, :param api_url: API URL. Defaults to DJ_API_URL environment variable. :param api_key: API key. Defaults to DJ_API_KEY environment variable. :param response_path: Path to extract content from the API response. + Defaults to 'choices.0.message.content'. :param system_prompt: System prompt for the calibration task. :param input_template: Template for building the model input. :param reference_template: Template for formatting the reference text. :param qa_pair_template: Template for formatting question-answer pairs. - :param output_pattern: Pattern for parsing model output. + :param output_pattern: Regular expression for parsing model output. :param api_params: Extra API parameters. :param kwargs: Extra keyword arguments. """ @@ -65,7 +66,7 @@ def __init__(self, self.DEFAULT_QA_PAIR_TEMPLATE self.output_pattern = output_pattern or self.DEFAULT_OUTPUT_PATTERN - self.api_params = api_params + self.api_params = api_params or {} self.model_key = prepare_model(model_type='api', api_model=api_model, api_url=api_url, diff --git a/data_juicer/ops/mapper/calibrate_query_mapper.py b/data_juicer/ops/mapper/calibrate_query_mapper.py index d573a2c76..ca2828d70 100644 --- a/data_juicer/ops/mapper/calibrate_query_mapper.py +++ b/data_juicer/ops/mapper/calibrate_query_mapper.py @@ -9,7 +9,7 @@ @OPERATORS.register_module(OP_NAME) class CalibrateQueryMapper(CalibrateQAMapper): """ - Mapper to calibrate only query in question-answer pairs. + Mapper to calibrate query in question-answer pairs. """ DEFAULT_SYSTEM_PROMPT = '请根据提供的【参考信息】对问答对中的【问题】进行校准,\ diff --git a/data_juicer/ops/mapper/calibrate_response_mapper.py b/data_juicer/ops/mapper/calibrate_response_mapper.py index f2bcd0afa..0030833a4 100644 --- a/data_juicer/ops/mapper/calibrate_response_mapper.py +++ b/data_juicer/ops/mapper/calibrate_response_mapper.py @@ -9,7 +9,7 @@ @OPERATORS.register_module(OP_NAME) class CalibrateResponseMapper(CalibrateQAMapper): """ - Mapper to calibrate only response in question-answer pairs. + Mapper to calibrate response in question-answer pairs. """ DEFAULT_SYSTEM_PROMPT = '请根据提供的【参考信息】对问答对中的【回答】进行校准,\ diff --git a/data_juicer/utils/model_utils.py b/data_juicer/utils/model_utils.py index bd491311b..479c16eef 100644 --- a/data_juicer/utils/model_utils.py +++ b/data_juicer/utils/model_utils.py @@ -106,6 +106,103 @@ def check_model(model_name, force=False): return cached_model_path +class APIModel: + + def __init__(self, + api_model, + *, + api_url=None, + api_key=None, + response_path=None): + self.api_model = api_model + + if api_url is None: + api_url = os.getenv('DJ_API_URL') + if api_url is None: + base_url = os.getenv('OPENAI_BASE_URL', + 'https://api.openai.com/v1') + api_url = base_url.rstrip('/') + '/chat/completions' + self.api_url = api_url + + if api_key is None: + api_key = os.getenv('DJ_API_KEY') or os.getenv('OPENAI_API_KEY') + self.api_key = api_key + + if response_path is None: + response_path = 'choices.0.message.content' + self.response_path = response_path + + self.headers = { + 'Content-Type': 'application/json', + 'Authorization': f'Bearer {self.api_key}' + } + + def __call__(self, messages, **kwargs): + """Sends messages to the configured API model and returns the parsed response. + + :param messages: The messages to send to the API. + :param model: The model to be used for generating responses. + :param kwargs: Additional parameters for the API request. + + :return: The parsed response from the API, or None if an error occurs. + """ + payload = { + 'model': self.api_model, + 'messages': messages, + **kwargs, + } + try: + response = requests.post(self.api_url, + json=payload, + headers=self.headers) + response.raise_for_status() + result = response.json() + return self.nested_access(result, self.response_path) + except Exception as e: + logger.exception(e) + return None + + @staticmethod + def nested_access(data, path): + """Access nested data using a dot-separated path. + + :param data: The data structure to access. + :param path: A dot-separated string representing the path to access. + :return: The value at the specified path, if it exists. + """ + keys = path.split('.') + for key in keys: + # Convert string keys to integers if they are numeric + key = int(key) if key.isdigit() else key + data = data[key] + return data + + +def prepare_api_model(*, + api_model, + api_url=None, + api_key=None, + response_path=None): + """Creates a callable API model for interacting with OpenAI-compatible API. + + This callable object supports custom result parsing and is suitable for use + with incompatible proxy servers. + + :param api_url: The URL of the API. If not provided, it will fallback to + the environment variables DJ_API_URL or OPENAI_BASE_URL. + :param api_key: The API key for authorization. If not provided, it will + fallback to the environment variables DJ_API_KEY or OPENAI_API_KEY. + :param response_path: The path to extract content from the API response. + Defaults to 'choices.0.message.content'. + :return: A callable API model object that can be used to send messages and + receive responses. + """ + return APIModel(api_model=api_model, + api_url=api_url, + api_key=api_key, + response_path=response_path) + + def prepare_fasttext_model(model_name='lid.176.bin'): """ Prepare and load a fasttext model. @@ -590,99 +687,8 @@ def prepare_opencv_classifier(model_path): return model -class APIModel: - - def __init__(self, - *, - api_model, - api_url=None, - api_key=None, - response_path=None): - self.api_model = api_model - - if api_url is None: - api_url = os.getenv('DJ_API_URL') - if api_url is None: - base_url = os.getenv('OPENAI_BASE_URL', - 'https://api.openai.com/v1') - api_url = base_url.rstrip('/') + '/chat/completions' - self.api_url = api_url - - if api_key is None: - api_key = os.getenv('DJ_API_KEY') or os.getenv('OPENAI_API_KEY') - self.api_key = api_key - - if response_path is None: - response_path = 'choices.0.message.content' - self.response_path = response_path - - self.headers = { - 'Content-Type': 'application/json', - 'Authorization': f'Bearer {self.api_key}' - } - - def __call__(self, *, messages, **kwargs): - """Sends messages to the configured API model and returns the parsed response. - - :param messages: The messages to send to the API. - :param model: The model to be used for generating responses. - :param kwargs: Additional parameters for the API request. - - :return: The parsed response from the API, or None if an error occurs. - """ - payload = { - 'model': self.model, - 'messages': messages, - **kwargs, - } - try: - response = requests.post(self.api_url, - json=payload, - headers=self.headers) - response.raise_for_status() - result = response.json() - return self.nested_access(result, self.response_path) - except Exception as e: - logger.exception(e) - return None - - @staticmethod - def nested_access(data, path): - """Access nested data using a dot-separated path. - - :param data: The data structure to access. - :param path: A dot-separated string representing the path to access. - :return: The value at the specified path, if it exists. - """ - keys = path.split('.') - for key in keys: - # Convert string keys to integers if they are numeric - key = int(key) if key.isdigit() else key - data = data[key] - return data - - -def prepare_api_model(*, api_url=None, api_key=None, response_path=None): - """Creates a callable API model for interacting with the OpenAI-compatible API. - - This callable object supports custom result parsing and is suitable for use - with incompatible proxy servers. - - :param api_url: The URL of the API. If not provided, it will fallback - to the environment variable or a default OpenAI URL. - :param api_key: The API key for authorization. If not provided, it will - fallback to the environment variable. - :param response_path: The path to extract content from the API response. - Defaults to 'choices.0.message.content'. - :return: A callable API model object that can be used to send messages - and receive responses. - """ - return APIModel(api_url=api_url, - api_key=api_key, - response_path=response_path) - - MODEL_FUNCTION_MAPPING = { + 'api': prepare_api_model, 'fasttext': prepare_fasttext_model, 'sentencepiece': prepare_sentencepiece_for_lang, 'kenlm': prepare_kenlm_model, @@ -695,7 +701,6 @@ def prepare_api_model(*, api_url=None, api_key=None, response_path=None): 'recognizeAnything': prepare_recognizeAnything_model, 'vllm': prepare_vllm_model, 'opencv_classifier': prepare_opencv_classifier, - 'api': prepare_api_model, } From 5597d5c30f7c4102499af77dd59df85d4c236a39 Mon Sep 17 00:00:00 2001 From: "gece.gc" Date: Tue, 29 Oct 2024 03:44:11 +0000 Subject: [PATCH 005/118] more tests --- tests/ops/mapper/test_calibrate_qa_mapper.py | 75 ++++++++++++++++++ .../ops/mapper/test_calibrate_query_mapper.py | 75 ++++++++++++++++++ .../mapper/test_calibrate_response_mapper.py | 76 +++++++++++++++++++ 3 files changed, 226 insertions(+) create mode 100644 tests/ops/mapper/test_calibrate_qa_mapper.py create mode 100644 tests/ops/mapper/test_calibrate_query_mapper.py create mode 100644 tests/ops/mapper/test_calibrate_response_mapper.py diff --git a/tests/ops/mapper/test_calibrate_qa_mapper.py b/tests/ops/mapper/test_calibrate_qa_mapper.py new file mode 100644 index 000000000..c2c69a7d2 --- /dev/null +++ b/tests/ops/mapper/test_calibrate_qa_mapper.py @@ -0,0 +1,75 @@ +import unittest + +from loguru import logger + +from data_juicer.ops.mapper.calibrate_qa_mapper import CalibrateQAMapper +from data_juicer.utils.unittest_utils import (SKIPPED_TESTS, + DataJuicerTestCaseBase) + + +# Skip tests for this OP in the GitHub actions due to disk space limitation. +# These tests have been tested locally. +@SKIPPED_TESTS.register_module() +class OptimizeQAMapperTest(DataJuicerTestCaseBase): + + def _run_op(self, api_model, response_path): + + op = CalibrateQAMapper(api_model=api_model, + response_path=response_path) + + reference = """# 角色语言风格 +1. 下面是李莲花的问答样例,你必须贴合他的语言风格: + +问题:你是谁? +李莲花:在下李莲花,不才略有一点神医之名,有礼。 + +问题:你就是个假神医! +李莲花:此言差矣,我从未说过我是神医,又何来假神医之说。 + +问题:李相夷是江湖传奇,失去了李相夷,这个江湖也没意思了! +李莲花:幼芋生成,新木长生。这个江湖熙来攘往,总会有新的传奇出现的。 + +问题:你恨不恨云彼丘,他给你下的碧茶之毒? +李莲花:若我是李相夷,当然是会恨他的。可李相夷已经死了,死去的人怎么还会一直恨呢,往事如烟,既然是往事,早就该忘记了。 + +问题:你不喜欢石水吗?她好像喜欢你呢。 +李莲花:石水啊,确实是个好姑娘,外冷内热,聪明伶俐。但我只把她当成我的妹妹,更无半点男女私情。 + +问题:你不觉得笛飞声有瞒着你的地方吗?为什么不一探究竟呢。 +李莲花:人生在世,谁都有不想说的秘密,给别人留余地,就等于是给自己留余地。 + +问题:你不觉得自己一生的遗憾太多了了吗? +李莲花:人生嘛,本处处都是遗憾,没有什么放不下的,更没有什么解不开的结,人总得学会放过自己。 + +2. 下面是剧本中李莲花的部分台词,用于语言风格上的参考: + +李莲花:没事,就是有些好奇,我见展护卫武功高强,并非池中物,不知是何机缘会在天机山庄做护卫? +李莲花:如此花哨的玉佩,这邢自如虽长得糙,想不到也是一爱美之人啊。 +李莲花:讨个吉利,还没开工就打打杀杀,这可不是好兆头。咱们来发财的,先办大事要紧,其他以后再算不迟。来人来人,快将丁元子带走止血治伤。 +李莲花:在下已牢记在心,大师放心去吧。 +李莲花:放心吧,该看到的,都看到了。 +李莲花:在下李莲花,有礼。 +李莲花:你小厮被害很难过,我理解,可也不必把罪名栽给我吧? +李莲花:不过是受了些机关里的毒邪,方才我已服过天机堂的避毒丹了,无碍。 +李莲花:我不知道,也不愿知道。我所说的只是个故事,当故事听就好,是真是假、你自己判断. +李莲花:不必紧张,这毒我中了许久,早就习惯了,没那么严重的。 +李莲花:等我有天想起你的时候,我发现我忘了为什么要恨你,觉得过去那些已不重要。 +""" + samples = [{ + 'text': reference, + 'query': '你还喜欢乔婉娩吗?', + 'response': '不喜欢。' + }] + + for sample in samples: + result = op.process(sample) + logger.info(f'Output results: {result}') + self.assertNotEqual(result['query'], '') + self.assertNotEqual(result['response'], '') + + def test(self): + self._run_op('gpt-4o', 'data.response.choices.0.message.content') + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/ops/mapper/test_calibrate_query_mapper.py b/tests/ops/mapper/test_calibrate_query_mapper.py new file mode 100644 index 000000000..a8b6cb836 --- /dev/null +++ b/tests/ops/mapper/test_calibrate_query_mapper.py @@ -0,0 +1,75 @@ +import unittest + +from loguru import logger + +from data_juicer.ops.mapper.calibrate_query_mapper import CalibrateQueryMapper +from data_juicer.utils.unittest_utils import (SKIPPED_TESTS, + DataJuicerTestCaseBase) + + +# Skip tests for this OP because the API call is not configured yet. +# These tests have been tested locally. +@SKIPPED_TESTS.register_module() +class CalibrateQueryMapperTest(DataJuicerTestCaseBase): + + def _run_op(self, api_model, response_path): + + op = CalibrateQueryMapper(api_model=api_model, + response_path=response_path) + + reference = """# 角色语言风格 +1. 下面是李莲花的问答样例,你必须贴合他的语言风格: + +问题:你是谁? +李莲花:在下李莲花,不才略有一点神医之名,有礼。 + +问题:你就是个假神医! +李莲花:此言差矣,我从未说过我是神医,又何来假神医之说。 + +问题:李相夷是江湖传奇,失去了李相夷,这个江湖也没意思了! +李莲花:幼芋生成,新木长生。这个江湖熙来攘往,总会有新的传奇出现的。 + +问题:你恨不恨云彼丘,他给你下的碧茶之毒? +李莲花:若我是李相夷,当然是会恨他的。可李相夷已经死了,死去的人怎么还会一直恨呢,往事如烟,既然是往事,早就该忘记了。 + +问题:你不喜欢石水吗?她好像喜欢你呢。 +李莲花:石水啊,确实是个好姑娘,外冷内热,聪明伶俐。但我只把她当成我的妹妹,更无半点男女私情。 + +问题:你不觉得笛飞声有瞒着你的地方吗?为什么不一探究竟呢。 +李莲花:人生在世,谁都有不想说的秘密,给别人留余地,就等于是给自己留余地。 + +问题:你不觉得自己一生的遗憾太多了了吗? +李莲花:人生嘛,本处处都是遗憾,没有什么放不下的,更没有什么解不开的结,人总得学会放过自己。 + +2. 下面是剧本中李莲花的部分台词,用于语言风格上的参考: + +李莲花:没事,就是有些好奇,我见展护卫武功高强,并非池中物,不知是何机缘会在天机山庄做护卫? +李莲花:如此花哨的玉佩,这邢自如虽长得糙,想不到也是一爱美之人啊。 +李莲花:讨个吉利,还没开工就打打杀杀,这可不是好兆头。咱们来发财的,先办大事要紧,其他以后再算不迟。来人来人,快将丁元子带走止血治伤。 +李莲花:在下已牢记在心,大师放心去吧。 +李莲花:放心吧,该看到的,都看到了。 +李莲花:在下李莲花,有礼。 +李莲花:你小厮被害很难过,我理解,可也不必把罪名栽给我吧? +李莲花:不过是受了些机关里的毒邪,方才我已服过天机堂的避毒丹了,无碍。 +李莲花:我不知道,也不愿知道。我所说的只是个故事,当故事听就好,是真是假、你自己判断. +李莲花:不必紧张,这毒我中了许久,早就习惯了,没那么严重的。 +李莲花:等我有天想起你的时候,我发现我忘了为什么要恨你,觉得过去那些已不重要。 +""" + samples = [{ + 'text': reference, + 'query': '你还喜欢乔婉娩吗?', + 'response': '不喜欢。' + }] + + for sample in samples: + result = op.process(sample) + logger.info(f'Output results: {result}') + self.assertNotEqual(result['query'], '') + self.assertNotEqual(result['response'], '') + + def test(self): + self._run_op('gpt-4o', 'data.response.choices.0.message.content') + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/ops/mapper/test_calibrate_response_mapper.py b/tests/ops/mapper/test_calibrate_response_mapper.py new file mode 100644 index 000000000..bcbeffbcb --- /dev/null +++ b/tests/ops/mapper/test_calibrate_response_mapper.py @@ -0,0 +1,76 @@ +import unittest + +from loguru import logger + +from data_juicer.ops.mapper.calibrate_response_mapper import \ + CalibrateResponseMapper +from data_juicer.utils.unittest_utils import (SKIPPED_TESTS, + DataJuicerTestCaseBase) + + +# Skip tests for this OP because the API call is not configured yet. +# These tests have been tested locally. +@SKIPPED_TESTS.register_module() +class CalibrateResponseMapperTest(DataJuicerTestCaseBase): + + def _run_op(self, api_model, response_path): + + op = CalibrateResponseMapper(api_model=api_model, + response_path=response_path) + + reference = """# 角色语言风格 +1. 下面是李莲花的问答样例,你必须贴合他的语言风格: + +问题:你是谁? +李莲花:在下李莲花,不才略有一点神医之名,有礼。 + +问题:你就是个假神医! +李莲花:此言差矣,我从未说过我是神医,又何来假神医之说。 + +问题:李相夷是江湖传奇,失去了李相夷,这个江湖也没意思了! +李莲花:幼芋生成,新木长生。这个江湖熙来攘往,总会有新的传奇出现的。 + +问题:你恨不恨云彼丘,他给你下的碧茶之毒? +李莲花:若我是李相夷,当然是会恨他的。可李相夷已经死了,死去的人怎么还会一直恨呢,往事如烟,既然是往事,早就该忘记了。 + +问题:你不喜欢石水吗?她好像喜欢你呢。 +李莲花:石水啊,确实是个好姑娘,外冷内热,聪明伶俐。但我只把她当成我的妹妹,更无半点男女私情。 + +问题:你不觉得笛飞声有瞒着你的地方吗?为什么不一探究竟呢。 +李莲花:人生在世,谁都有不想说的秘密,给别人留余地,就等于是给自己留余地。 + +问题:你不觉得自己一生的遗憾太多了了吗? +李莲花:人生嘛,本处处都是遗憾,没有什么放不下的,更没有什么解不开的结,人总得学会放过自己。 + +2. 下面是剧本中李莲花的部分台词,用于语言风格上的参考: + +李莲花:没事,就是有些好奇,我见展护卫武功高强,并非池中物,不知是何机缘会在天机山庄做护卫? +李莲花:如此花哨的玉佩,这邢自如虽长得糙,想不到也是一爱美之人啊。 +李莲花:讨个吉利,还没开工就打打杀杀,这可不是好兆头。咱们来发财的,先办大事要紧,其他以后再算不迟。来人来人,快将丁元子带走止血治伤。 +李莲花:在下已牢记在心,大师放心去吧。 +李莲花:放心吧,该看到的,都看到了。 +李莲花:在下李莲花,有礼。 +李莲花:你小厮被害很难过,我理解,可也不必把罪名栽给我吧? +李莲花:不过是受了些机关里的毒邪,方才我已服过天机堂的避毒丹了,无碍。 +李莲花:我不知道,也不愿知道。我所说的只是个故事,当故事听就好,是真是假、你自己判断. +李莲花:不必紧张,这毒我中了许久,早就习惯了,没那么严重的。 +李莲花:等我有天想起你的时候,我发现我忘了为什么要恨你,觉得过去那些已不重要。 +""" + samples = [{ + 'text': reference, + 'query': '你还喜欢乔婉娩吗?', + 'response': '不喜欢。' + }] + + for sample in samples: + result = op.process(sample) + logger.info(f'Output results: {result}') + self.assertNotEqual(result['query'], '') + self.assertNotEqual(result['response'], '') + + def test(self): + self._run_op('gpt-4o', 'data.response.choices.0.message.content') + + +if __name__ == '__main__': + unittest.main() From 4b6e76919b6df536f94aa34b36c0957ea00513d4 Mon Sep 17 00:00:00 2001 From: "gece.gc" Date: Tue, 29 Oct 2024 06:44:39 +0000 Subject: [PATCH 006/118] update tests --- tests/ops/mapper/test_calibrate_qa_mapper.py | 8 +++++--- tests/ops/mapper/test_calibrate_query_mapper.py | 7 +++++-- tests/ops/mapper/test_calibrate_response_mapper.py | 7 +++++-- 3 files changed, 15 insertions(+), 7 deletions(-) diff --git a/tests/ops/mapper/test_calibrate_qa_mapper.py b/tests/ops/mapper/test_calibrate_qa_mapper.py index c2c69a7d2..7c31b56fc 100644 --- a/tests/ops/mapper/test_calibrate_qa_mapper.py +++ b/tests/ops/mapper/test_calibrate_qa_mapper.py @@ -7,12 +7,12 @@ DataJuicerTestCaseBase) -# Skip tests for this OP in the GitHub actions due to disk space limitation. +# Skip tests for this OP because the API call is not configured yet. # These tests have been tested locally. @SKIPPED_TESTS.register_module() class OptimizeQAMapperTest(DataJuicerTestCaseBase): - def _run_op(self, api_model, response_path): + def _run_op(self, api_model, response_path=None): op = CalibrateQAMapper(api_model=api_model, response_path=response_path) @@ -68,7 +68,9 @@ def _run_op(self, api_model, response_path): self.assertNotEqual(result['response'], '') def test(self): - self._run_op('gpt-4o', 'data.response.choices.0.message.content') + # export DJ_API_URL=https://dashscope.aliyuncs.com/compatible-mode/v1/chat/completions + # export DJ_API_KEY=your_key + self._run_op('qwen2.5-72b-instruct') if __name__ == '__main__': diff --git a/tests/ops/mapper/test_calibrate_query_mapper.py b/tests/ops/mapper/test_calibrate_query_mapper.py index a8b6cb836..8229c10ed 100644 --- a/tests/ops/mapper/test_calibrate_query_mapper.py +++ b/tests/ops/mapper/test_calibrate_query_mapper.py @@ -12,7 +12,7 @@ @SKIPPED_TESTS.register_module() class CalibrateQueryMapperTest(DataJuicerTestCaseBase): - def _run_op(self, api_model, response_path): + def _run_op(self, api_model, response_path=None): op = CalibrateQueryMapper(api_model=api_model, response_path=response_path) @@ -68,7 +68,10 @@ def _run_op(self, api_model, response_path): self.assertNotEqual(result['response'], '') def test(self): - self._run_op('gpt-4o', 'data.response.choices.0.message.content') + # before runing this test, set below environment variables: + # export DJ_API_URL=https://dashscope.aliyuncs.com/compatible-mode/v1/chat/completions + # export DJ_API_KEY=your_key + self._run_op('qwen2.5-72b-instruct') if __name__ == '__main__': diff --git a/tests/ops/mapper/test_calibrate_response_mapper.py b/tests/ops/mapper/test_calibrate_response_mapper.py index bcbeffbcb..e092d4c48 100644 --- a/tests/ops/mapper/test_calibrate_response_mapper.py +++ b/tests/ops/mapper/test_calibrate_response_mapper.py @@ -13,7 +13,7 @@ @SKIPPED_TESTS.register_module() class CalibrateResponseMapperTest(DataJuicerTestCaseBase): - def _run_op(self, api_model, response_path): + def _run_op(self, api_model, response_path=None): op = CalibrateResponseMapper(api_model=api_model, response_path=response_path) @@ -69,7 +69,10 @@ def _run_op(self, api_model, response_path): self.assertNotEqual(result['response'], '') def test(self): - self._run_op('gpt-4o', 'data.response.choices.0.message.content') + # before runing this test, set below environment variables: + # export DJ_API_URL=https://dashscope.aliyuncs.com/compatible-mode/v1/chat/completions + # export DJ_API_KEY=your_key + self._run_op('qwen2.5-72b-instruct') if __name__ == '__main__': From 325a753c5d34da9faafc456506b6bed2a391bbf9 Mon Sep 17 00:00:00 2001 From: "gece.gc" Date: Tue, 29 Oct 2024 06:54:41 +0000 Subject: [PATCH 007/118] update prompts --- data_juicer/ops/mapper/calibrate_qa_mapper.py | 4 ++-- data_juicer/ops/mapper/calibrate_query_mapper.py | 2 +- data_juicer/ops/mapper/calibrate_response_mapper.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/data_juicer/ops/mapper/calibrate_qa_mapper.py b/data_juicer/ops/mapper/calibrate_qa_mapper.py index 7a3ed5836..23b167c47 100644 --- a/data_juicer/ops/mapper/calibrate_qa_mapper.py +++ b/data_juicer/ops/mapper/calibrate_qa_mapper.py @@ -19,9 +19,9 @@ class CalibrateQAMapper(Mapper): DEFAULT_SYSTEM_PROMPT = ('请根据提供的【参考信息】对【问题】和【回答】进行校准,使其更加详细、准确。\n' '按照以下格式输出:\n' '【问题】\n' - '优化后的问题\n' + '校准后的问题\n' '【回答】\n' - '优化后的回答') + '校准后的回答') DEFAULT_INPUT_TEMPLATE = '{reference}\n{qa_pair}' DEFAULT_REFERENCE_TEMPLATE = '【参考信息】\n{}' DEFAULT_QA_PAIR_TEMPLATE = '【问题】\n{}\n【回答】\n{}' diff --git a/data_juicer/ops/mapper/calibrate_query_mapper.py b/data_juicer/ops/mapper/calibrate_query_mapper.py index ca2828d70..00fa23000 100644 --- a/data_juicer/ops/mapper/calibrate_query_mapper.py +++ b/data_juicer/ops/mapper/calibrate_query_mapper.py @@ -13,7 +13,7 @@ class CalibrateQueryMapper(CalibrateQAMapper): """ DEFAULT_SYSTEM_PROMPT = '请根据提供的【参考信息】对问答对中的【问题】进行校准,\ - 使其更加详细、准确,且仍可以由原答案回答。只输出优化后的问题,不要输出多余内容。' + 使其更加详细、准确,且仍可以由原答案回答。只输出校准后的问题,不要输出多余内容。' def parse_output(self, raw_output): return raw_output.strip(), None diff --git a/data_juicer/ops/mapper/calibrate_response_mapper.py b/data_juicer/ops/mapper/calibrate_response_mapper.py index 0030833a4..ef47d6a70 100644 --- a/data_juicer/ops/mapper/calibrate_response_mapper.py +++ b/data_juicer/ops/mapper/calibrate_response_mapper.py @@ -13,7 +13,7 @@ class CalibrateResponseMapper(CalibrateQAMapper): """ DEFAULT_SYSTEM_PROMPT = '请根据提供的【参考信息】对问答对中的【回答】进行校准,\ - 使其更加详细、准确,且仍可以回答原问题。只输出优化后的回答,不要输出多余内容。' + 使其更加详细、准确,且仍可以回答原问题。只输出校准后的回答,不要输出多余内容。' def parse_output(self, raw_output): return None, raw_output.strip() From 4f04bdd2e2efaad824b3f75a7867efc1129f32f2 Mon Sep 17 00:00:00 2001 From: "gece.gc" Date: Wed, 30 Oct 2024 03:12:52 +0000 Subject: [PATCH 008/118] fix unittest --- data_juicer/ops/mapper/calibrate_qa_mapper.py | 2 +- tests/config/test_config_funcs.py | 21 +++++++++++++++++++ 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/data_juicer/ops/mapper/calibrate_qa_mapper.py b/data_juicer/ops/mapper/calibrate_qa_mapper.py index 23b167c47..c3686581d 100644 --- a/data_juicer/ops/mapper/calibrate_qa_mapper.py +++ b/data_juicer/ops/mapper/calibrate_qa_mapper.py @@ -28,7 +28,7 @@ class CalibrateQAMapper(Mapper): DEFAULT_OUTPUT_PATTERN = r'【问题】\s*(.*?)\s*【回答】\s*(.*)' def __init__(self, - api_model: str, + api_model: str = 'gpt-4o', *, api_url: Optional[str] = None, api_key: Optional[str] = None, diff --git a/tests/config/test_config_funcs.py b/tests/config/test_config_funcs.py index c024ceb0f..ff425d2d7 100644 --- a/tests/config/test_config_funcs.py +++ b/tests/config/test_config_funcs.py @@ -45,6 +45,9 @@ def test_yaml_cfg_file(self): 'image_key': 'images', 'audio_key': 'audios', 'video_key': 'videos', + 'query_key': 'query', + 'response_key': 'response', + 'history_key': 'history', 'accelerator': None, 'num_proc': 4, 'cpu_required': 1, @@ -62,6 +65,9 @@ def test_yaml_cfg_file(self): 'image_key': 'images', 'audio_key': 'audios', 'video_key': 'videos', + 'query_key': 'query', + 'response_key': 'response', + 'history_key': 'history', 'accelerator': None, 'num_proc': 4, 'stats_export_path': None, @@ -128,6 +134,9 @@ def test_mixture_cfg(self): 'image_key': 'images', 'audio_key': 'audios', 'video_key': 'videos', + 'query_key': 'query', + 'response_key': 'response', + 'history_key': 'history', 'accelerator': None, 'num_proc': 4, 'stats_export_path': None, @@ -146,6 +155,9 @@ def test_mixture_cfg(self): 'image_key': 'images', 'audio_key': 'audios', 'video_key': 'videos', + 'query_key': 'query', + 'response_key': 'response', + 'history_key': 'history', 'accelerator': None, 'num_proc': 4, 'stats_export_path': None, @@ -164,6 +176,9 @@ def test_mixture_cfg(self): 'image_key': 'images', 'audio_key': 'audios', 'video_key': 'videos', + 'query_key': 'query', + 'response_key': 'response', + 'history_key': 'history', 'accelerator': None, 'num_proc': 4, 'stats_export_path': None, @@ -182,6 +197,9 @@ def test_mixture_cfg(self): 'image_key': 'images', 'audio_key': 'audios', 'video_key': 'videos', + 'query_key': 'query', + 'response_key': 'response', + 'history_key': 'history', 'accelerator': None, 'num_proc': 4, 'stats_export_path': None, @@ -200,6 +218,9 @@ def test_mixture_cfg(self): 'image_key': 'images', 'audio_key': 'audios', 'video_key': 'videos', + 'query_key': 'query', + 'response_key': 'response', + 'history_key': 'history', 'accelerator': None, 'num_proc': 4, 'stats_export_path': None, From 0adbdcda15fd9e644bebf5f163ff16e7c883d888 Mon Sep 17 00:00:00 2001 From: "gece.gc" Date: Wed, 30 Oct 2024 03:21:24 +0000 Subject: [PATCH 009/118] update tests --- tests/ops/mapper/test_calibrate_qa_mapper.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/ops/mapper/test_calibrate_qa_mapper.py b/tests/ops/mapper/test_calibrate_qa_mapper.py index 7c31b56fc..0f42b8b84 100644 --- a/tests/ops/mapper/test_calibrate_qa_mapper.py +++ b/tests/ops/mapper/test_calibrate_qa_mapper.py @@ -68,6 +68,7 @@ def _run_op(self, api_model, response_path=None): self.assertNotEqual(result['response'], '') def test(self): + # before runing this test, set below environment variables: # export DJ_API_URL=https://dashscope.aliyuncs.com/compatible-mode/v1/chat/completions # export DJ_API_KEY=your_key self._run_op('qwen2.5-72b-instruct') From 0aa406956fd47075bbacbf8ad0250212b2472e31 Mon Sep 17 00:00:00 2001 From: "gece.gc" Date: Fri, 1 Nov 2024 03:38:40 +0000 Subject: [PATCH 010/118] add docs --- configs/config_all.yaml | 12 ++++++++++++ data_juicer/ops/mapper/calibrate_qa_mapper.py | 4 ++-- data_juicer/ops/mapper/calibrate_query_mapper.py | 2 +- data_juicer/ops/mapper/calibrate_response_mapper.py | 4 ++-- docs/Operators.md | 7 +++++-- docs/Operators_ZH.md | 11 +++++++---- 6 files changed, 29 insertions(+), 11 deletions(-) diff --git a/configs/config_all.yaml b/configs/config_all.yaml index 5b817c4d0..e530c4afc 100644 --- a/configs/config_all.yaml +++ b/configs/config_all.yaml @@ -53,6 +53,18 @@ hpo_config: null # path to a configur process: # Mapper ops. Most of these ops need no arguments. - audio_ffmpeg_wrapped_mapper: # simple wrapper for FFmpeg audio filters + - calibrate_qa_mapper: # calibrate question-answer pairs based on reference text. + api_url: # API URL. Defaults to DJ_API_URL environment variable. + api_key: # API key. Defaults to DJ_API_KEY environment variable. + response_path: # Path to extract content from the API response. Defaults to 'choices.0.message.content'. + system_prompt: # System prompt for the calibration task. + input_template: # Template for building the model input. + reference_template: # Template for formatting the reference text. + qa_pair_template: # Template for formatting question-answer pairs. + output_pattern: # Regular expression for parsing model output. + api_params: # Extra parameters passed to the API call. + - calibrate_query_mapper: # calibrate query in question-answer pairs based on reference text. + - calibrate_response_mapper: # calibrate response in question-answer pairs based on reference text. - chinese_convert_mapper: # convert Chinese between Traditional Chinese, Simplified Chinese and Japanese Kanji. mode: 's2t' # choose the mode to convert Chinese: ['s2t', 't2s', 's2tw', 'tw2s', 's2hk', 'hk2s', 's2twp', 'tw2sp', 't2tw', 'tw2t', 'hk2t', 't2hk', 't2jp', 'jp2t'] - clean_email_mapper: # remove emails from text. diff --git a/data_juicer/ops/mapper/calibrate_qa_mapper.py b/data_juicer/ops/mapper/calibrate_qa_mapper.py index c3686581d..dcbeb654a 100644 --- a/data_juicer/ops/mapper/calibrate_qa_mapper.py +++ b/data_juicer/ops/mapper/calibrate_qa_mapper.py @@ -12,7 +12,7 @@ @OPERATORS.register_module(OP_NAME) class CalibrateQAMapper(Mapper): """ - Mapper to calibrate question-answer pairs. + Mapper to calibrate question-answer pairs based on reference text. """ # avoid leading whitespace @@ -53,7 +53,7 @@ def __init__(self, :param reference_template: Template for formatting the reference text. :param qa_pair_template: Template for formatting question-answer pairs. :param output_pattern: Regular expression for parsing model output. - :param api_params: Extra API parameters. + :param api_params: Extra parameters passed to the API call. :param kwargs: Extra keyword arguments. """ super().__init__(**kwargs) diff --git a/data_juicer/ops/mapper/calibrate_query_mapper.py b/data_juicer/ops/mapper/calibrate_query_mapper.py index 00fa23000..88098d7f8 100644 --- a/data_juicer/ops/mapper/calibrate_query_mapper.py +++ b/data_juicer/ops/mapper/calibrate_query_mapper.py @@ -9,7 +9,7 @@ @OPERATORS.register_module(OP_NAME) class CalibrateQueryMapper(CalibrateQAMapper): """ - Mapper to calibrate query in question-answer pairs. + Mapper to calibrate query in question-answer pairs based on reference text. """ DEFAULT_SYSTEM_PROMPT = '请根据提供的【参考信息】对问答对中的【问题】进行校准,\ diff --git a/data_juicer/ops/mapper/calibrate_response_mapper.py b/data_juicer/ops/mapper/calibrate_response_mapper.py index ef47d6a70..db56af317 100644 --- a/data_juicer/ops/mapper/calibrate_response_mapper.py +++ b/data_juicer/ops/mapper/calibrate_response_mapper.py @@ -9,8 +9,8 @@ @OPERATORS.register_module(OP_NAME) class CalibrateResponseMapper(CalibrateQAMapper): """ - Mapper to calibrate response in question-answer pairs. - """ + Mapper to calibrate response in question-answer pairs based on reference text. + """ # noqa: E501 DEFAULT_SYSTEM_PROMPT = '请根据提供的【参考信息】对问答对中的【回答】进行校准,\ 使其更加详细、准确,且仍可以回答原问题。只输出校准后的回答,不要输出多余内容。' diff --git a/docs/Operators.md b/docs/Operators.md index f84cc25a0..3f74e4499 100644 --- a/docs/Operators.md +++ b/docs/Operators.md @@ -11,7 +11,7 @@ The operators in Data-Juicer are categorized into 5 types. | Type | Number | Description | |-----------------------------------|:------:|-------------------------------------------------| | [ Formatter ]( #formatter ) | 7 | Discovers, loads, and canonicalizes source data | -| [ Mapper ]( #mapper ) | 47 | Edits and transforms samples | +| [ Mapper ]( #mapper ) | 50 | Edits and transforms samples | | [ Filter ]( #filter ) | 43 | Filters out low-quality samples | | [ Deduplicator ]( #deduplicator ) | 8 | Detects and removes duplicate samples | | [ Selector ]( #selector ) | 4 | Selects top samples based on ranking | @@ -51,6 +51,9 @@ All the specific operators are listed below, each featured with several capabili | Operator | Domain | Lang | Description | |-----------------------------------------------------|--------------------|--------|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| | audio_ffmpeg_wrapped_mapper | Audio | - | Simple wrapper to run a FFmpeg audio filter | +| calibrate_qa_mapper | General | en, zh | Calibrate question-answer pairs based on reference text | +| calibrate_query_mapper | General | en, zh | Calibrate query in question-answer pairs based on reference text | +| calibrate_response_mapper | General | en, zh | Calibrate response in question-answer pairs based on reference text | | chinese_convert_mapper | General | zh | Converts Chinese between Traditional Chinese, Simplified Chinese and Japanese Kanji (by [opencc](https://github.com/BYVoid/OpenCC)) | | clean_copyright_mapper | Code | en, zh | Removes copyright notice at the beginning of code files (must contain the word *copyright*) | | clean_email_mapper | General | en, zh | Removes email information | @@ -66,7 +69,7 @@ All the specific operators are listed below, each featured with several capabili | image_captioning_mapper | Multimodal | - | generate samples whose captions are generated based on another model (such as blip2) and the figure within the original sample | | image_diffusion_mapper | Multimodal | - | Generate and augment images by stable diffusion model | | image_face_blur_mapper | Image | - | Blur faces detected in images | -| image_tagging_mapper | Multimodal | - | Mapper to generate image tags from the input images. | +| image_tagging_mapper | Multimodal | - | Mapper to generate image tags from the input images. | | nlpaug_en_mapper | General | en | Simply augments texts in English based on the `nlpaug` library | | nlpcda_zh_mapper | General | zh | Simply augments texts in Chinese based on the `nlpcda` library | | optimize_instruction_mapper | General | en, zh | Optimize instruction text samples. | diff --git a/docs/Operators_ZH.md b/docs/Operators_ZH.md index 5a7c8ddda..c8c83e5b3 100644 --- a/docs/Operators_ZH.md +++ b/docs/Operators_ZH.md @@ -11,7 +11,7 @@ Data-Juicer 中的算子分为以下 5 种类型。 | 类型 | 数量 | 描述 | |------------------------------------|:--:|---------------| | [ Formatter ]( #formatter ) | 7 | 发现、加载、规范化原始数据 | -| [ Mapper ]( #mapper ) | 47 | 对数据样本进行编辑和转换 | +| [ Mapper ]( #mapper ) | 50 | 对数据样本进行编辑和转换 | | [ Filter ]( #filter ) | 43 | 过滤低质量样本 | | [ Deduplicator ]( #deduplicator ) | 8 | 识别、删除重复样本 | | [ Selector ]( #selector ) | 4 | 基于排序选取高质量样本 | @@ -47,9 +47,12 @@ Data-Juicer 中的算子分为以下 5 种类型。 ## Mapper -| 算子 | 场景 | 语言 | 描述 | +| 算子 | 场景 | 语言 | 描述 | |----------------------------------------------------|-----------------------|-----------|------------------------------------------------------------------------| | audio_ffmpeg_wrapped_mapper | Audio | - | 运行 FFmpeg 语音过滤器的简单封装 | +| calibrate_qa_mapper | General | en, zh | 根据参考文本校准问答对 | +| calibrate_query_mapper | General | en, zh | 根据参考文本校准问答对中的问题 | +| calibrate_response_mapper | General | en, zh | 根据参考文本校准问答对中的回答 | | chinese_convert_mapper | General | zh | 用于在繁体中文、简体中文和日文汉字之间进行转换(借助 [opencc](https://github.com/BYVoid/OpenCC)) | | clean_copyright_mapper | Code | en, zh | 删除代码文件开头的版权声明 (必须包含单词 *copyright*) | | clean_email_mapper | General | en, zh | 删除邮箱信息 | @@ -62,10 +65,10 @@ Data-Juicer 中的算子分为以下 5 种类型。 | generate_instruction_mapper | General | en, zh | 指令扩充,根据种子数据,生成新的样本。 | | image_blur_mapper | Image | - | 对图像进行模糊处理 | | image_captioning_from_gpt4v_mapper | Multimodal | - | 基于gpt-4-vision和图像生成文本 | -| image_captioning_mapper | Multimodal | - | 生成样本,其标题是根据另一个辅助模型(例如 blip2)和原始样本中的图形生成的。 | +| image_captioning_mapper | Multimodal | - | 生成样本,其标题是根据另一个辅助模型(例如 blip2)和原始样本中的图形生成的。 | | image_diffusion_mapper | Multimodal | - | 用stable diffusion生成图像,对图像进行增强 | | image_face_blur_mapper | Image | - | 对图像中的人脸进行模糊处理 | -| image_tagging_mapper | Multimodal | - | 从输入图片中生成图片标签 | +| image_tagging_mapper | Multimodal | - | 从输入图片中生成图片标签 | | nlpaug_en_mapper | General | en | 使用`nlpaug`库对英语文本进行简单增强 | | nlpcda_zh_mapper | General | zh | 使用`nlpcda`库对中文文本进行简单增强 | | optimize_instruction_mapper | General | en, zh | 指令优化,优化prompt。 | From f0075321bee487f00eba38df995e821dff57581d Mon Sep 17 00:00:00 2001 From: "gece.gc" Date: Fri, 1 Nov 2024 09:00:15 +0000 Subject: [PATCH 011/118] minor fix --- configs/config_all.yaml | 1 + data_juicer/utils/model_utils.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/configs/config_all.yaml b/configs/config_all.yaml index e530c4afc..3bdd9f680 100644 --- a/configs/config_all.yaml +++ b/configs/config_all.yaml @@ -54,6 +54,7 @@ process: # Mapper ops. Most of these ops need no arguments. - audio_ffmpeg_wrapped_mapper: # simple wrapper for FFmpeg audio filters - calibrate_qa_mapper: # calibrate question-answer pairs based on reference text. + api_model: # API model name. api_url: # API URL. Defaults to DJ_API_URL environment variable. api_key: # API key. Defaults to DJ_API_KEY environment variable. response_path: # Path to extract content from the API response. Defaults to 'choices.0.message.content'. diff --git a/data_juicer/utils/model_utils.py b/data_juicer/utils/model_utils.py index 479c16eef..98db05927 100644 --- a/data_juicer/utils/model_utils.py +++ b/data_juicer/utils/model_utils.py @@ -160,7 +160,7 @@ def __call__(self, messages, **kwargs): return self.nested_access(result, self.response_path) except Exception as e: logger.exception(e) - return None + return '' @staticmethod def nested_access(data, path): From ee4f4619b19570cf42c676858c6fd393e8e276e3 Mon Sep 17 00:00:00 2001 From: "gece.gc" Date: Tue, 5 Nov 2024 07:57:35 +0000 Subject: [PATCH 012/118] add API processor --- data_juicer/utils/model_utils.py | 62 ++++++++++++++++++++++++++------ 1 file changed, 52 insertions(+), 10 deletions(-) diff --git a/data_juicer/utils/model_utils.py b/data_juicer/utils/model_utils.py index 354b3b3a8..5d97a3c06 100644 --- a/data_juicer/utils/model_utils.py +++ b/data_juicer/utils/model_utils.py @@ -177,11 +177,13 @@ def nested_access(data, path): return data -def prepare_api_model(*, - api_model, +def prepare_api_model(api_model, + *, api_url=None, api_key=None, - response_path=None): + response_path=None, + return_processor=False, + processor_name=None): """Creates a callable API model for interacting with OpenAI-compatible API. This callable object supports custom result parsing and is suitable for use @@ -192,14 +194,54 @@ def prepare_api_model(*, :param api_key: The API key for authorization. If not provided, it will fallback to the environment variables DJ_API_KEY or OPENAI_API_KEY. :param response_path: The path to extract content from the API response. - Defaults to 'choices.0.message.content'. - :return: A callable API model object that can be used to send messages and - receive responses. + Defaults to 'choices.0.message.content'. This can be customized + based on the API's response structure. + :param return_processor: A boolean flag indicating whether to return a + processor along with the model. The processor is used for tasks like + tokenization or encoding. Defaults to False. + :param processor_name: The name of a specific processor from Hugging Face + to be used. This is only necessary if a custom processor is required. + :return: A tuple containing the callable API model object and optionally a + processor if `return_processor` is True. """ - return APIModel(api_model=api_model, - api_url=api_url, - api_key=api_key, - response_path=response_path) + model = APIModel(api_model=api_model, + api_url=api_url, + api_key=api_key, + response_path=response_path) + if not return_processor: + return model + + def get_processor(): + try: + import tiktoken + return tiktoken.encoding_for_model(api_model) + except Exception: + pass + + try: + import dashscope + return dashscope.get_tokenizer(api_model) + except Exception: + pass + + try: + return transformers.AutoProcessor.from_pretrained( + api_model, trust_remote_code=True) + except Exception: + raise ValueError( + 'Failed to initialize the processor. Please check the following:\n' # noqa: E501 + "- For OpenAI models: Install 'tiktoken' via `pip install tiktoken`.\n" # noqa: E501 + "- For DashScope models: Install 'dashscope' via `pip install dashscope`.\n" # noqa: E501 + "- For custom models: Provide a valid Hugging Face name in 'processor_name'.\n" # noqa: E501 + 'If the issue persists, verify the passed `api_model` parameter.' # noqa: E501 + ) + + if processor_name is not None: + processor = transformers.AutoProcessor.from_pretrained( + processor_name, trust_remote_code=True) + else: + processor = get_processor() + return (model, processor) def prepare_diffusion_model(pretrained_model_name_or_path, diffusion_type, From b00b18243836d65d3ccedccfa15c05f80318bcd8 Mon Sep 17 00:00:00 2001 From: "gece.gc" Date: Tue, 5 Nov 2024 08:30:09 +0000 Subject: [PATCH 013/118] refine API processor --- data_juicer/utils/model_utils.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/data_juicer/utils/model_utils.py b/data_juicer/utils/model_utils.py index 5d97a3c06..3a33ac8b4 100644 --- a/data_juicer/utils/model_utils.py +++ b/data_juicer/utils/model_utils.py @@ -225,20 +225,17 @@ def get_processor(): pass try: - return transformers.AutoProcessor.from_pretrained( - api_model, trust_remote_code=True) + return transformers.AutoProcessor.from_pretrained(api_model) except Exception: raise ValueError( 'Failed to initialize the processor. Please check the following:\n' # noqa: E501 "- For OpenAI models: Install 'tiktoken' via `pip install tiktoken`.\n" # noqa: E501 - "- For DashScope models: Install 'dashscope' via `pip install dashscope`.\n" # noqa: E501 - "- For custom models: Provide a valid Hugging Face name in 'processor_name'.\n" # noqa: E501 - 'If the issue persists, verify the passed `api_model` parameter.' # noqa: E501 - ) + "- For DashScope models: Install both 'dashscope' and 'tiktoken' via `pip install dashscope tiktoken`.\n" # noqa: E501 + "- For custom models: Provide a valid Hugging Face name via the 'processor_name' parameter.\n" # noqa: E501 + 'If the issue persists, check the provided `api_model`.') if processor_name is not None: - processor = transformers.AutoProcessor.from_pretrained( - processor_name, trust_remote_code=True) + processor = transformers.AutoProcessor.from_pretrained(processor_name) else: processor = get_processor() return (model, processor) From b718de714f2392a79f8fe5ebd5af9564b963d66f Mon Sep 17 00:00:00 2001 From: "gece.gc" Date: Tue, 5 Nov 2024 08:48:09 +0000 Subject: [PATCH 014/118] refine --- configs/config_all.yaml | 3 ++- data_juicer/ops/mapper/calibrate_qa_mapper.py | 14 +++++++++----- data_juicer/utils/model_utils.py | 12 +++++++++--- 3 files changed, 20 insertions(+), 9 deletions(-) diff --git a/configs/config_all.yaml b/configs/config_all.yaml index dbbe462be..48c9f172e 100644 --- a/configs/config_all.yaml +++ b/configs/config_all.yaml @@ -63,7 +63,8 @@ process: reference_template: null # Template for formatting the reference text. qa_pair_template: null # Template for formatting question-answer pairs. output_pattern: null # Regular expression for parsing model output. - api_params: null # Extra parameters in dict passed to the API call. + model_params: null # Parameters for initializing the model. + sampling_params: null # Extra parameters passed to the API call. - calibrate_query_mapper: # calibrate query in question-answer pairs based on reference text. - calibrate_response_mapper: # calibrate response in question-answer pairs based on reference text. - chinese_convert_mapper: # convert Chinese between Traditional Chinese, Simplified Chinese and Japanese Kanji. diff --git a/data_juicer/ops/mapper/calibrate_qa_mapper.py b/data_juicer/ops/mapper/calibrate_qa_mapper.py index dcbeb654a..a13c6dce4 100644 --- a/data_juicer/ops/mapper/calibrate_qa_mapper.py +++ b/data_juicer/ops/mapper/calibrate_qa_mapper.py @@ -38,7 +38,8 @@ def __init__(self, reference_template: Optional[str] = None, qa_pair_template: Optional[str] = None, output_pattern: Optional[str] = None, - api_params: Optional[Dict] = None, + model_params: Optional[Dict] = None, + sampling_params: Optional[Dict] = None, **kwargs): """ Initialization method. @@ -53,7 +54,8 @@ def __init__(self, :param reference_template: Template for formatting the reference text. :param qa_pair_template: Template for formatting question-answer pairs. :param output_pattern: Regular expression for parsing model output. - :param api_params: Extra parameters passed to the API call. + :param model_params: Parameters for initializing the model. + :param sampling_params: Extra parameters passed to the API call. :param kwargs: Extra keyword arguments. """ super().__init__(**kwargs) @@ -66,12 +68,14 @@ def __init__(self, self.DEFAULT_QA_PAIR_TEMPLATE self.output_pattern = output_pattern or self.DEFAULT_OUTPUT_PATTERN - self.api_params = api_params or {} + self.model_params = model_params or {} + self.sampling_params = sampling_params or {} self.model_key = prepare_model(model_type='api', api_model=api_model, api_url=api_url, api_key=api_key, - response_path=response_path) + response_path=response_path, + **model_params) def build_input(self, sample): reference = self.reference_template.format(sample[self.text_key]) @@ -98,7 +102,7 @@ def process_single(self, sample=None, rank=None): 'role': 'user', 'content': self.build_input(sample) }] - output = client(messages, **self.api_params) + output = client(messages, **self.sampling_params) parsed_q, parsed_a = self.parse_output(output) if parsed_q: diff --git a/data_juicer/utils/model_utils.py b/data_juicer/utils/model_utils.py index 3a33ac8b4..80c69b823 100644 --- a/data_juicer/utils/model_utils.py +++ b/data_juicer/utils/model_utils.py @@ -183,7 +183,8 @@ def prepare_api_model(api_model, api_key=None, response_path=None, return_processor=False, - processor_name=None): + processor_name=None, + **model_params): """Creates a callable API model for interacting with OpenAI-compatible API. This callable object supports custom result parsing and is suitable for use @@ -201,9 +202,12 @@ def prepare_api_model(api_model, tokenization or encoding. Defaults to False. :param processor_name: The name of a specific processor from Hugging Face to be used. This is only necessary if a custom processor is required. + :param model_params: Extra parameters to be passed to the processor. :return: A tuple containing the callable API model object and optionally a processor if `return_processor` is True. """ + model_params = model_params or {} + model = APIModel(api_model=api_model, api_url=api_url, api_key=api_key, @@ -225,7 +229,8 @@ def get_processor(): pass try: - return transformers.AutoProcessor.from_pretrained(api_model) + return transformers.AutoProcessor.from_pretrained( + api_model, **model_params) except Exception: raise ValueError( 'Failed to initialize the processor. Please check the following:\n' # noqa: E501 @@ -235,7 +240,8 @@ def get_processor(): 'If the issue persists, check the provided `api_model`.') if processor_name is not None: - processor = transformers.AutoProcessor.from_pretrained(processor_name) + processor = transformers.AutoProcessor.from_pretrained( + processor_name, **model_params) else: processor = get_processor() return (model, processor) From 6d1d433f9db416760f0e3062bdab82fc4688c90d Mon Sep 17 00:00:00 2001 From: Haibin <1400012807@pku.edu.cn> Date: Wed, 6 Nov 2024 10:15:46 +0800 Subject: [PATCH 015/118] chunk and extract events --- .coveragerc | 11 + data_juicer/ops/mapper/__init__.py | 4 + .../ops/mapper/extract_event_mapper.py | 133 +++++++++ .../ops/mapper/remove_table_text_mapper.py | 4 +- data_juicer/ops/mapper/text_chunk_mapper.py | 134 +++++++++ data_juicer/utils/model_utils.py | 3 +- data_juicer/utils/unittest_utils.py | 15 +- environments/dev_requires.txt | 1 + tests/core/test_monitor.py | 2 + .../filter/test_image_aesthetics_filter.py | 1 - .../filter/test_image_face_count_filter.py | 1 - .../filter/test_image_face_ratio_filter.py | 1 - tests/ops/filter/test_image_nsfw_filter.py | 1 - .../filter/test_image_text_matching_filter.py | 1 - .../ops/filter/test_image_watermark_filter.py | 1 - .../test_phrase_grounding_recall_filter.py | 1 - .../filter/test_video_aesthetics_filter.py | 1 - tests/ops/filter/test_video_nsfw_filter.py | 1 - .../test_video_ocr_area_ratio_filter.py | 1 - .../test_video_tagging_from_frames_filter.py | 1 - .../ops/filter/test_video_watermark_filter.py | 1 - tests/ops/mapper/test_extract_event_mapper.py | 66 +++++ tests/ops/mapper/test_extract_qa_mapper.py | 2 +- .../test_generate_instruction_mapper.py | 2 +- .../mapper/test_image_captioning_mapper.py | 2 +- .../ops/mapper/test_image_diffusion_mapper.py | 2 +- tests/ops/mapper/test_image_tagging_mapper.py | 2 + tests/ops/mapper/test_nlpcda_zh_mapper.py | 5 +- .../test_optimize_instruction_mapper.py | 2 +- tests/ops/mapper/test_text_chunk_mapper.py | 268 ++++++++++++++++++ ...test_video_captioning_from_audio_mapper.py | 2 +- ...est_video_captioning_from_frames_mapper.py | 2 +- ...video_captioning_from_summarizer_mapper.py | 2 +- ...test_video_captioning_from_video_mapper.py | 2 +- .../test_video_remove_watermark_mapper.py | 1 - .../test_video_split_by_scene_mapper.py | 1 - .../test_video_tagging_from_frames_mapper.py | 2 + tests/ops/test_op_fusion.py | 1 - tests/run.py | 17 +- 39 files changed, 670 insertions(+), 30 deletions(-) create mode 100644 .coveragerc create mode 100644 data_juicer/ops/mapper/extract_event_mapper.py create mode 100644 data_juicer/ops/mapper/text_chunk_mapper.py create mode 100644 tests/ops/mapper/test_extract_event_mapper.py create mode 100644 tests/ops/mapper/test_text_chunk_mapper.py diff --git a/.coveragerc b/.coveragerc new file mode 100644 index 000000000..d4a7a6d63 --- /dev/null +++ b/.coveragerc @@ -0,0 +1,11 @@ +[run] +omit = + # avoid measuring strange non-existing files + /workspace/config.py + /workspace/config-3.py + + # avoid measuring third-party dist packages + */dist-packages/* + + # avoid measuring code of unittest + tests/* diff --git a/data_juicer/ops/mapper/__init__.py b/data_juicer/ops/mapper/__init__.py index 5a31f9d79..8ac4394c7 100644 --- a/data_juicer/ops/mapper/__init__.py +++ b/data_juicer/ops/mapper/__init__.py @@ -6,6 +6,7 @@ from .clean_ip_mapper import CleanIpMapper from .clean_links_mapper import CleanLinksMapper from .expand_macro_mapper import ExpandMacroMapper +from .extract_event_mapper import ExtractEventMapper from .extract_qa_mapper import ExtractQAMapper from .fix_unicode_mapper import FixUnicodeMapper from .generate_instruction_mapper import GenerateInstructionMapper @@ -32,6 +33,7 @@ RemoveWordsWithIncorrectSubstringsMapper from .replace_content_mapper import ReplaceContentMapper from .sentence_split_mapper import SentenceSplitMapper +from .text_chunk_mapper import TextChunkMapper from .video_captioning_from_audio_mapper import VideoCaptioningFromAudioMapper from .video_captioning_from_frames_mapper import \ VideoCaptioningFromFramesMapper @@ -98,4 +100,6 @@ 'VideoSplitByDurationMapper', 'VideoFaceBlurMapper', 'ImageTaggingMapper', + 'TextChunkMapper', + 'ExtractEventMapper', ] diff --git a/data_juicer/ops/mapper/extract_event_mapper.py b/data_juicer/ops/mapper/extract_event_mapper.py new file mode 100644 index 000000000..1b376102a --- /dev/null +++ b/data_juicer/ops/mapper/extract_event_mapper.py @@ -0,0 +1,133 @@ +import json +import re +from typing import Dict, Optional + +from loguru import logger +from pydantic import PositiveInt + +from data_juicer.ops.base_op import OPERATORS, UNFORKABLE, Mapper +from data_juicer.utils.model_utils import get_model, prepare_model + +OP_NAME = 'extract_event_mapper' + + +# TODO: LLM-based inference. +@UNFORKABLE.register_module(OP_NAME) +@OPERATORS.register_module(OP_NAME) +class ExtractEventMapper(Mapper): + """ + Extract events and relavant characters in the text + """ + + DEFAULT_SYSTEM_PROMPT = ('给定一段文本,对文本的情节进行分点总结,并抽取与情节相关的人物。\n' + '要求:\n' + '- 尽量不要遗漏内容,不要添加文本中没有的情节,符合原文事实\n' + '- 联系上下文说明前因后果,但仍然需要符合事实\n' + '- 不要包含主观看法\n' + '- 注意要尽可能保留文本的专有名词\n' + '- 注意相关人物需要在对应情节中出现\n' + '- 只抽取情节中的主要人物,不要遗漏情节的主要人物\n' + '- 总结格式如下:\n' + '### 情节1:\n' + '- **情节描述**: ...\n' + '- **相关人物**:人物1,人物2,人物3,...\n' + '### 情节2:\n' + '- **情节描述**: ...\n' + '- **相关人物**:人物1,人物2,...\n' + '### 情节3:\n' + '- **情节描述**: ...\n' + '- **相关人物**:人物1,...\n' + '...\n') + DEFAULT_INPUT_TEMPLATE = '文本:{text}\n' + DEFAULT_OUTPUT_PATTERN = r""" + \#\#\#\s*情节(\d+):\s* + -\s*情节描述\s*:\s*(.*?)\s* + -\s*相关人物\s*:\s*(.*?)(?=\#\#\#|\Z) + """ + + def __init__(self, + api_model: str = 'gpt-4o', + *, + api_url: Optional[str] = None, + api_key: Optional[str] = None, + response_path: Optional[str] = None, + system_prompt: Optional[str] = None, + input_template: Optional[str] = None, + output_pattern: Optional[str] = None, + try_num: PositiveInt = 3, + api_params: Optional[Dict] = None, + **kwargs): + """ + Initialization method. + :param api_model: API model name. + :param api_url: API URL. Defaults to DJ_API_URL environment variable. + :param api_key: API key. Defaults to DJ_API_KEY environment variable. + :param response_path: Path to extract content from the API response. + Defaults to 'choices.0.message.content'. + :param system_prompt: System prompt for the calibration task. + :param input_template: Template for building the model input. + :param output_pattern: Regular expression for parsing model output. + :param try_num: The number of retry attempts when there is an API + call error or output parsing error. + :param api_params: Extra parameters passed to the API call. + e.g {'temperature': 0.9, 'top_p': 0.95} + :param kwargs: Extra keyword arguments. + """ + super().__init__(**kwargs) + + self.system_prompt = system_prompt or self.DEFAULT_SYSTEM_PROMPT + self.input_template = input_template or self.DEFAULT_INPUT_TEMPLATE + self.output_pattern = output_pattern or self.DEFAULT_OUTPUT_PATTERN + + self.api_params = api_params or {} + self.model_key = prepare_model(model_type='api', + api_model=api_model, + api_url=api_url, + api_key=api_key, + response_path=response_path) + + self.try_num = try_num + + def build_input(self, sample): + input_prompt = self.input_template.format(text=sample[self.text_key]) + return input_prompt + + def parse_output(self, raw_output): + pattern = re.compile(self.output_pattern, re.VERBOSE | re.DOTALL) + matches = pattern.findall(raw_output) + + contents = [] + for match in matches: + _, description, characters = match + description = description.strip() + contents.append({ + 'event_description': description, + 'relavant_characters': characters + }) + + return contents + + def process_single(self, sample=None, rank=None): + client = get_model(self.model_key, rank=rank) + + messages = [{ + 'role': 'system', + 'content': self.system_prompt + }, { + 'role': 'user', + 'content': self.build_input(sample) + }] + + contents = [] + for i in range(self.try_num): + try: + output = client(messages, **self.api_params) + contents = self.parse_output(output) + if len(contents) > 0: + break + except Exception as e: + logger.warning(f'Exception: {e}') + + sample[self.response_key] = json.dumps(contents) + + return sample diff --git a/data_juicer/ops/mapper/remove_table_text_mapper.py b/data_juicer/ops/mapper/remove_table_text_mapper.py index ff2b07a4f..87b06e5d1 100644 --- a/data_juicer/ops/mapper/remove_table_text_mapper.py +++ b/data_juicer/ops/mapper/remove_table_text_mapper.py @@ -36,8 +36,8 @@ def __init__(self, def process_batched(self, samples): for idx, text in enumerate(samples[self.text_key]): - for idx in range(self.min_col - 1, self.max_col): - pattern = re.compile(self.pattern % idx) + for i in range(self.min_col - 1, self.max_col): + pattern = re.compile(self.pattern % i) text = pattern.sub('', text) samples[self.text_key][idx] = text diff --git a/data_juicer/ops/mapper/text_chunk_mapper.py b/data_juicer/ops/mapper/text_chunk_mapper.py new file mode 100644 index 000000000..a659da7d9 --- /dev/null +++ b/data_juicer/ops/mapper/text_chunk_mapper.py @@ -0,0 +1,134 @@ +import re +from itertools import chain +from typing import Union + +from pydantic import NonNegativeInt, PositiveInt + +from data_juicer.utils.model_utils import get_model, prepare_model + +from ..base_op import OPERATORS, Mapper + +OP_NAME = 'text_chunk_mapper' + + +@OPERATORS.register_module(OP_NAME) +class TextChunkMapper(Mapper): + """Split input text to chunks.""" + + _batched_op = True + + def __init__(self, + max_len: Union[PositiveInt, None] = None, + split_pattern: Union[str, None] = r'\n\n', + overlap_len: NonNegativeInt = 0, + hf_tokenizer: Union[str, None] = None, + trust_remote_code: bool = False, + *args, + **kwargs): + """ + Initialization method. + + :param max_len: Split text into multi texts with this max len if not + None. + :param split_pattern: Make sure split in this pattern if it is not None + and force cut if the length exceeds max_len. + :param overlap_len: Overlap length of the split texts if not split in + the split pattern. + :param hf_tokenizer: The tokenizer name of Hugging Face tokenizers. + The text length will be calculate as the token num if it is offerd. + Otherwise, the text length equals to string length. + :trust_remote_code: for loading huggingface model + :param args: extra args + :param kwargs: extra args + """ + super().__init__(*args, **kwargs) + + if max_len is None and split_pattern is None: + raise ValueError('max_len and split_pattern cannot be both None') + + if max_len is not None and overlap_len >= max_len: + raise ValueError('overlap_len must be less than max_len') + + self.max_len = max_len + self.overlap_len = overlap_len + self.split_pattern = split_pattern + self.hf_tokenizer = hf_tokenizer + if hf_tokenizer is not None: + self.model_key = prepare_model( + model_type='huggingface', + pretrained_model_name_or_path=hf_tokenizer, + return_model=False, + trust_remote_code=trust_remote_code) + + def recursively_chunk(self, text): + if self.hf_tokenizer is not None: + tokenizer = get_model(self.model_key) + tokens = tokenizer.tokenize(text) + tokens = [t.decode(encoding='UTF-8') for t in tokens] + total_len = len(tokens) + sub_text = ''.join(tokens[:self.max_len]) + else: + total_len = len(text) + sub_text = text[:self.max_len] + + if total_len <= self.max_len: + return [text] + + matches = list(re.finditer(self.split_pattern, sub_text)) + if not matches: + cur_text = sub_text + if self.hf_tokenizer is not None: + left_text = ''.join(tokens[self.max_len - self.overlap_len:]) + else: + left_text = text[self.max_len - self.overlap_len:] + else: + last_match = matches[-1] + cur_text = sub_text[:last_match.start()] + left_text = text[last_match.end():] + + return [cur_text] + self.recursively_chunk(left_text) + + def get_text_chunks(self, text): + + if self.split_pattern is not None and self.max_len is None: + chunks = re.split(f'({self.split_pattern})', text) + chunks = [t for t in chunks if t.strip()] + elif self.split_pattern is None and self.max_len is not None: + tokens = text + total_len = len(text) + if self.hf_tokenizer is not None: + tokenizer = get_model(self.model_key) + tokens = tokenizer.tokenize(text) + tokens = [t.decode(encoding='UTF-8') for t in tokens] + total_len = len(tokens) + if total_len <= self.max_len: + return [text] + chunks = [] + for start in range(0, total_len, self.max_len - self.overlap_len): + cur = tokens[start:start + self.max_len] + if self.hf_tokenizer is not None: + cur = ''.join(cur) + chunks.append(cur) + else: + chunks = self.recursively_chunk(text) + + return chunks + + def process_batched(self, samples): + + sample_num = len(samples[self.text_key]) + + samples[self.text_key] = [ + self.get_text_chunks(text) for text in samples[self.text_key] + ] + + for key in samples: + if key != self.text_key: + samples[key] = [[samples[key][i]] * + len(samples[self.text_key][i]) + for i in range(len(sample_num))] + + for key in samples: + samples[key] = list(chain(*samples[key])) + + return samples diff --git a/data_juicer/utils/model_utils.py b/data_juicer/utils/model_utils.py index cda046b81..dc593bc7f 100644 --- a/data_juicer/utils/model_utils.py +++ b/data_juicer/utils/model_utils.py @@ -612,7 +612,8 @@ def prepare_model(model_type, **model_kwargs): model_func = MODEL_FUNCTION_MAPPING[model_type] model_key = partial(model_func, **model_kwargs) # always instantiate once for possible caching - model_key() + if model_type != 'vllm': + model_key() return model_key diff --git a/data_juicer/utils/unittest_utils.py b/data_juicer/utils/unittest_utils.py index 81033b224..1e66c55cc 100644 --- a/data_juicer/utils/unittest_utils.py +++ b/data_juicer/utils/unittest_utils.py @@ -16,6 +16,8 @@ SKIPPED_TESTS = Registry('SkippedTests') +CLEAR_MODEL = False + def TEST_TAG(*tags): """Tags for test case. @@ -29,6 +31,15 @@ def decorator(func): return decorator +def set_clear_model_flag(flag): + global CLEAR_MODEL + CLEAR_MODEL = flag + if CLEAR_MODEL: + print('CLEAR DOWNLOADED MODELS AFTER UNITTESTS.') + else: + print('KEEP DOWNLOADED MODELS AFTER UNITTESTS.') + + class DataJuicerTestCaseBase(unittest.TestCase): @classmethod @@ -48,7 +59,9 @@ def tearDownClass(cls, hf_model_name=None) -> None: multiprocess.set_start_method(cls.original_mp_method, force=True) # clean the huggingface model cache files - if hf_model_name: + if not CLEAR_MODEL: + pass + elif hf_model_name: # given the hf model name, remove this model only model_dir = os.path.join( transformers.TRANSFORMERS_CACHE, diff --git a/environments/dev_requires.txt b/environments/dev_requires.txt index ff091a304..9793d5746 100644 --- a/environments/dev_requires.txt +++ b/environments/dev_requires.txt @@ -1,3 +1,4 @@ +coverage pre-commit sphinx sphinx-autobuild diff --git a/tests/core/test_monitor.py b/tests/core/test_monitor.py index 01840348d..3f7a35f21 100644 --- a/tests/core/test_monitor.py +++ b/tests/core/test_monitor.py @@ -4,6 +4,8 @@ from data_juicer.core import Monitor from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase, SKIPPED_TESTS +# Skip this test due to some random resource utilization fluctuation, which may +# cause failure of this test @SKIPPED_TESTS.register_module() class MonitorTest(DataJuicerTestCaseBase): diff --git a/tests/ops/filter/test_image_aesthetics_filter.py b/tests/ops/filter/test_image_aesthetics_filter.py index e20f9d2c6..3ebb8419c 100644 --- a/tests/ops/filter/test_image_aesthetics_filter.py +++ b/tests/ops/filter/test_image_aesthetics_filter.py @@ -8,7 +8,6 @@ from data_juicer.utils.constant import Fields from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase, SKIPPED_TESTS -@SKIPPED_TESTS.register_module() class ImageAestheticsFilterTest(DataJuicerTestCaseBase): maxDiff = None diff --git a/tests/ops/filter/test_image_face_count_filter.py b/tests/ops/filter/test_image_face_count_filter.py index dd106f6bb..becb47148 100644 --- a/tests/ops/filter/test_image_face_count_filter.py +++ b/tests/ops/filter/test_image_face_count_filter.py @@ -8,7 +8,6 @@ from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase, SKIPPED_TESTS -@SKIPPED_TESTS.register_module() class ImageFaceCountFilterTest(DataJuicerTestCaseBase): data_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', diff --git a/tests/ops/filter/test_image_face_ratio_filter.py b/tests/ops/filter/test_image_face_ratio_filter.py index d82ac2ec1..69a10a42d 100644 --- a/tests/ops/filter/test_image_face_ratio_filter.py +++ b/tests/ops/filter/test_image_face_ratio_filter.py @@ -8,7 +8,6 @@ from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase, SKIPPED_TESTS -@SKIPPED_TESTS.register_module() class ImageFaceRatioFilterTest(DataJuicerTestCaseBase): maxDiff = None diff --git a/tests/ops/filter/test_image_nsfw_filter.py b/tests/ops/filter/test_image_nsfw_filter.py index 0a588e272..87f24e94b 100644 --- a/tests/ops/filter/test_image_nsfw_filter.py +++ b/tests/ops/filter/test_image_nsfw_filter.py @@ -10,7 +10,6 @@ from data_juicer.utils.constant import Fields from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase, SKIPPED_TESTS -@SKIPPED_TESTS.register_module() class ImageNSFWFilterTest(DataJuicerTestCaseBase): data_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', diff --git a/tests/ops/filter/test_image_text_matching_filter.py b/tests/ops/filter/test_image_text_matching_filter.py index 0551da254..91ed938df 100644 --- a/tests/ops/filter/test_image_text_matching_filter.py +++ b/tests/ops/filter/test_image_text_matching_filter.py @@ -11,7 +11,6 @@ from data_juicer.utils.mm_utils import SpecialTokens from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase, SKIPPED_TESTS -@SKIPPED_TESTS.register_module() class ImageTextMatchingFilterTest(DataJuicerTestCaseBase): data_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', diff --git a/tests/ops/filter/test_image_watermark_filter.py b/tests/ops/filter/test_image_watermark_filter.py index 01ed2e0dc..b5e5146f8 100644 --- a/tests/ops/filter/test_image_watermark_filter.py +++ b/tests/ops/filter/test_image_watermark_filter.py @@ -10,7 +10,6 @@ from data_juicer.utils.constant import Fields from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase, SKIPPED_TESTS -@SKIPPED_TESTS.register_module() class ImageWatermarkFilterTest(DataJuicerTestCaseBase): data_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', diff --git a/tests/ops/filter/test_phrase_grounding_recall_filter.py b/tests/ops/filter/test_phrase_grounding_recall_filter.py index e865c2f22..201e71214 100644 --- a/tests/ops/filter/test_phrase_grounding_recall_filter.py +++ b/tests/ops/filter/test_phrase_grounding_recall_filter.py @@ -11,7 +11,6 @@ from data_juicer.utils.mm_utils import SpecialTokens from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase, SKIPPED_TESTS -@SKIPPED_TESTS.register_module() class PhraseGroundingRecallFilterTest(DataJuicerTestCaseBase): data_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', diff --git a/tests/ops/filter/test_video_aesthetics_filter.py b/tests/ops/filter/test_video_aesthetics_filter.py index 551d0e721..b0681ef4d 100644 --- a/tests/ops/filter/test_video_aesthetics_filter.py +++ b/tests/ops/filter/test_video_aesthetics_filter.py @@ -9,7 +9,6 @@ from data_juicer.utils.mm_utils import SpecialTokens from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase, SKIPPED_TESTS -@SKIPPED_TESTS.register_module() class VideoAestheticsFilterTest(DataJuicerTestCaseBase): data_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', diff --git a/tests/ops/filter/test_video_nsfw_filter.py b/tests/ops/filter/test_video_nsfw_filter.py index 3c713407d..8376eb7af 100644 --- a/tests/ops/filter/test_video_nsfw_filter.py +++ b/tests/ops/filter/test_video_nsfw_filter.py @@ -10,7 +10,6 @@ from data_juicer.utils.constant import Fields from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase, SKIPPED_TESTS -@SKIPPED_TESTS.register_module() class VideoNSFWFilterTest(DataJuicerTestCaseBase): data_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', diff --git a/tests/ops/filter/test_video_ocr_area_ratio_filter.py b/tests/ops/filter/test_video_ocr_area_ratio_filter.py index 9884ab1cf..b7c7e6f50 100644 --- a/tests/ops/filter/test_video_ocr_area_ratio_filter.py +++ b/tests/ops/filter/test_video_ocr_area_ratio_filter.py @@ -8,7 +8,6 @@ from data_juicer.utils.constant import Fields from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase, SKIPPED_TESTS -@SKIPPED_TESTS.register_module() class VideoOcrAreaRatioFilterTest(DataJuicerTestCaseBase): data_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', diff --git a/tests/ops/filter/test_video_tagging_from_frames_filter.py b/tests/ops/filter/test_video_tagging_from_frames_filter.py index c16b07d4d..545be9748 100644 --- a/tests/ops/filter/test_video_tagging_from_frames_filter.py +++ b/tests/ops/filter/test_video_tagging_from_frames_filter.py @@ -8,7 +8,6 @@ from data_juicer.utils.mm_utils import SpecialTokens from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase, SKIPPED_TESTS -@SKIPPED_TESTS.register_module() class VideoTaggingFromFramesFilterTest(DataJuicerTestCaseBase): data_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', 'data') diff --git a/tests/ops/filter/test_video_watermark_filter.py b/tests/ops/filter/test_video_watermark_filter.py index aca75131f..629319e3f 100644 --- a/tests/ops/filter/test_video_watermark_filter.py +++ b/tests/ops/filter/test_video_watermark_filter.py @@ -10,7 +10,6 @@ from data_juicer.utils.constant import Fields from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase, SKIPPED_TESTS -@SKIPPED_TESTS.register_module() class VideoWatermarkFilterTest(DataJuicerTestCaseBase): data_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', diff --git a/tests/ops/mapper/test_extract_event_mapper.py b/tests/ops/mapper/test_extract_event_mapper.py new file mode 100644 index 000000000..4c57fb63d --- /dev/null +++ b/tests/ops/mapper/test_extract_event_mapper.py @@ -0,0 +1,66 @@ +import unittest +import json +from data_juicer.ops.mapper.extract_event_mapper import ExtractEventMapper +from data_juicer.utils.unittest_utils import (SKIPPED_TESTS, + DataJuicerTestCaseBase) + +# Skip tests for this OP in the GitHub actions due to unknown DistNetworkError. +# These tests have been tested locally. +@SKIPPED_TESTS.register_module() +class ExtractEventMapperTest(DataJuicerTestCaseBase): + + + def _run_op(self, api_model, response_path=None): + + op = ExtractEventMapper(api_model=api_model, + response_path=response_path) + + raw_text = """△芩婆走到中间,看着众人。 +芩婆:当年,我那老鬼漆木山与李相夷之父乃是挚交。原本李家隐世而居,一日为了救人,得罪附近山匪,夜里便遭了山匪所袭,唯有二子生还,流落街头。 +封磬震惊:二子?不是只有一个儿子吗? +芩婆:我和漆木山得知这个噩耗后,到处寻找李家那两个孩子的下落。只可惜等我们找他们时,李家长子李相显已经病死。 +李莲花似回忆起了什么:李相显...... +芩婆:我们只从乞丐堆里带回了年纪尚且未满四岁的李相夷,以及,(看向单孤刀)二个一直护着李相夷,与李相显年纪相仿的小乞丐...... +闪回/ +李相显将李且给他的玉佩塞给单孤刀,恳切托付:我没什么值钱的东西,这个玉佩是我唯一的家当了、送给你,我弟弟、相夷......求你照顾他一阵...... +△李相显还想再说什么已气绝而亡,小相夷唤着哥哥大哭,单孤刀愕然看着手里的玉佩有点不知所措。 +△话刚说完,哐当一声破庙门倒进来,几个其他少年乞丐进来。少年乞丐老大:这地儿不错,诶,你俩,出去! +△单孤刀把小相夷护在身后,抓住靠在墙边的木棍。单孤刀:这儿,是我,和我弟弟的。 +乞丐们要抢李相夷的馒头,小李相夷哭着死死护住自馒头不放。 +乞丐甲野蛮地抢:给我拿来! +小单孤刀:放开他! +△单孤刀用力撞向几个乞丐,救下小李相夷。乞丐甲:小子,活腻了! +△几个乞丐围攻小单孤刀,小单孤刀和众乞丐厮打到一起。突然其中一个乞丐掏出一把生锈的刀就朝单孤刀砍去、一个点燃火把棍戳他。单孤刀侧手一挡,火把棍在他手腕上烫出一道伤口,身后几根棍子打得他痛苦倒地! +/闪回结束 +△单孤刀拿着自己手里的玉佩看着,又看看自己手上的印记,不肯相信。单孤刀:胡说!全都是胡说!这些事我为何不知道?都是你在信口雌黄! +芩婆:那我问你,我们将你带回云隐山之前的事你又记得多少? +△单孤刀突然愣住,他意识到那之前的事自己竟都想不起来。 +芩婆:怎么?都想不起来了?(拽起单孤刀手腕,露出他的伤痕)你当日被你师父找到时,手腕上就受了伤,也正因为这处伤,高烧不退,醒来后便忘记了不少从前的事。 +△单孤刀呆住。 +芩婆:而相夷当年不过孩童,尚未到记事的年纪,很多事自然不知道。 +△李莲花得知真相,闭目叹息。 +△封磬震惊地看看单孤刀,又看看李莲花,终于想明白了一切,颓然、懊恼。 +封磬:自萱公主之子下落不明后,这近百年来我们整个家族都一直在不遗余力地寻找萱公主的子嗣后代,直到二十几年前终于让我寻得了线索,知道萱公主的曾孙被漆木山夫妇收为徒,但......我只知道萱公主之孙有一年约十岁的儿子,却不知......原来竟还有一幼子!我......我凭着南胤皇族的玉佩、孩子的年纪和他身上的印记来与主上相认,可没想到......这竟是一个错误!全错了! +△封磬神情复杂地看向李莲花,封磬:你,你才是我的主上...... +△封磬颓然地跪倒下来。 +△李莲花对眼前的一切有些意外、无措。 +笛飞声冷声:怪不得单孤刀的血对业火独毫无作用,李莲花的血才能毁掉这东西。 +△笛飞声不禁冷笑一下。 +""" + samples = [{ + 'text': raw_text, + }] + + for sample in samples: + result = op.process(sample) + self.assertNotEqual(result['response'], '[]') + + def test(self): + # before runing this test, set below environment variables: + # export DJ_API_URL=https://dashscope.aliyuncs.com/compatible-mode/v1/chat/completions + # export DJ_API_KEY=your_key + self._run_op('qwen2.5-72b-instruct') + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/ops/mapper/test_extract_qa_mapper.py b/tests/ops/mapper/test_extract_qa_mapper.py index 2e1b59a78..415efad4e 100644 --- a/tests/ops/mapper/test_extract_qa_mapper.py +++ b/tests/ops/mapper/test_extract_qa_mapper.py @@ -4,7 +4,7 @@ from data_juicer.utils.unittest_utils import (SKIPPED_TESTS, DataJuicerTestCaseBase) -# Skip tests for this OP in the GitHub actions due to disk space limitation. +# Skip tests for this OP in the GitHub actions due to unknown DistNetworkError. # These tests have been tested locally. @SKIPPED_TESTS.register_module() class ExtractQAMapperTest(DataJuicerTestCaseBase): diff --git a/tests/ops/mapper/test_generate_instruction_mapper.py b/tests/ops/mapper/test_generate_instruction_mapper.py index 43bd31262..a250fbcc4 100644 --- a/tests/ops/mapper/test_generate_instruction_mapper.py +++ b/tests/ops/mapper/test_generate_instruction_mapper.py @@ -4,7 +4,7 @@ from data_juicer.utils.unittest_utils import (SKIPPED_TESTS, DataJuicerTestCaseBase) -# Skip tests for this OP in the GitHub actions due to disk space limitation. +# Skip tests for this OP in the GitHub actions due to unknown DistNetworkError. # These tests have been tested locally. @SKIPPED_TESTS.register_module() class GenerateInstructionMapperTest(DataJuicerTestCaseBase): diff --git a/tests/ops/mapper/test_image_captioning_mapper.py b/tests/ops/mapper/test_image_captioning_mapper.py index c4c3d1e3e..2a772ab20 100644 --- a/tests/ops/mapper/test_image_captioning_mapper.py +++ b/tests/ops/mapper/test_image_captioning_mapper.py @@ -9,7 +9,7 @@ DataJuicerTestCaseBase) -# Skip tests for this OP in the GitHub actions due to disk space limitation. +# Skip tests for this OP in the GitHub actions due to OOM on the current runner # These tests have been tested locally. @SKIPPED_TESTS.register_module() class ImageCaptioningMapperTest(DataJuicerTestCaseBase): diff --git a/tests/ops/mapper/test_image_diffusion_mapper.py b/tests/ops/mapper/test_image_diffusion_mapper.py index ad241732f..5883a7ff7 100644 --- a/tests/ops/mapper/test_image_diffusion_mapper.py +++ b/tests/ops/mapper/test_image_diffusion_mapper.py @@ -10,7 +10,7 @@ DataJuicerTestCaseBase) -# Skip tests for this OP in the GitHub actions due to disk space limitation. +# Skip tests for this OP in the GitHub actions due to OOM on the current runner # These tests have been tested locally. @SKIPPED_TESTS.register_module() class ImageDiffusionMapperTest(DataJuicerTestCaseBase): diff --git a/tests/ops/mapper/test_image_tagging_mapper.py b/tests/ops/mapper/test_image_tagging_mapper.py index e9609b12f..06fd6fba7 100644 --- a/tests/ops/mapper/test_image_tagging_mapper.py +++ b/tests/ops/mapper/test_image_tagging_mapper.py @@ -8,6 +8,8 @@ from data_juicer.utils.constant import Fields from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase, SKIPPED_TESTS +# Skip tests for this OP in the GitHub actions due to OOM on the current runner +# These tests have been tested locally. @SKIPPED_TESTS.register_module() class ImageTaggingMapperTest(DataJuicerTestCaseBase): data_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', diff --git a/tests/ops/mapper/test_nlpcda_zh_mapper.py b/tests/ops/mapper/test_nlpcda_zh_mapper.py index 3624a9c35..ce21ea55d 100644 --- a/tests/ops/mapper/test_nlpcda_zh_mapper.py +++ b/tests/ops/mapper/test_nlpcda_zh_mapper.py @@ -4,9 +4,12 @@ from data_juicer.core import NestedDataset as Dataset from data_juicer.ops.mapper.nlpcda_zh_mapper import NlpcdaZhMapper -from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase +from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase, SKIPPED_TESTS +# Skip tests for this OP in the GitHub actions due to unknown UnicodeEncodeError +# These tests have been tested locally. +@SKIPPED_TESTS.register_module() class NlpaugEnMapperTest(DataJuicerTestCaseBase): def setUp(self): diff --git a/tests/ops/mapper/test_optimize_instruction_mapper.py b/tests/ops/mapper/test_optimize_instruction_mapper.py index 7c7b58b4c..4b3e4562b 100644 --- a/tests/ops/mapper/test_optimize_instruction_mapper.py +++ b/tests/ops/mapper/test_optimize_instruction_mapper.py @@ -3,7 +3,7 @@ from data_juicer.utils.unittest_utils import (SKIPPED_TESTS, DataJuicerTestCaseBase) -# Skip tests for this OP in the GitHub actions due to disk space limitation. +# Skip tests for this OP in the GitHub actions due to unknown DistNetworkError. # These tests have been tested locally. @SKIPPED_TESTS.register_module() class OptimizeInstructionMapperTest(DataJuicerTestCaseBase): diff --git a/tests/ops/mapper/test_text_chunk_mapper.py b/tests/ops/mapper/test_text_chunk_mapper.py new file mode 100644 index 000000000..fde3387bf --- /dev/null +++ b/tests/ops/mapper/test_text_chunk_mapper.py @@ -0,0 +1,268 @@ +import unittest + +from data_juicer.core.data import NestedDataset as Dataset +from data_juicer.ops.mapper.text_chunk_mapper import TextChunkMapper +from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase + + +class TextChunkMapperTest(DataJuicerTestCaseBase): + + def _run_helper(self, op, samples, target): + dataset = Dataset.from_list(samples) + dataset = dataset.map(op.process, batch_size=2) + + for d, t in zip(dataset, target): + self.assertEqual(d['text'], t['text']) + + def test_naive_text_chunk(self): + + source = [ + { + 'text': "Today is Sunday and it's a happy day!" + }, + { + 'text': + "Sur la plateforme MT4, plusieurs manières d'accéder à \n" + 'ces fonctionnalités sont conçues simultanément.' + }, + { + 'text': '欢迎来到阿里巴巴!' + }, + ] + target = [ + { + 'text': "Today is Sunday and it's a happy day!" + }, + { + 'text': + "Sur la plateforme MT4, plusieurs manières d'accéder à " + }, + { + 'text': + 'ces fonctionnalités sont conçues simultanément.' + }, + { + 'text': '欢迎来到阿里巴巴!' + }, + ] + op = TextChunkMapper(split_pattern='\n') + self._run_helper(op, source, target) + + def test_max_len_text_chunk(self): + source = [ + { + 'text': "Today is Sunday and it's a happy day!" + }, + { + 'text': + "Sur la plateforme MT4, plusieurs manières d'accéder à " + 'ces fonctionnalités sont conçues simultanément.' + }, + { + 'text': '欢迎来到阿里巴巴!' + }, + ] + target = [ + { + 'text': "Today is Sunday and " + }, + { + 'text': "it's a happy day!" + }, + { + 'text': + "Sur la plateforme MT" + }, + { + 'text': + '4, plusieurs manière' + }, + { + 'text': + "s d'accéder à ces fo" + }, + { + 'text': + 'nctionnalités sont c' + }, + { + 'text': + 'onçues simultanément' + }, + { + 'text': + '.' + }, + { + 'text': '欢迎来到阿里巴巴!' + }, + ] + op = TextChunkMapper(max_len=20, split_pattern=None) + self._run_helper(op, source, target) + + def test_max_len_text_chunk(self): + source = [ + { + 'text': "Today is Sunday and it's a happy day!" + }, + { + 'text': + "Sur la plateforme MT4, plusieurs manières d'accéder à " + 'ces fonctionnalités sont conçues simultanément.' + }, + { + 'text': '欢迎来到阿里巴巴!' + }, + ] + target = [ + { + 'text': "Today is Sunday and " + }, + { + 'text': "d it's a happy day!" + }, + { + 'text': "Sur la plateforme MT" + }, + { + 'text': 'MT4, plusieurs maniè' + }, + { + 'text': "ières d'accéder à ce" + }, + { + 'text': 'ces fonctionnalités ' + }, + { + 'text': 's sont conçues simul' + }, + { + 'text': 'ultanément.' + }, + { + 'text': '欢迎来到阿里巴巴!' + }, + ] + op = TextChunkMapper(max_len=20, overlap_len=2) + self._run_helper(op, source, target) + + def test_max_len_and_split_pattern_text_chunk(self): + source = [ + { + 'text': "Today is Sunday and it's a happy day!" + }, + { + 'text': + "Sur la plateforme MT4, plusieurs manières d'accéder à \n" + 'ces fonctionnalités sont conçues simultanément.' + }, + { + 'text': '欢迎来到阿里巴巴!' + }, + ] + target = [ + { + 'text': "Today is Sunday and " + }, + { + 'text': "d it's a happy day!" + }, + { + 'text': "Sur la plateforme MT" + }, + { + 'text': 'MT4, plusieurs maniè' + }, + { + 'text': "ières d'accéder à " + }, + { + 'text': 'ces fonctionnalités ' + }, + { + 'text': 's sont conçues simul' + }, + { + 'text': 'ultanément.' + }, + { + 'text': '欢迎来到阿里巴巴!' + }, + ] + op = TextChunkMapper(max_len=20, overlap_len=2, split_pattern='\n') + self._run_helper(op, source, target) + + def test_tokenizer_text_chunk(self): + source = [ + { + 'text': "Today is Sunday and it's a happy day!" + }, + { + 'text': + "Sur la plateforme MT4, plusieurs manières d'accéder à " + 'ces fonctionnalités sont conçues simultanément.' + }, + { + 'text': '欢迎来到阿里巴巴!' + }, + ] + target = [ + { + 'text': "Today is Sunday and it's a happy day!" + }, + { + 'text': "Sur la plateforme MT4, plusieurs manières" + }, + { + 'text': "ières d'accéder à ces fonctionnalités" + }, + { + 'text': "ités sont conçues simultanément." + }, + { + 'text': '欢迎来到阿里巴巴!' + }, + ] + op = TextChunkMapper(max_len=10, overlap_len=1, split_pattern=None, hf_tokenizer='Qwen/Qwen-7B-Chat', trust_remote_code=True) + self._run_helper(op, source, target) + + def test_all_text_chunk(self): + source = [ + { + 'text': "Today is Sunday and it's a happy day!" + }, + { + 'text': + "Sur la plateforme MT4, plusieurs manières d'accéder à \n" + 'ces fonctionnalités sont conçues simultanément.' + }, + { + 'text': '欢迎来到阿里巴巴!' + }, + ] + target = [ + { + 'text': "Today is Sunday and it's a happy day!" + }, + { + 'text': "Sur la plateforme MT4, plusieurs manières" + }, + { + 'text': "ières d'accéder à " + }, + { + 'text': "ces fonctionnalités sont conçues simultan" + }, + { + 'text': "anément." + }, + { + 'text': '欢迎来到阿里巴巴!' + }, + ] + op = TextChunkMapper(max_len=10, overlap_len=1, split_pattern='\n', hf_tokenizer='Qwen/Qwen-7B-Chat', trust_remote_code=True) + self._run_helper(op, source, target) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/ops/mapper/test_video_captioning_from_audio_mapper.py b/tests/ops/mapper/test_video_captioning_from_audio_mapper.py index caadeb97b..402509639 100644 --- a/tests/ops/mapper/test_video_captioning_from_audio_mapper.py +++ b/tests/ops/mapper/test_video_captioning_from_audio_mapper.py @@ -9,7 +9,7 @@ DataJuicerTestCaseBase) -# Skip tests for this OP in the GitHub actions due to disk space limitation. +# Skip tests for this OP in the GitHub actions due to OOM on the current runner # These tests have been tested locally. @SKIPPED_TESTS.register_module() class VideoCaptioningFromAudioMapperTest(DataJuicerTestCaseBase): diff --git a/tests/ops/mapper/test_video_captioning_from_frames_mapper.py b/tests/ops/mapper/test_video_captioning_from_frames_mapper.py index 71bf963f6..d9bf29724 100644 --- a/tests/ops/mapper/test_video_captioning_from_frames_mapper.py +++ b/tests/ops/mapper/test_video_captioning_from_frames_mapper.py @@ -9,7 +9,7 @@ DataJuicerTestCaseBase) -# Skip tests for this OP in the GitHub actions due to disk space limitation. +# Skip tests for this OP in the GitHub actions due to OOM on the current runner # These tests have been tested locally. @SKIPPED_TESTS.register_module() class VideoCaptioningFromFramesMapperTest(DataJuicerTestCaseBase): diff --git a/tests/ops/mapper/test_video_captioning_from_summarizer_mapper.py b/tests/ops/mapper/test_video_captioning_from_summarizer_mapper.py index 79f8037b9..016a4d73b 100644 --- a/tests/ops/mapper/test_video_captioning_from_summarizer_mapper.py +++ b/tests/ops/mapper/test_video_captioning_from_summarizer_mapper.py @@ -9,7 +9,7 @@ DataJuicerTestCaseBase) -# Skip tests for this OP in the GitHub actions due to disk space limitation. +# Skip tests for this OP in the GitHub actions due to OOM on the current runner # These tests have been tested locally. @SKIPPED_TESTS.register_module() class VideoCaptioningFromSummarizerMapperTest(DataJuicerTestCaseBase): diff --git a/tests/ops/mapper/test_video_captioning_from_video_mapper.py b/tests/ops/mapper/test_video_captioning_from_video_mapper.py index 012761af5..f3de27226 100644 --- a/tests/ops/mapper/test_video_captioning_from_video_mapper.py +++ b/tests/ops/mapper/test_video_captioning_from_video_mapper.py @@ -9,7 +9,7 @@ DataJuicerTestCaseBase) -# Skip tests for this OP in the GitHub actions due to disk space limitation. +# Skip tests for this OP in the GitHub actions due to OOM on the current runner # These tests have been tested locally. @SKIPPED_TESTS.register_module() class VideoCaptioningFromVideoMapperTest(DataJuicerTestCaseBase): diff --git a/tests/ops/mapper/test_video_remove_watermark_mapper.py b/tests/ops/mapper/test_video_remove_watermark_mapper.py index 0cfefa76f..96231a976 100644 --- a/tests/ops/mapper/test_video_remove_watermark_mapper.py +++ b/tests/ops/mapper/test_video_remove_watermark_mapper.py @@ -11,7 +11,6 @@ DataJuicerTestCaseBase) -@SKIPPED_TESTS.register_module() class VideoRemoveWatermarkMapperTest(DataJuicerTestCaseBase): data_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', diff --git a/tests/ops/mapper/test_video_split_by_scene_mapper.py b/tests/ops/mapper/test_video_split_by_scene_mapper.py index dbbc32553..6e71789e6 100644 --- a/tests/ops/mapper/test_video_split_by_scene_mapper.py +++ b/tests/ops/mapper/test_video_split_by_scene_mapper.py @@ -8,7 +8,6 @@ from data_juicer.utils.mm_utils import SpecialTokens from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase, SKIPPED_TESTS -@SKIPPED_TESTS.register_module() class VideoSplitBySceneMapperTest(DataJuicerTestCaseBase): data_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', diff --git a/tests/ops/mapper/test_video_tagging_from_frames_mapper.py b/tests/ops/mapper/test_video_tagging_from_frames_mapper.py index b310591a4..4484df754 100644 --- a/tests/ops/mapper/test_video_tagging_from_frames_mapper.py +++ b/tests/ops/mapper/test_video_tagging_from_frames_mapper.py @@ -9,6 +9,8 @@ from data_juicer.utils.mm_utils import SpecialTokens from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase, SKIPPED_TESTS +# Skip tests for this OP in the GitHub actions due to OOM on the current runner +# These tests have been tested locally. @SKIPPED_TESTS.register_module() class VideoTaggingFromFramesMapperTest(DataJuicerTestCaseBase): data_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', diff --git a/tests/ops/test_op_fusion.py b/tests/ops/test_op_fusion.py index d545e0074..271737154 100644 --- a/tests/ops/test_op_fusion.py +++ b/tests/ops/test_op_fusion.py @@ -5,7 +5,6 @@ DataJuicerTestCaseBase) -@SKIPPED_TESTS.register_module() class OpFusionTest(DataJuicerTestCaseBase): def _run_op_fusion(self, original_process_list, target_process_list): diff --git a/tests/run.py b/tests/run.py index 3b7b99736..378d6b0d6 100644 --- a/tests/run.py +++ b/tests/run.py @@ -10,10 +10,11 @@ import os import sys import unittest +import coverage from loguru import logger -from data_juicer.utils.unittest_utils import SKIPPED_TESTS +from data_juicer.utils.unittest_utils import SKIPPED_TESTS, set_clear_model_flag file_dir = os.path.join(os.path.dirname(__file__), '..') sys.path.append(file_dir) @@ -26,8 +27,14 @@ parser.add_argument('--test_dir', default='tests', help='directory to be tested') +parser.add_argument('--clear_model', + default=False, + type=bool, + help='whether to clear the downloaded models for tests. ' + 'It\'s False in default.') args = parser.parse_args() +set_clear_model_flag(args.clear_model) class TaggedTestLoader(unittest.TestLoader): def __init__(self, tag="standalone"): @@ -66,13 +73,21 @@ def gather_test_cases(test_dir, pattern, tag): def main(): + cov = coverage.Coverage() + cov.start() + runner = unittest.TextTestRunner() test_suite = gather_test_cases(os.path.abspath(args.test_dir), args.pattern, args.tag) res = runner.run(test_suite) + + cov.stop() + if not res.wasSuccessful(): exit(1) + cov.report() + if __name__ == '__main__': main() From 4d1670fbc58ef4b7cd8a0334a4f01aa4e1a7e8da Mon Sep 17 00:00:00 2001 From: "gece.gc" Date: Wed, 6 Nov 2024 02:36:07 +0000 Subject: [PATCH 016/118] fix bugs --- data_juicer/core/data.py | 2 +- data_juicer/ops/mapper/optimize_qa_mapper.py | 5 ++--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/data_juicer/core/data.py b/data_juicer/core/data.py index fc74ef802..7e51bd1f8 100644 --- a/data_juicer/core/data.py +++ b/data_juicer/core/data.py @@ -249,7 +249,7 @@ def map(self, *args, **kargs): if callable(getattr( called_func.__self__, 'is_batched_op')) and called_func.__self__.is_batched_op( - ) or not called_func.__self__.get('turbo', False): + ) or not getattr(called_func.__self__, 'turbo', False): kargs['batched'] = True kargs['batch_size'] = kargs.pop('batch_size', 1) if hasattr( called_func.__self__, 'is_batched_op' diff --git a/data_juicer/ops/mapper/optimize_qa_mapper.py b/data_juicer/ops/mapper/optimize_qa_mapper.py index eda705e5f..cd5a0aba7 100644 --- a/data_juicer/ops/mapper/optimize_qa_mapper.py +++ b/data_juicer/ops/mapper/optimize_qa_mapper.py @@ -107,9 +107,8 @@ def build_input(self, sample): def parse_output(self, raw_output): logger.debug(raw_output) - matches = re.findall(self.output_pattern, raw_output, re.DOTALL) - if matches: - match = matches[0] + match = re.match(self.output_pattern, raw_output, re.DOTALL) + if match: return match.group(1).strip(), match.group(2).strip() else: return None, None From 9e11aa39c5618d60a88c858997374d285378ec13 Mon Sep 17 00:00:00 2001 From: "gece.gc" Date: Wed, 6 Nov 2024 02:52:15 +0000 Subject: [PATCH 017/118] fix tests --- tests/ops/mapper/test_calibrate_qa_mapper.py | 2 +- tests/run.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/ops/mapper/test_calibrate_qa_mapper.py b/tests/ops/mapper/test_calibrate_qa_mapper.py index 0f42b8b84..56f26cb05 100644 --- a/tests/ops/mapper/test_calibrate_qa_mapper.py +++ b/tests/ops/mapper/test_calibrate_qa_mapper.py @@ -10,7 +10,7 @@ # Skip tests for this OP because the API call is not configured yet. # These tests have been tested locally. @SKIPPED_TESTS.register_module() -class OptimizeQAMapperTest(DataJuicerTestCaseBase): +class CalibrateQAMapperTest(DataJuicerTestCaseBase): def _run_op(self, api_model, response_path=None): diff --git a/tests/run.py b/tests/run.py index 378d6b0d6..376071b3b 100644 --- a/tests/run.py +++ b/tests/run.py @@ -63,6 +63,8 @@ def gather_test_cases(test_dir, pattern, tag): print('suite_discovered', suite_discovered) for test_suite in suite_discovered: print('test_suite', test_suite) + if isinstance(test_suite, unittest.loader._FailedTest): + raise test_suite._exception for test_case in test_suite: if type(test_case) in SKIPPED_TESTS.modules.values(): continue From 347bc0f678e60dd6a43f7530f4948e0192834f22 Mon Sep 17 00:00:00 2001 From: "gece.gc" Date: Thu, 7 Nov 2024 05:37:28 +0000 Subject: [PATCH 018/118] refine tests --- .github/workflows/docker/docker-compose.yml | 2 ++ data_juicer/utils/mm_utils.py | 1 - tests/ops/filter/test_video_ocr_area_ratio_filter.py | 2 +- 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/.github/workflows/docker/docker-compose.yml b/.github/workflows/docker/docker-compose.yml index 92a5c76c2..61c2c84a5 100644 --- a/.github/workflows/docker/docker-compose.yml +++ b/.github/workflows/docker/docker-compose.yml @@ -10,6 +10,7 @@ services: - TORCH_HOME=/data/torch - NLTK_DATA=/data/nltk - DATA_JUICER_CACHE_HOME=/data/dj + - EASYOCR_MODULE_PATH=/data/EasyOCR - RAY_ADDRESS=auto working_dir: /workspace networks: @@ -39,6 +40,7 @@ services: - TORCH_HOME=/data/torch - NLTK_DATA=/data/nltk - DATA_JUICER_CACHE_HOME=/data/dj + - EASYOCR_MODULE_PATH=/data/EasyOCR working_dir: /workspace volumes: - huggingface_cache:/data diff --git a/data_juicer/utils/mm_utils.py b/data_juicer/utils/mm_utils.py index 0e4cb8e34..7aa2df833 100644 --- a/data_juicer/utils/mm_utils.py +++ b/data_juicer/utils/mm_utils.py @@ -808,7 +808,6 @@ def parse_string_to_roi(roi_string, roi_type='pixel'): 'format of "x1, y1, x2, y2", "(x1, y1, x2, y2)", or ' '"[x1, y1, x2, y2]".') return None - return None def close_video(container: av.container.InputContainer): diff --git a/tests/ops/filter/test_video_ocr_area_ratio_filter.py b/tests/ops/filter/test_video_ocr_area_ratio_filter.py index b7c7e6f50..1adfdc021 100644 --- a/tests/ops/filter/test_video_ocr_area_ratio_filter.py +++ b/tests/ops/filter/test_video_ocr_area_ratio_filter.py @@ -6,7 +6,7 @@ from data_juicer.ops.filter.video_ocr_area_ratio_filter import \ VideoOcrAreaRatioFilter from data_juicer.utils.constant import Fields -from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase, SKIPPED_TESTS +from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase class VideoOcrAreaRatioFilterTest(DataJuicerTestCaseBase): From c9d505109c48dfafc03d1078fe3a0886cb905ad6 Mon Sep 17 00:00:00 2001 From: Haibin <1400012807@pku.edu.cn> Date: Fri, 8 Nov 2024 10:49:30 +0800 Subject: [PATCH 019/118] extract nickname --- data_juicer/ops/mapper/__init__.py | 13 +- .../mapper/extract_entity_attribute_mapper.py | 97 ++++++----- .../ops/mapper/extract_event_mapper.py | 70 +++++--- .../ops/mapper/extract_nickname_mapper.py | 155 ++++++++++++++++++ data_juicer/ops/mapper/text_chunk_mapper.py | 25 +-- data_juicer/utils/constant.py | 16 +- .../test_extract_entity_attribute_mapper.py | 15 +- tests/ops/mapper/test_extract_event_mapper.py | 11 +- tests/ops/mapper/test_text_chunk_mapper.py | 64 +++++++- 9 files changed, 359 insertions(+), 107 deletions(-) create mode 100644 data_juicer/ops/mapper/extract_nickname_mapper.py diff --git a/data_juicer/ops/mapper/__init__.py b/data_juicer/ops/mapper/__init__.py index 5f21da953..0350049a4 100644 --- a/data_juicer/ops/mapper/__init__.py +++ b/data_juicer/ops/mapper/__init__.py @@ -11,6 +11,7 @@ from .expand_macro_mapper import ExpandMacroMapper from .extract_entity_attribute_mapper import ExtractEntityAttributeMapper from .extract_event_mapper import ExtractEventMapper +from .extract_nickname_mapper import ExtractNicknameMapper from .fix_unicode_mapper import FixUnicodeMapper from .generate_qa_from_examples_mapper import GenerateQAFromExamplesMapper from .generate_qa_from_text_mapper import GenerateQAFromTextMapper @@ -63,12 +64,12 @@ 'CalibrateResponseMapper', 'ChineseConvertMapper', 'CleanCopyrightMapper', 'CleanEmailMapper', 'CleanHtmlMapper', 'CleanIpMapper', 'CleanLinksMapper', 'ExpandMacroMapper', 'ExtractEntityAttributeMapper', 'ExtractEventMapper', - 'FixUnicodeMapper', 'GenerateQAFromExamplesMapper', - 'GenerateQAFromTextMapper', 'ImageBlurMapper', - 'ImageCaptioningFromGPT4VMapper', 'ImageCaptioningMapper', - 'ImageDiffusionMapper', 'ImageFaceBlurMapper', 'ImageTaggingMapper', - 'NlpaugEnMapper', 'NlpcdaZhMapper', 'OptimizeQAMapper', - 'OptimizeQueryMapper', 'OptimizeResponseMapper', + 'ExtractNicknameMapper', 'FixUnicodeMapper', + 'GenerateQAFromExamplesMapper', 'GenerateQAFromTextMapper', + 'ImageBlurMapper', 'ImageCaptioningFromGPT4VMapper', + 'ImageCaptioningMapper', 'ImageDiffusionMapper', 'ImageFaceBlurMapper', + 'ImageTaggingMapper', 'NlpaugEnMapper', 'NlpcdaZhMapper', + 'OptimizeQAMapper', 'OptimizeQueryMapper', 'OptimizeResponseMapper', 'PunctuationNormalizationMapper', 'RemoveBibliographyMapper', 'RemoveCommentsMapper', 'RemoveHeaderMapper', 'RemoveLongWordsMapper', 'RemoveNonChineseCharacterlMapper', 'RemoveRepeatSentencesMapper', diff --git a/data_juicer/ops/mapper/extract_entity_attribute_mapper.py b/data_juicer/ops/mapper/extract_entity_attribute_mapper.py index 8d302ed46..b7fa84435 100644 --- a/data_juicer/ops/mapper/extract_entity_attribute_mapper.py +++ b/data_juicer/ops/mapper/extract_entity_attribute_mapper.py @@ -1,4 +1,3 @@ -import json import re from itertools import chain from typing import Dict, List, Optional @@ -21,6 +20,8 @@ class ExtractEntityAttributeMapper(Mapper): Extract attributes for given entities from the text """ + _batched_op = True + DEFAULT_SYSTEM_PROMPT_TEMPLATE = ( '给定一段文本,从文本中总结{entity}的{attribute},并且从原文摘录最能说明该{attribute}的代表性示例。\n' '要求:\n' @@ -35,17 +36,18 @@ class ExtractEntityAttributeMapper(Mapper): '...\n') DEFAULT_INPUT_TEMPLATE = '# 文本\n```\n{text}\n```\n' - DEFAULT_ATTR_PATTERN_TEMPLATE = r'\#\#\s*\{attribute\}:\s*(?=\#\#\#|\Z)' - DEFAULT_DEMON_PATTERN = r'\#\#\#\s*代表性示例(\d+):\s*(?=\#\#\#|\Z)' + DEFAULT_ATTR_PATTERN_TEMPLATE = r'\#\#\s*{attribute}:\s*(.*?)(?=\#\#\#|\Z)' + DEFAULT_DEMON_PATTERN = r'\#\#\#\s*代表性示例(\d+):\s*(.*?)(?=\#\#\#|\Z)' def __init__(self, query_entities: List[str], query_attributes: List[str], api_model: str = 'gpt-4o', *, - entity_key: str = Fields.entity, - entity_attribute_key: str = Fields.entity_attribute, - support_text_key: str = Fields.support_text, + entity_key: str = Fields.main_entity, + attribute_key: str = Fields.attribute, + attribute_desc_key: str = Fields.attribute_description, + support_text_key: str = Fields.attribute_support_text, api_url: Optional[str] = None, api_key: Optional[str] = None, response_path: Optional[str] = None, @@ -54,18 +56,22 @@ def __init__(self, attr_pattern_template: Optional[str] = None, demo_pattern: Optional[str] = None, try_num: PositiveInt = 3, + drop_text: bool = False, api_params: Optional[Dict] = None, **kwargs): """ Initialization method. - :param query_entities: entity list to be queried. - :param query_attributes: attribute list to be queried. + :param query_entities: Entity list to be queried. + :param query_attributes: Attribute list to be queried. :param api_model: API model name. - :param entity_key: the field name to store the entity. - It's "__dj__entity__" in default. - :param entity_attribute_key: the field name to store the attribute. - It's "__dj__entity_attribute__" in default. - :param support_text_key: the field name to store the attribute + :param entity_key: The field name to store the given main entity for + attribute extraction. It's "__dj__entity__" in default. + :param entity_attribute_key: The field name to store the given + attribute to be extracted. It's "__dj__attribute__" in default. + :param attribute_desc_key: The field name to store the extracted + attribute description. It's "__dj__attribute_description__" in + default. + :param support_text_key: The field name to store the attribute support text extracted from the raw text. It's "__dj__support_text__" in default. :param api_url: API URL. Defaults to DJ_API_URL environment variable. @@ -81,6 +87,7 @@ def __init__(self, output to support the attribute. :param try_num: The number of retry attempts when there is an API call error or output parsing error. + :param drop_text: If drop the text in the output. :param api_params: Extra parameters passed to the API call. e.g {'temperature': 0.9, 'top_p': 0.95} :param kwargs: Extra keyword arguments. @@ -91,7 +98,8 @@ def __init__(self, self.query_attributes = query_attributes self.entity_key = entity_key - self.entity_attribute_key = entity_attribute_key + self.attribute_key = attribute_key + self.attribute_desc_key = attribute_desc_key self.support_text_key = support_text_key self.system_prompt_template = system_prompt_template \ @@ -109,40 +117,34 @@ def __init__(self, response_path=response_path) self.try_num = try_num + self.drop_text = drop_text def parse_output(self, raw_output, attribute_name): attribute_pattern = self.attr_pattern_template.format( attribute=attribute_name) - print('attribute_pattern', attribute_pattern) pattern = re.compile(attribute_pattern, re.VERBOSE | re.DOTALL) matches = pattern.findall(raw_output) if matches: - attribute = matches[0] + attribute = matches[0].strip() else: attribute = '' pattern = re.compile(self.demo_pattern, re.VERBOSE | re.DOTALL) matches = pattern.findall(raw_output) + demos = [demo.strip() for _, demo in matches if demo.strip()] - demos = [] - for match in matches: - _, demo = match - demos.append(demo) - - return {attribute_name: attribute, self.support_text_key: demos} + return attribute, demos def _process_single_sample(self, text='', rank=None): client = get_model(self.model_key, rank=rank) - results = [] + entities, attributes, descs, demo_lists = [], [], [], [] for entity in self.query_entities: for attribute in self.query_attributes: system_prompt = self.system_prompt_template.format( entity=entity, attribute=attribute) - print(system_prompt) input_prompt = self.input_template.format(text=text) - print(input_prompt) messages = [{ 'role': 'system', 'content': system_prompt @@ -151,40 +153,45 @@ def _process_single_sample(self, text='', rank=None): 'content': input_prompt }] - result = {attribute: '', self.support_text_key: []} + desc, demos = '', [] for i in range(self.try_num): try: output = client(messages, **self.api_params) - print(output) - result = self.parse_output(output, attribute) - if result[attribute]: + desc, demos = self.parse_output(output, attribute) + if desc and len(demos) > 0: break except Exception as e: logger.warning(f'Exception: {e}') - result = json.dumps({ - self.entity_key: entity, - self.entity_attribute_key: result - }) - print(result) + entities.append(entity) + attributes.append(attribute) + descs.append(desc) + demo_lists.append(demos) - results.append(result) - - return results + return entities, attributes, descs, demo_lists def process_batched(self, samples): sample_num = len(samples[self.text_key]) - samples[self.response_key] = [ - self._process_single_sample(text) - for text in samples[self.text_key] - ] + entities, attributes, descs, demo_lists = [], [], [], [] + for text in samples[self.text_key]: + res = self._process_single_sample(text) + cur_ents, cur_attrs, cur_descs, cur_demos = res + entities.append(cur_ents) + attributes.append(cur_attrs) + descs.append(cur_descs) + demo_lists.append(cur_demos) + + if self.drop_text: + samples.pop(self.text_key) for key in samples: - if key != self.response_key: - samples[key] = [[samples[key][i]] * - len(samples[self.response_key][i]) - for i in range(len(sample_num))] + samples[key] = [[samples[key][i]] * len(descs[i]) + for i in range(sample_num)] + samples[self.entity_key] = entities + samples[self.attribute_key] = attributes + samples[self.attribute_desc_key] = descs + samples[self.support_text_key] = demo_lists for key in samples: samples[key] = list(chain(*samples[key])) diff --git a/data_juicer/ops/mapper/extract_event_mapper.py b/data_juicer/ops/mapper/extract_event_mapper.py index cdde5268d..d731eb746 100644 --- a/data_juicer/ops/mapper/extract_event_mapper.py +++ b/data_juicer/ops/mapper/extract_event_mapper.py @@ -1,5 +1,5 @@ -import json import re +from itertools import chain from typing import Dict, Optional from loguru import logger @@ -22,6 +22,8 @@ class ExtractEventMapper(Mapper): Extract events and relavant characters in the text """ + _batched_op = True + DEFAULT_SYSTEM_PROMPT = ('给定一段文本,对文本的情节进行分点总结,并抽取与情节相关的人物。\n' '要求:\n' '- 尽量不要遗漏内容,不要添加文本中没有的情节,符合原文事实\n' @@ -60,16 +62,17 @@ def __init__(self, input_template: Optional[str] = None, output_pattern: Optional[str] = None, try_num: PositiveInt = 3, + drop_text: bool = False, api_params: Optional[Dict] = None, **kwargs): """ Initialization method. :param api_model: API model name. - :param event_desc_key: the field name to store the event descriptions - in response. It's "__dj__event_description__" in default. - :param relavant_char_key: the field name to store the relavant - characters to the events in response. - It's "__dj__relavant_characters__" in default. + :param event_desc_key: The field name to store the event descriptions. + It's "__dj__event_description__" in default. + :param relavant_char_key: The field name to store the relavant + characters to the events. It's "__dj__relavant_characters__" in + default. :param api_url: API URL. Defaults to DJ_API_URL environment variable. :param api_key: API key. Defaults to DJ_API_KEY environment variable. :param response_path: Path to extract content from the API response. @@ -79,6 +82,7 @@ def __init__(self, :param output_pattern: Regular expression for parsing model output. :param try_num: The number of retry attempts when there is an API call error or output parsing error. + :param drop_text: If drop the text in the output. :param api_params: Extra parameters passed to the API call. e.g {'temperature': 0.9, 'top_p': 0.95} :param kwargs: Extra keyword arguments. @@ -100,27 +104,27 @@ def __init__(self, response_path=response_path) self.try_num = try_num + self.drop_text = drop_text def parse_output(self, raw_output): pattern = re.compile(self.output_pattern, re.VERBOSE | re.DOTALL) matches = pattern.findall(raw_output) - contents = [] + event_list, character_list = [], [] + for match in matches: - _, description, characters = match - contents.append({ - self.event_desc_key: - description.strip(), - self.relavant_char_key: - split_text_by_punctuation(characters) - }) + _, desc, chars = match + chars = split_text_by_punctuation(chars) + if len(chars) > 0: + event_list.append(desc) + character_list.append(chars) - return contents + return event_list, character_list - def process_single(self, sample=None, rank=None): + def _process_single_sample(self, text='', rank=None): client = get_model(self.model_key, rank=rank) - input_prompt = self.input_template.format(text=sample[self.text_key]) + input_prompt = self.input_template.format(text=text) messages = [{ 'role': 'system', 'content': self.system_prompt @@ -129,16 +133,38 @@ def process_single(self, sample=None, rank=None): 'content': input_prompt }] - contents = [] + event_list, character_list = [], [] for i in range(self.try_num): try: output = client(messages, **self.api_params) - contents = self.parse_output(output) - if len(contents) > 0: + event_list, character_list = self.parse_output(output) + if len(event_list) > 0: break except Exception as e: logger.warning(f'Exception: {e}') - sample[self.response_key] = json.dumps(contents) + return event_list, character_list + + def process_batched(self, samples): + + sample_num = len(samples[self.text_key]) + + events, characters = [], [] + for text in samples[self.text_key]: + cur_events, cur_characters = self._process_single_sample(text) + events.append(cur_events) + characters.append(cur_characters) + + if self.drop_text: + samples.pop(self.text_key) + + for key in samples: + samples[key] = [[samples[key][i]] * len(events[i]) + for i in range(sample_num)] + samples[self.event_desc_key] = events + samples[self.relavant_char_key] = characters + + for key in samples: + samples[key] = list(chain(*samples[key])) - return sample + return samples diff --git a/data_juicer/ops/mapper/extract_nickname_mapper.py b/data_juicer/ops/mapper/extract_nickname_mapper.py new file mode 100644 index 000000000..4dfe2b0e7 --- /dev/null +++ b/data_juicer/ops/mapper/extract_nickname_mapper.py @@ -0,0 +1,155 @@ +import re +from typing import Dict, Optional + +from loguru import logger +from pydantic import PositiveInt + +from data_juicer.ops.base_op import OPERATORS, UNFORKABLE, Mapper +from data_juicer.utils.constant import Fields +from data_juicer.utils.model_utils import get_model, prepare_model + +OP_NAME = 'extract_nickname_mapper' + + +# TODO: LLM-based inference. +@UNFORKABLE.register_module(OP_NAME) +@OPERATORS.register_module(OP_NAME) +class ExtractNicknameMapper(Mapper): + """ + Extract nickname relationship in the text + """ + + DEFAULT_SYSTEM_PROMPT = ('给定你一段文本,你的任务是将人物之间的称呼方式(昵称)提取出来。\n' + '要求:\n' + '- 需要给出说话人对被称呼人的称呼,不要搞反了。\n' + '- 相同的说话人和被称呼人最多给出一个最常用的称呼。\n' + '- 请不要输出互相没有昵称的称呼方式。\n' + '- 输出格式如下:\n' + '```\n' + '### 称呼方式1\n' + '- **说话人**:...\n' + '- **被称呼人**:...\n' + '- **...对...的昵称**:...\n' + '### 称呼方式2\n' + '- **说话人**:...\n' + '- **被称呼人**:...\n' + '- **...对...的昵称**:...\n' + '### 称呼方式3\n' + '- **说话人**:...\n' + '- **被称呼人**:...\n' + '- **...对...的昵称**:...\n' + '...\n' + '```\n') + DEFAULT_INPUT_TEMPLATE = '# 文本\n```\n{text}\n```\n' + DEFAULT_OUTPUT_PATTERN = r""" + \#\#\#\s*称呼方式(\d+)\s* + -\s*\*\*说话人\*\*\s*:\s*(.*?)\s* + -\s*\*\*被称呼人\*\*\s*:\s*(.*?)\s* + -\s*\*\*(.*?)对(.*?)的昵称\*\*\s*:\s*(.*?)(?=\#\#\#|\Z) # for double check + """ + + def __init__(self, + api_model: str = 'gpt-4o', + *, + nickname_key: str = Fields.nickname, + api_url: Optional[str] = None, + api_key: Optional[str] = None, + response_path: Optional[str] = None, + system_prompt: Optional[str] = None, + input_template: Optional[str] = None, + output_pattern: Optional[str] = None, + try_num: PositiveInt = 3, + drop_text: bool = False, + api_params: Optional[Dict] = None, + **kwargs): + """ + Initialization method. + :param api_model: API model name. + :param nickname_key: The field name to store the nickname + relationship. It's "__dj__nickname__" in default. + :param api_url: API URL. Defaults to DJ_API_URL environment variable. + :param api_key: API key. Defaults to DJ_API_KEY environment variable. + :param response_path: Path to extract content from the API response. + Defaults to 'choices.0.message.content'. + :param system_prompt: System prompt for the calibration task. + :param input_template: Template for building the model input. + :param output_pattern: Regular expression for parsing model output. + :param try_num: The number of retry attempts when there is an API + call error or output parsing error. + :param drop_text: If drop the text in the output. + :param api_params: Extra parameters passed to the API call. + e.g {'temperature': 0.9, 'top_p': 0.95} + :param kwargs: Extra keyword arguments. + """ + super().__init__(**kwargs) + + self.nickname_key = nickname_key + + self.system_prompt = system_prompt or self.DEFAULT_SYSTEM_PROMPT + self.input_template = input_template or self.DEFAULT_INPUT_TEMPLATE + self.output_pattern = output_pattern or self.DEFAULT_OUTPUT_PATTERN + + self.api_params = api_params or {} + self.model_key = prepare_model(model_type='api', + api_model=api_model, + api_url=api_url, + api_key=api_key, + response_path=response_path) + + self.try_num = try_num + self.drop_text = drop_text + + def parse_output(self, raw_output): + pattern = re.compile(self.output_pattern, re.VERBOSE | re.DOTALL) + matches = pattern.findall(raw_output) + + nickname_relations = [] + + for match in matches: + _, role1, role2, role1_tmp, role2_tmp, nickname = match + # for double check + if role1.strip() != role1_tmp.strip() or role2.strip( + ) != role2_tmp.strip(): + continue + role1 = role1.strip() + role2 = role2.strip() + nickname = nickname.strip() + # is name but not nickname + if role2 == nickname: + continue + if role1 and role2 and nickname: + nickname_relations.append((role1, role2, nickname)) + nickname_relations = list(set(nickname_relations)) + + nickname_relations = [{ + 'entity1': nr[0], + 'entity2': nr[1], + 'description': nr[2], + 'relation': 'nickname' + } for nr in nickname_relations] + + return nickname_relations + + def process_single(self, sample=None, rank=None): + client = get_model(self.model_key, rank=rank) + + input_prompt = self.input_template.format(text=sample[self.text_key]) + messages = [{ + 'role': 'system', + 'content': self.system_prompt + }, { + 'role': 'user', + 'content': input_prompt + }] + nickname_relations = [] + for i in range(self.try_num): + try: + output = client(messages, **self.sampling_params) + nickname_relations = self.parse_output(output) + if len(nickname_relations) > 0: + break + except Exception as e: + logger.warning(f'Exception: {e}') + + sample[self.nickname_key] = nickname_relations + return sample diff --git a/data_juicer/ops/mapper/text_chunk_mapper.py b/data_juicer/ops/mapper/text_chunk_mapper.py index a659da7d9..ca3b5e5ec 100644 --- a/data_juicer/ops/mapper/text_chunk_mapper.py +++ b/data_juicer/ops/mapper/text_chunk_mapper.py @@ -21,7 +21,8 @@ def __init__(self, max_len: Union[PositiveInt, None] = None, split_pattern: Union[str, None] = r'\n\n', overlap_len: NonNegativeInt = 0, - hf_tokenizer: Union[str, None] = None, + tokenizer: Union[str, None] = None, + tokenizer_type: str = 'huggingface', trust_remote_code: bool = False, *args, **kwargs): @@ -34,9 +35,11 @@ def __init__(self, and force cut if the length exceeds max_len. :param overlap_len: Overlap length of the split texts if not split in the split pattern. - :param hf_tokenizer: The tokenizer name of Hugging Face tokenizers. + :param tokenizer: The tokenizer name of Hugging Face tokenizers. The text length will be calculate as the token num if it is offerd. Otherwise, the text length equals to string length. + :param tokenizer_type: The type of tokenizer, it should be + 'huggingface' or 'api'. :trust_remote_code: for loading huggingface model :param args: extra args :param kwargs: extra args @@ -52,16 +55,16 @@ def __init__(self, self.max_len = max_len self.overlap_len = overlap_len self.split_pattern = split_pattern - self.hf_tokenizer = hf_tokenizer - if hf_tokenizer is not None: + self.tokenizer_name = tokenizer + if tokenizer is not None: self.model_key = prepare_model( - model_type='huggingface', - pretrained_model_name_or_path=hf_tokenizer, + model_type=tokenizer_type, + pretrained_model_name_or_path=tokenizer, return_model=False, trust_remote_code=trust_remote_code) def recursively_chunk(self, text): - if self.hf_tokenizer is not None: + if self.tokenizer_name is not None: tokenizer = get_model(self.model_key) tokens = tokenizer.tokenize(text) tokens = [t.decode(encoding='UTF-8') for t in tokens] @@ -77,7 +80,7 @@ def recursively_chunk(self, text): matches = list(re.finditer(self.split_pattern, sub_text)) if not matches: cur_text = sub_text - if self.hf_tokenizer is not None: + if self.tokenizer_name is not None: left_text = ''.join(tokens[self.max_len - self.overlap_len:]) else: left_text = text[self.max_len - self.overlap_len:] @@ -96,7 +99,7 @@ def get_text_chunks(self, text): elif self.split_pattern is None and self.max_len is not None: tokens = text total_len = len(text) - if self.hf_tokenizer is not None: + if self.tokenizer_name is not None: tokenizer = get_model(self.model_key) tokens = tokenizer.tokenize(text) tokens = [t.decode(encoding='UTF-8') for t in tokens] @@ -106,7 +109,7 @@ def get_text_chunks(self, text): chunks = [] for start in range(0, total_len, self.max_len - self.overlap_len): cur = tokens[start:start + self.max_len] - if self.hf_tokenizer is not None: + if self.tokenizer_name is not None: cur = ''.join(cur) chunks.append(cur) else: @@ -126,7 +129,7 @@ def process_batched(self, samples): if key != self.text_key: samples[key] = [[samples[key][i]] * len(samples[self.text_key][i]) - for i in range(len(sample_num))] + for i in range(sample_num)] for key in samples: samples[key] = list(chain(*samples[key])) diff --git a/data_juicer/utils/constant.py b/data_juicer/utils/constant.py index 9d676d8a1..17cf54749 100644 --- a/data_juicer/utils/constant.py +++ b/data_juicer/utils/constant.py @@ -32,12 +32,16 @@ class Fields(object): event_description = DEFAULT_PREFIX + 'event_description__' # # a list of characters relavant to the event relavant_characters = DEFAULT_PREFIX + 'relavant_characters__' - # # entity in the knowlege graph - entity = DEFAULT_PREFIX + 'entity__' - # # a list of attributes for the entity - entity_attribute = DEFAULT_PREFIX + 'entity_attribute__' - # # extract from raw data for support some summary - support_text = DEFAULT_PREFIX + 'support_text__' + # # the given main entity for attribute extraction + main_entity = DEFAULT_PREFIX + 'main_entity__' + # # the given attribute to be extracted + attribute = DEFAULT_PREFIX + 'attribute__' + # # the extracted attribute description + attribute_description = DEFAULT_PREFIX + 'attribute_description__' + # # extract from raw data for support the attribute + attribute_support_text = DEFAULT_PREFIX + 'attribute_support_text__' + # # the nickname relationship + nickname = DEFAULT_PREFIX + 'nickname__' class StatsKeysMeta(type): diff --git a/tests/ops/mapper/test_extract_entity_attribute_mapper.py b/tests/ops/mapper/test_extract_entity_attribute_mapper.py index 4eb8a9034..a14ee76ec 100644 --- a/tests/ops/mapper/test_extract_entity_attribute_mapper.py +++ b/tests/ops/mapper/test_extract_entity_attribute_mapper.py @@ -14,10 +14,8 @@ class ExtractEntityAttributeMapperTest(DataJuicerTestCaseBase): def _run_op(self, api_model, response_path=None): - # query_entities = ["李莲花", "方多病"] - # query_attributes = ["语言风格", "角色性格"] - query_entities = ["李莲花"] - query_attributes = ["语言风格"] + query_entities = ["李莲花", "方多病"] + query_attributes = ["语言风格", "角色性格"] op = ExtractEntityAttributeMapper( query_entities=query_entities, @@ -48,13 +46,8 @@ def _run_op(self, api_model, response_path=None): dataset = Dataset.from_list(samples) dataset = dataset.map(op.process, batch_size=1) for sample in dataset: - response = json.loads(sample['response']) - result = response[Fields.entity_attribute] - self.assertNotEqual(len(result[Fields.support_text]), 0) - for attr in query_attributes: - if attr in result: - self.assertNotEqual(result[attr], '') - # self.assertEqual(len(result['text']), len(query_entities) * len(query_attributes)) + self.assertNotEqual(sample[Fields.attribute_description], '') + self.assertNotEqual(len(sample[Fields.attribute_support_text]), 0) def test(self): # before runing this test, set below environment variables: diff --git a/tests/ops/mapper/test_extract_event_mapper.py b/tests/ops/mapper/test_extract_event_mapper.py index 4c57fb63d..735583650 100644 --- a/tests/ops/mapper/test_extract_event_mapper.py +++ b/tests/ops/mapper/test_extract_event_mapper.py @@ -1,8 +1,10 @@ import unittest import json +from data_juicer.core.data import NestedDataset as Dataset from data_juicer.ops.mapper.extract_event_mapper import ExtractEventMapper from data_juicer.utils.unittest_utils import (SKIPPED_TESTS, DataJuicerTestCaseBase) +from data_juicer.utils.constant import Fields # Skip tests for this OP in the GitHub actions due to unknown DistNetworkError. # These tests have been tested locally. @@ -51,9 +53,12 @@ def _run_op(self, api_model, response_path=None): 'text': raw_text, }] - for sample in samples: - result = op.process(sample) - self.assertNotEqual(result['response'], '[]') + dataset = Dataset.from_list(samples) + dataset = dataset.map(op.process, batch_size=2) + self.assertNotEqual(len(dataset), 0) + for sample in dataset: + self.assertNotEqual(sample[Fields.event_description], '') + self.assertNotEqual(sample[Fields.relavant_characters], []) def test(self): # before runing this test, set below environment variables: diff --git a/tests/ops/mapper/test_text_chunk_mapper.py b/tests/ops/mapper/test_text_chunk_mapper.py index fde3387bf..84d135c4b 100644 --- a/tests/ops/mapper/test_text_chunk_mapper.py +++ b/tests/ops/mapper/test_text_chunk_mapper.py @@ -189,7 +189,11 @@ def test_max_len_and_split_pattern_text_chunk(self): 'text': '欢迎来到阿里巴巴!' }, ] - op = TextChunkMapper(max_len=20, overlap_len=2, split_pattern='\n') + op = TextChunkMapper( + max_len=20, + overlap_len=2, + split_pattern='\n' + ) self._run_helper(op, source, target) def test_tokenizer_text_chunk(self): @@ -223,9 +227,57 @@ def test_tokenizer_text_chunk(self): 'text': '欢迎来到阿里巴巴!' }, ] - op = TextChunkMapper(max_len=10, overlap_len=1, split_pattern=None, hf_tokenizer='Qwen/Qwen-7B-Chat', trust_remote_code=True) + op = TextChunkMapper( + max_len=10, + overlap_len=1, + split_pattern=None, + tokenizer='Qwen/Qwen-7B-Chat', + trust_remote_code=True + ) self._run_helper(op, source, target) + def test_api_tokenizer_text_chunk(self): + source = [ + { + 'text': "Today is Sunday and it's a happy day!" + }, + { + 'text': + "Sur la plateforme MT4, plusieurs manières d'accéder à " + 'ces fonctionnalités sont conçues simultanément.' + }, + { + 'text': '欢迎来到阿里巴巴!' + }, + ] + target = [ + { + 'text': "Today is Sunday and it's a happy day!" + }, + { + 'text': "Sur la plateforme MT4, plusieurs manières" + }, + { + 'text': "ières d'accéder à ces fonctionnalités" + }, + { + 'text': "ités sont conçues simultanément." + }, + { + 'text': '欢迎来到阿里巴巴!' + }, + ] + op = TextChunkMapper( + max_len=10, + overlap_len=1, + split_pattern=None, + tokenizer='gpt-4o', + tokenizer_type='api', + trust_remote_code=True + ) + self._run_helper(op, source, target) + + def test_all_text_chunk(self): source = [ { @@ -260,7 +312,13 @@ def test_all_text_chunk(self): 'text': '欢迎来到阿里巴巴!' }, ] - op = TextChunkMapper(max_len=10, overlap_len=1, split_pattern='\n', hf_tokenizer='Qwen/Qwen-7B-Chat', trust_remote_code=True) + op = TextChunkMapper( + max_len=10, + overlap_len=1, + split_pattern='\n', + tokenizer='Qwen/Qwen-7B-Chat', + trust_remote_code=True + ) self._run_helper(op, source, target) From 9262777409fe9992032f372022d703f669def7ce Mon Sep 17 00:00:00 2001 From: Haibin <1400012807@pku.edu.cn> Date: Fri, 8 Nov 2024 15:02:09 +0800 Subject: [PATCH 020/118] nickname test done --- data_juicer/ops/mapper/calibrate_qa_mapper.py | 27 +++++++--- .../mapper/extract_entity_attribute_mapper.py | 15 +++--- .../ops/mapper/extract_event_mapper.py | 14 +++-- .../ops/mapper/extract_nickname_mapper.py | 12 +++-- data_juicer/ops/mapper/text_chunk_mapper.py | 33 ++++++------ .../test_extract_entity_attribute_mapper.py | 4 ++ tests/ops/mapper/test_extract_event_mapper.py | 5 ++ .../mapper/test_extract_nickname_mapper.py | 54 +++++++++++++++++++ tests/ops/mapper/test_text_chunk_mapper.py | 46 ++++++++++++++-- 9 files changed, 166 insertions(+), 44 deletions(-) create mode 100644 tests/ops/mapper/test_extract_nickname_mapper.py diff --git a/data_juicer/ops/mapper/calibrate_qa_mapper.py b/data_juicer/ops/mapper/calibrate_qa_mapper.py index a13c6dce4..b250f7cd7 100644 --- a/data_juicer/ops/mapper/calibrate_qa_mapper.py +++ b/data_juicer/ops/mapper/calibrate_qa_mapper.py @@ -1,6 +1,9 @@ import re from typing import Dict, Optional +from loguru import logger +from pydantic import PositiveInt + from data_juicer.ops.base_op import OPERATORS, UNFORKABLE, Mapper from data_juicer.utils.model_utils import get_model, prepare_model @@ -38,8 +41,9 @@ def __init__(self, reference_template: Optional[str] = None, qa_pair_template: Optional[str] = None, output_pattern: Optional[str] = None, - model_params: Optional[Dict] = None, - sampling_params: Optional[Dict] = None, + try_num: PositiveInt = 3, + model_params: Optional[Dict] = {}, + sampling_params: Optional[Dict] = {}, **kwargs): """ Initialization method. @@ -56,6 +60,7 @@ def __init__(self, :param output_pattern: Regular expression for parsing model output. :param model_params: Parameters for initializing the model. :param sampling_params: Extra parameters passed to the API call. + e.g {'temperature': 0.9, 'top_p': 0.95} :param kwargs: Extra keyword arguments. """ super().__init__(**kwargs) @@ -68,8 +73,8 @@ def __init__(self, self.DEFAULT_QA_PAIR_TEMPLATE self.output_pattern = output_pattern or self.DEFAULT_OUTPUT_PATTERN - self.model_params = model_params or {} - self.sampling_params = sampling_params or {} + self.model_params = model_params + self.sampling_params = sampling_params self.model_key = prepare_model(model_type='api', api_model=api_model, api_url=api_url, @@ -77,6 +82,8 @@ def __init__(self, response_path=response_path, **model_params) + self.try_num = try_num + def build_input(self, sample): reference = self.reference_template.format(sample[self.text_key]) qa_pair = self.qa_pair_template.format(sample[self.query_key], @@ -102,9 +109,15 @@ def process_single(self, sample=None, rank=None): 'role': 'user', 'content': self.build_input(sample) }] - output = client(messages, **self.sampling_params) - - parsed_q, parsed_a = self.parse_output(output) + parsed_q, parsed_a = None, None + for i in range(self.try_num): + try: + output = client(messages, **self.sampling_params) + parsed_q, parsed_a = self.parse_output(output) + if parsed_q or parsed_a: + break + except Exception as e: + logger.warning(f'Exception: {e}') if parsed_q: sample[self.query_key] = parsed_q if parsed_a: diff --git a/data_juicer/ops/mapper/extract_entity_attribute_mapper.py b/data_juicer/ops/mapper/extract_entity_attribute_mapper.py index b7fa84435..0055d3ad7 100644 --- a/data_juicer/ops/mapper/extract_entity_attribute_mapper.py +++ b/data_juicer/ops/mapper/extract_entity_attribute_mapper.py @@ -57,7 +57,8 @@ def __init__(self, demo_pattern: Optional[str] = None, try_num: PositiveInt = 3, drop_text: bool = False, - api_params: Optional[Dict] = None, + model_params: Optional[Dict] = {}, + sampling_params: Optional[Dict] = {}, **kwargs): """ Initialization method. @@ -87,8 +88,8 @@ def __init__(self, output to support the attribute. :param try_num: The number of retry attempts when there is an API call error or output parsing error. - :param drop_text: If drop the text in the output. - :param api_params: Extra parameters passed to the API call. + :param model_params: Parameters for initializing the model. + :param sampling_params: Extra parameters passed to the API call. e.g {'temperature': 0.9, 'top_p': 0.95} :param kwargs: Extra keyword arguments. """ @@ -109,12 +110,14 @@ def __init__(self, or self.DEFAULT_ATTR_PATTERN_TEMPLATE self.demo_pattern = demo_pattern or self.DEFAULT_DEMON_PATTERN - self.api_params = api_params or {} + self.model_params = model_params + self.sampling_params = sampling_params self.model_key = prepare_model(model_type='api', api_model=api_model, api_url=api_url, api_key=api_key, - response_path=response_path) + response_path=response_path, + **model_params) self.try_num = try_num self.drop_text = drop_text @@ -156,7 +159,7 @@ def _process_single_sample(self, text='', rank=None): desc, demos = '', [] for i in range(self.try_num): try: - output = client(messages, **self.api_params) + output = client(messages, **self.sampling_params) desc, demos = self.parse_output(output, attribute) if desc and len(demos) > 0: break diff --git a/data_juicer/ops/mapper/extract_event_mapper.py b/data_juicer/ops/mapper/extract_event_mapper.py index d731eb746..c70c9ec6e 100644 --- a/data_juicer/ops/mapper/extract_event_mapper.py +++ b/data_juicer/ops/mapper/extract_event_mapper.py @@ -63,7 +63,8 @@ def __init__(self, output_pattern: Optional[str] = None, try_num: PositiveInt = 3, drop_text: bool = False, - api_params: Optional[Dict] = None, + model_params: Optional[Dict] = {}, + sampling_params: Optional[Dict] = {}, **kwargs): """ Initialization method. @@ -83,7 +84,8 @@ def __init__(self, :param try_num: The number of retry attempts when there is an API call error or output parsing error. :param drop_text: If drop the text in the output. - :param api_params: Extra parameters passed to the API call. + :param model_params: Parameters for initializing the model. + :param sampling_params: Extra parameters passed to the API call. e.g {'temperature': 0.9, 'top_p': 0.95} :param kwargs: Extra keyword arguments. """ @@ -96,12 +98,14 @@ def __init__(self, self.input_template = input_template or self.DEFAULT_INPUT_TEMPLATE self.output_pattern = output_pattern or self.DEFAULT_OUTPUT_PATTERN - self.api_params = api_params or {} + self.model_params = model_params + self.sampling_params = sampling_params self.model_key = prepare_model(model_type='api', api_model=api_model, api_url=api_url, api_key=api_key, - response_path=response_path) + response_path=response_path, + **model_params) self.try_num = try_num self.drop_text = drop_text @@ -136,7 +140,7 @@ def _process_single_sample(self, text='', rank=None): event_list, character_list = [], [] for i in range(self.try_num): try: - output = client(messages, **self.api_params) + output = client(messages, **self.sampling_params) event_list, character_list = self.parse_output(output) if len(event_list) > 0: break diff --git a/data_juicer/ops/mapper/extract_nickname_mapper.py b/data_juicer/ops/mapper/extract_nickname_mapper.py index 4dfe2b0e7..eee49a7b6 100644 --- a/data_juicer/ops/mapper/extract_nickname_mapper.py +++ b/data_juicer/ops/mapper/extract_nickname_mapper.py @@ -60,7 +60,8 @@ def __init__(self, output_pattern: Optional[str] = None, try_num: PositiveInt = 3, drop_text: bool = False, - api_params: Optional[Dict] = None, + model_params: Optional[Dict] = {}, + sampling_params: Optional[Dict] = {}, **kwargs): """ Initialization method. @@ -77,7 +78,8 @@ def __init__(self, :param try_num: The number of retry attempts when there is an API call error or output parsing error. :param drop_text: If drop the text in the output. - :param api_params: Extra parameters passed to the API call. + :param model_params: Parameters for initializing the model. + :param sampling_params: Extra parameters passed to the API call. e.g {'temperature': 0.9, 'top_p': 0.95} :param kwargs: Extra keyword arguments. """ @@ -89,12 +91,14 @@ def __init__(self, self.input_template = input_template or self.DEFAULT_INPUT_TEMPLATE self.output_pattern = output_pattern or self.DEFAULT_OUTPUT_PATTERN - self.api_params = api_params or {} + self.model_params = model_params + self.sampling_params = sampling_params self.model_key = prepare_model(model_type='api', api_model=api_model, api_url=api_url, api_key=api_key, - response_path=response_path) + response_path=response_path, + **model_params) self.try_num = try_num self.drop_text = drop_text diff --git a/data_juicer/ops/mapper/text_chunk_mapper.py b/data_juicer/ops/mapper/text_chunk_mapper.py index ca3b5e5ec..2dd8a9441 100644 --- a/data_juicer/ops/mapper/text_chunk_mapper.py +++ b/data_juicer/ops/mapper/text_chunk_mapper.py @@ -22,7 +22,6 @@ def __init__(self, split_pattern: Union[str, None] = r'\n\n', overlap_len: NonNegativeInt = 0, tokenizer: Union[str, None] = None, - tokenizer_type: str = 'huggingface', trust_remote_code: bool = False, *args, **kwargs): @@ -37,9 +36,9 @@ def __init__(self, the split pattern. :param tokenizer: The tokenizer name of Hugging Face tokenizers. The text length will be calculate as the token num if it is offerd. - Otherwise, the text length equals to string length. - :param tokenizer_type: The type of tokenizer, it should be - 'huggingface' or 'api'. + Otherwise, the text length equals to string length. Support + tiktoken tokenizer (such as gpt-4o), dashscope tokenizer (such as + qwen2.5-72b-instruct) and huggingface tokenizer. :trust_remote_code: for loading huggingface model :param args: extra args :param kwargs: extra args @@ -57,19 +56,17 @@ def __init__(self, self.split_pattern = split_pattern self.tokenizer_name = tokenizer if tokenizer is not None: - self.model_key = prepare_model( - model_type=tokenizer_type, - pretrained_model_name_or_path=tokenizer, - return_model=False, - trust_remote_code=trust_remote_code) + self.model_key = prepare_model(model_type='api', + api_model=tokenizer, + return_processor=True, + trust_remote_code=trust_remote_code) def recursively_chunk(self, text): if self.tokenizer_name is not None: - tokenizer = get_model(self.model_key) - tokens = tokenizer.tokenize(text) - tokens = [t.decode(encoding='UTF-8') for t in tokens] + _, tokenizer = get_model(self.model_key) + tokens = tokenizer.encode(text) total_len = len(tokens) - sub_text = ''.join(tokens[:self.max_len]) + sub_text = tokenizer.decode(tokens[:self.max_len]) else: total_len = len(text) sub_text = text[:self.max_len] @@ -81,7 +78,8 @@ def recursively_chunk(self, text): if not matches: cur_text = sub_text if self.tokenizer_name is not None: - left_text = ''.join(tokens[self.max_len - self.overlap_len:]) + left_text = tokenizer.decode(tokens[self.max_len - + self.overlap_len:]) else: left_text = text[self.max_len - self.overlap_len:] else: @@ -100,9 +98,8 @@ def get_text_chunks(self, text): tokens = text total_len = len(text) if self.tokenizer_name is not None: - tokenizer = get_model(self.model_key) - tokens = tokenizer.tokenize(text) - tokens = [t.decode(encoding='UTF-8') for t in tokens] + _, tokenizer = get_model(self.model_key) + tokens = tokenizer.encode(text) total_len = len(tokens) if total_len <= self.max_len: return [text] @@ -110,7 +107,7 @@ def get_text_chunks(self, text): for start in range(0, total_len, self.max_len - self.overlap_len): cur = tokens[start:start + self.max_len] if self.tokenizer_name is not None: - cur = ''.join(cur) + cur = tokenizer.decode(cur) chunks.append(cur) else: chunks = self.recursively_chunk(text) diff --git a/tests/ops/mapper/test_extract_entity_attribute_mapper.py b/tests/ops/mapper/test_extract_entity_attribute_mapper.py index a14ee76ec..e2e5d2738 100644 --- a/tests/ops/mapper/test_extract_entity_attribute_mapper.py +++ b/tests/ops/mapper/test_extract_entity_attribute_mapper.py @@ -1,5 +1,8 @@ import unittest import json + +from loguru import logger + from data_juicer.core.data import NestedDataset as Dataset from data_juicer.ops.mapper.extract_entity_attribute_mapper import ExtractEntityAttributeMapper from data_juicer.utils.unittest_utils import (SKIPPED_TESTS, @@ -46,6 +49,7 @@ def _run_op(self, api_model, response_path=None): dataset = Dataset.from_list(samples) dataset = dataset.map(op.process, batch_size=1) for sample in dataset: + logger.info(f'{sample[Fields.main_entity]} {sample[Fields.attribute]}: {sample[Fields.attribute_description]}') self.assertNotEqual(sample[Fields.attribute_description], '') self.assertNotEqual(len(sample[Fields.attribute_support_text]), 0) diff --git a/tests/ops/mapper/test_extract_event_mapper.py b/tests/ops/mapper/test_extract_event_mapper.py index 735583650..c454464c9 100644 --- a/tests/ops/mapper/test_extract_event_mapper.py +++ b/tests/ops/mapper/test_extract_event_mapper.py @@ -1,5 +1,8 @@ import unittest import json + +from loguru import logger + from data_juicer.core.data import NestedDataset as Dataset from data_juicer.ops.mapper.extract_event_mapper import ExtractEventMapper from data_juicer.utils.unittest_utils import (SKIPPED_TESTS, @@ -57,7 +60,9 @@ def _run_op(self, api_model, response_path=None): dataset = dataset.map(op.process, batch_size=2) self.assertNotEqual(len(dataset), 0) for sample in dataset: + logger.info(f"event: {sample[Fields.event_description]}") self.assertNotEqual(sample[Fields.event_description], '') + logger.info(f"characters: {sample[Fields.relavant_characters]}") self.assertNotEqual(sample[Fields.relavant_characters], []) def test(self): diff --git a/tests/ops/mapper/test_extract_nickname_mapper.py b/tests/ops/mapper/test_extract_nickname_mapper.py new file mode 100644 index 000000000..4f7b44336 --- /dev/null +++ b/tests/ops/mapper/test_extract_nickname_mapper.py @@ -0,0 +1,54 @@ +import unittest +import json + +from loguru import logger + +from data_juicer.core.data import NestedDataset as Dataset +from data_juicer.ops.mapper.extract_nickname_mapper import ExtractNicknameMapper +from data_juicer.utils.unittest_utils import (SKIPPED_TESTS, + DataJuicerTestCaseBase) +from data_juicer.utils.constant import Fields + +# Skip tests for this OP in the GitHub actions due to unknown DistNetworkError. +# These tests have been tested locally. +@SKIPPED_TESTS.register_module() +class ExtractNicknameMapperTest(DataJuicerTestCaseBase): + + + def _run_op(self, api_model, response_path=None): + + op = ExtractNicknameMapper(api_model=api_model, + response_path=response_path) + + raw_text = """△李莲花又指出刚才门框上的痕迹。 +△李莲花:门框上也是人的掌痕和爪印。指力能嵌入硬物寸余,七分力道主上,三分力道垫下,还有辅以的爪式,看样子这还有昆仑派的外家功夫。 +方多病看着李莲花,愈发生疑os:通过痕迹就能判断出功夫和门派,这绝对只有精通武艺之人才能做到,李莲花你到底是什么人?! +笛飞声环顾四周:有朝月派,还有昆仑派,看来必是一群武林高手在这发生了决斗! +李莲花:如果是武林高手过招,为何又会出现如此多野兽的痕迹。方小宝,你可听过江湖上有什么门派是驯兽来斗?方小宝?方小宝? +方多病回过神:不、不曾听过。 +李莲花:还有这些人都去了哪里? +笛飞声:打架不管是输是赢,自然是打完就走。 +李莲花摇头:就算打完便走,但这里是客栈,为何这么多年一直荒在这里,甚至没人来收拾一下? +笛飞声:闹鬼?这里死过这么多人,楼下又画了那么多符,所以不敢进来? +△这时,梁上又出现有东西移动的声响,李莲花、笛飞声都猛然回头看去。 +""" + samples = [{ + 'text': raw_text, + }] + + dataset = Dataset.from_list(samples) + dataset = dataset.map(op.process, batch_size=2) + result = dataset[0][Fields.nickname] + result = [(d['entity1'], d['entity2'], d['description']) for d in result] + logger.info(f'result: {result}') + self.assertIn(("李莲花","方多病","方小宝"), result) + + def test(self): + # before runing this test, set below environment variables: + # export DJ_API_URL=https://dashscope.aliyuncs.com/compatible-mode/v1/chat/completions + # export DJ_API_KEY=your_key + self._run_op('qwen2.5-72b-instruct') + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/ops/mapper/test_text_chunk_mapper.py b/tests/ops/mapper/test_text_chunk_mapper.py index 84d135c4b..8004d9ede 100644 --- a/tests/ops/mapper/test_text_chunk_mapper.py +++ b/tests/ops/mapper/test_text_chunk_mapper.py @@ -236,7 +236,47 @@ def test_tokenizer_text_chunk(self): ) self._run_helper(op, source, target) - def test_api_tokenizer_text_chunk(self): + def test_tiktoken_tokenizer_text_chunk(self): + source = [ + { + 'text': "Today is Sunday and it's a happy day!" + }, + { + 'text': + "Sur la plateforme MT4, plusieurs manières d'accéder à " + 'ces fonctionnalités sont conçues simultanément.' + }, + { + 'text': '欢迎来到阿里巴巴!' + }, + ] + target = [ + { + 'text': "Today is Sunday and it's a happy day!" + }, + { + 'text': "Sur la plateforme MT4, plusieurs manières d" + }, + { + 'text': " d'accéder à ces fonctionnalités sont conçues simult" + }, + { + 'text': " simultanément." + }, + { + 'text': '欢迎来到阿里巴巴!' + }, + ] + op = TextChunkMapper( + max_len=10, + overlap_len=1, + split_pattern=None, + tokenizer='gpt-4o', + trust_remote_code=True + ) + self._run_helper(op, source, target) + + def test_dashscope_tokenizer_text_chunk(self): source = [ { 'text': "Today is Sunday and it's a happy day!" @@ -271,13 +311,11 @@ def test_api_tokenizer_text_chunk(self): max_len=10, overlap_len=1, split_pattern=None, - tokenizer='gpt-4o', - tokenizer_type='api', + tokenizer='qwen2.5-72b-instruct', trust_remote_code=True ) self._run_helper(op, source, target) - def test_all_text_chunk(self): source = [ { From c7dc28ed9db5fb750dc690cdcd1716e7d2c9be4d Mon Sep 17 00:00:00 2001 From: Haibin <1400012807@pku.edu.cn> Date: Mon, 11 Nov 2024 15:27:20 +0800 Subject: [PATCH 021/118] lightRAG to OP --- data_juicer/ops/mapper/__init__.py | 7 +- data_juicer/ops/mapper/calibrate_qa_mapper.py | 4 +- .../mapper/extract_entity_attribute_mapper.py | 4 +- .../mapper/extract_entity_relation_mapper.py | 342 ++++++++++++++++++ .../ops/mapper/extract_event_mapper.py | 4 +- .../ops/mapper/extract_keyword_mapper.py | 193 ++++++++++ .../ops/mapper/extract_nickname_mapper.py | 14 +- data_juicer/utils/common_utils.py | 8 + data_juicer/utils/constant.py | 6 + .../test_extract_entity_relation_mapper.py | 86 +++++ .../ops/mapper/test_extract_keyword_mapper.py | 72 ++++ .../mapper/test_extract_nickname_mapper.py | 2 +- 12 files changed, 728 insertions(+), 14 deletions(-) create mode 100644 data_juicer/ops/mapper/extract_entity_relation_mapper.py create mode 100644 data_juicer/ops/mapper/extract_keyword_mapper.py create mode 100644 tests/ops/mapper/test_extract_entity_relation_mapper.py create mode 100644 tests/ops/mapper/test_extract_keyword_mapper.py diff --git a/data_juicer/ops/mapper/__init__.py b/data_juicer/ops/mapper/__init__.py index 0350049a4..41bf092a3 100644 --- a/data_juicer/ops/mapper/__init__.py +++ b/data_juicer/ops/mapper/__init__.py @@ -10,7 +10,9 @@ from .clean_links_mapper import CleanLinksMapper from .expand_macro_mapper import ExpandMacroMapper from .extract_entity_attribute_mapper import ExtractEntityAttributeMapper +from .extract_entity_relation_mapper import ExtractEntityRelationMapper from .extract_event_mapper import ExtractEventMapper +from .extract_keyword_mapper import ExtractKeywordMapper from .extract_nickname_mapper import ExtractNicknameMapper from .fix_unicode_mapper import FixUnicodeMapper from .generate_qa_from_examples_mapper import GenerateQAFromExamplesMapper @@ -63,8 +65,9 @@ 'AudioFFmpegWrappedMapper', 'CalibrateQAMapper', 'CalibrateQueryMapper', 'CalibrateResponseMapper', 'ChineseConvertMapper', 'CleanCopyrightMapper', 'CleanEmailMapper', 'CleanHtmlMapper', 'CleanIpMapper', 'CleanLinksMapper', - 'ExpandMacroMapper', 'ExtractEntityAttributeMapper', 'ExtractEventMapper', - 'ExtractNicknameMapper', 'FixUnicodeMapper', + 'ExpandMacroMapper', 'ExtractEntityAttributeMapper', + 'ExtractEntityRelationMapper', 'ExtractEventMapper', + 'ExtractKeywordMapper', 'ExtractNicknameMapper', 'FixUnicodeMapper', 'GenerateQAFromExamplesMapper', 'GenerateQAFromTextMapper', 'ImageBlurMapper', 'ImageCaptioningFromGPT4VMapper', 'ImageCaptioningMapper', 'ImageDiffusionMapper', 'ImageFaceBlurMapper', diff --git a/data_juicer/ops/mapper/calibrate_qa_mapper.py b/data_juicer/ops/mapper/calibrate_qa_mapper.py index b250f7cd7..cf61ff73d 100644 --- a/data_juicer/ops/mapper/calibrate_qa_mapper.py +++ b/data_juicer/ops/mapper/calibrate_qa_mapper.py @@ -42,8 +42,8 @@ def __init__(self, qa_pair_template: Optional[str] = None, output_pattern: Optional[str] = None, try_num: PositiveInt = 3, - model_params: Optional[Dict] = {}, - sampling_params: Optional[Dict] = {}, + model_params: Dict = {}, + sampling_params: Dict = {}, **kwargs): """ Initialization method. diff --git a/data_juicer/ops/mapper/extract_entity_attribute_mapper.py b/data_juicer/ops/mapper/extract_entity_attribute_mapper.py index 0055d3ad7..7d1ae7c8c 100644 --- a/data_juicer/ops/mapper/extract_entity_attribute_mapper.py +++ b/data_juicer/ops/mapper/extract_entity_attribute_mapper.py @@ -57,8 +57,8 @@ def __init__(self, demo_pattern: Optional[str] = None, try_num: PositiveInt = 3, drop_text: bool = False, - model_params: Optional[Dict] = {}, - sampling_params: Optional[Dict] = {}, + model_params: Dict = {}, + sampling_params: Dict = {}, **kwargs): """ Initialization method. diff --git a/data_juicer/ops/mapper/extract_entity_relation_mapper.py b/data_juicer/ops/mapper/extract_entity_relation_mapper.py new file mode 100644 index 000000000..2de5f477e --- /dev/null +++ b/data_juicer/ops/mapper/extract_entity_relation_mapper.py @@ -0,0 +1,342 @@ +# This OP is modified from light RAG +# https://github.com/HKUDS/LightRAG + +# flake8: noqa: E501 + +import re +from typing import Dict, List, Optional + +from loguru import logger +from pydantic import NonNegativeInt, PositiveInt + +from data_juicer.ops.base_op import OPERATORS, UNFORKABLE, Mapper +from data_juicer.utils.common_utils import is_float +from data_juicer.utils.constant import Fields +from data_juicer.utils.model_utils import get_model, prepare_model + +from ..common import split_text_by_punctuation + +OP_NAME = 'extract_entity_relation_mapper' + + +# TODO: LLM-based inference. +@UNFORKABLE.register_module(OP_NAME) +@OPERATORS.register_module(OP_NAME) +class ExtractEntityRelationMapper(Mapper): + """ + Extract entities and relations in the text for knowledge graph. + """ + + DEFAULT_PROMPT_TEMPLATE = """-Goal- +Given a text document that is potentially relevant to this activity and a list of entity types, identify all entities of those types from the text and all relationships among the identified entities. + +-Steps- +1. Identify all entities. For each identified entity, extract the following information: +- entity_name: Name of the entity +- entity_type: One of the following types: [{entity_types}] +- entity_description: Comprehensive description of the entity's attributes and activities +Format each entity as ("entity"{tuple_delimiter}{tuple_delimiter}{tuple_delimiter} + +2. From the entities identified in step 1, identify all pairs of (source_entity, target_entity) that are *clearly related* to each other. +For each pair of related entities, extract the following information: +- source_entity: name of the source entity, as identified in step 1 +- target_entity: name of the target entity, as identified in step 1 +- relationship_description: explanation as to why you think the source entity and the target entity are related to each other +- relationship_strength: a numeric score indicating strength of the relationship between the source entity and target entity +- relationship_keywords: one or more high-level key words that summarize the overarching nature of the relationship, focusing on concepts or themes rather than specific details +Format each relationship as ("relationship"{tuple_delimiter}{tuple_delimiter}{tuple_delimiter}{tuple_delimiter}{tuple_delimiter}) + +3. Return output in the language of the given text as a single list of all the entities and relationships identified in steps 1 and 2. Use **{record_delimiter}** as the list delimiter. + +4. When finished, output {completion_delimiter} + +###################### +-Examples- +###################### +Example 1: + +Entity_types: [person, technology, mission, organization, location] +Text: +``` +while Alex clenched his jaw, the buzz of frustration dull against the backdrop of Taylor's authoritarian certainty. It was this competitive undercurrent that kept him alert, the sense that his and Jordan's shared commitment to discovery was an unspoken rebellion against Cruz's narrowing vision of control and order. + +Then Taylor did something unexpected. They paused beside Jordan and, for a moment, observed the device with something akin to reverence. “If this tech can be understood..." Taylor said, their voice quieter, "It could change the game for us. For all of us.” + +The underlying dismissal earlier seemed to falter, replaced by a glimpse of reluctant respect for the gravity of what lay in their hands. Jordan looked up, and for a fleeting heartbeat, their eyes locked with Taylor's, a wordless clash of wills softening into an uneasy truce. + +It was a small transformation, barely perceptible, but one that Alex noted with an inward nod. They had all been brought here by different paths +``` +################ +Output: +("entity"{tuple_delimiter}"Alex"{tuple_delimiter}"person"{tuple_delimiter}"Alex is a character who experiences frustration and is observant of the dynamics among other characters."){record_delimiter} +("entity"{tuple_delimiter}"Taylor"{tuple_delimiter}"person"{tuple_delimiter}"Taylor is portrayed with authoritarian certainty and shows a moment of reverence towards a device, indicating a change in perspective."){record_delimiter} +("entity"{tuple_delimiter}"Jordan"{tuple_delimiter}"person"{tuple_delimiter}"Jordan shares a commitment to discovery and has a significant interaction with Taylor regarding a device."){record_delimiter} +("entity"{tuple_delimiter}"Cruz"{tuple_delimiter}"person"{tuple_delimiter}"Cruz is associated with a vision of control and order, influencing the dynamics among other characters."){record_delimiter} +("entity"{tuple_delimiter}"The Device"{tuple_delimiter}"technology"{tuple_delimiter}"The Device is central to the story, with potential game-changing implications, and is revered by Taylor."){record_delimiter} +("relationship"{tuple_delimiter}"Alex"{tuple_delimiter}"Taylor"{tuple_delimiter}"Alex is affected by Taylor's authoritarian certainty and observes changes in Taylor's attitude towards the device."{tuple_delimiter}"power dynamics, perspective shift"{tuple_delimiter}7){record_delimiter} +("relationship"{tuple_delimiter}"Alex"{tuple_delimiter}"Jordan"{tuple_delimiter}"Alex and Jordan share a commitment to discovery, which contrasts with Cruz's vision."{tuple_delimiter}"shared goals, rebellion"{tuple_delimiter}6){record_delimiter} +("relationship"{tuple_delimiter}"Taylor"{tuple_delimiter}"Jordan"{tuple_delimiter}"Taylor and Jordan interact directly regarding the device, leading to a moment of mutual respect and an uneasy truce."{tuple_delimiter}"conflict resolution, mutual respect"{tuple_delimiter}8){record_delimiter} +("relationship"{tuple_delimiter}"Jordan"{tuple_delimiter}"Cruz"{tuple_delimiter}"Jordan's commitment to discovery is in rebellion against Cruz's vision of control and order."{tuple_delimiter}"ideological conflict, rebellion"{tuple_delimiter}5){record_delimiter} +("relationship"{tuple_delimiter}"Taylor"{tuple_delimiter}"The Device"{tuple_delimiter}"Taylor shows reverence towards the device, indicating its importance and potential impact."{tuple_delimiter}"reverence, technological significance"{tuple_delimiter}9){record_delimiter} +############################# +Example 2: + +Entity_types: [人物, 技术, 任务, 组织, 地点] +Text: +``` +他们不再是单纯的执行者;他们已成为某个超越星辰与条纹的领域的信息守护者。这一使命的提升不能被规则和既定协议所束缚——它需要一种新的视角,一种新的决心。 + +随着与华盛顿的通讯在背景中嗡嗡作响,对话中的紧张情绪通过嘟嘟声和静电噪音贯穿始终。团队站立着,一股不祥的气息笼罩着他们。显然,他们在接下来几个小时内做出的决定可能会重新定义人类在宇宙中的位置,或者将他们置于无知和潜在危险之中。 + +随着与星辰的联系变得更加牢固,小组开始处理逐渐成形的警告,从被动接受者转变为积极参与者。梅瑟后来的直觉占据了上风——团队的任务已经演变,不再仅仅是观察和报告,而是互动和准备。一场蜕变已经开始,而“杜尔塞行动”则以他们大胆的新频率震动,这种基调不是由世俗设定的 +``` +############# +Output: +("entity"{tuple_delimiter}"华盛顿"{tuple_delimiter}"地点"{tuple_delimiter}"华盛顿是正在接收通讯的地方,表明其在决策过程中的重要性。"){record_delimiter} +("entity"{tuple_delimiter}"杜尔塞行动"{tuple_delimiter}"任务"{tuple_delimiter}"杜尔塞行动被描述为一项已演变为互动和准备的任务,显示出目标和活动的重大转变。"){record_delimiter} +("entity"{tuple_delimiter}"团队"{tuple_delimiter}"组织"{tuple_delimiter}"团队被描绘成一群从被动观察者转变为积极参与者的人,展示了他们角色的动态变化。"){record_delimiter} +("relationship"{tuple_delimiter}"团队"{tuple_delimiter}"华盛顿"{tuple_delimiter}"团队收到来自华盛顿的通讯,这影响了他们的决策过程。"{tuple_delimiter}"决策、外部影响"{tuple_delimiter}7){record_delimiter} +("relationship"{tuple_delimiter}"团队"{tuple_delimiter}"杜尔塞行动"{tuple_delimiter}"团队直接参与杜尔塞行动,执行其演变后的目标和活动。"{tuple_delimiter}"任务演变、积极参与"{tuple_delimiter}9){completion_delimiter} +############################# +Example 3: + +Entity_types: [person, role, technology, organization, event, location, concept] +Text: +``` +their voice slicing through the buzz of activity. "Control may be an illusion when facing an intelligence that literally writes its own rules," they stated stoically, casting a watchful eye over the flurry of data. + +"It's like it's learning to communicate," offered Sam Rivera from a nearby interface, their youthful energy boding a mix of awe and anxiety. "This gives talking to strangers' a whole new meaning." + +Alex surveyed his team—each face a study in concentration, determination, and not a small measure of trepidation. "This might well be our first contact," he acknowledged, "And we need to be ready for whatever answers back." + +Together, they stood on the edge of the unknown, forging humanity's response to a message from the heavens. The ensuing silence was palpable—a collective introspection about their role in this grand cosmic play, one that could rewrite human history. + +The encrypted dialogue continued to unfold, its intricate patterns showing an almost uncanny anticipation +``` +############# +Output: +("entity"{tuple_delimiter}"Sam Rivera"{tuple_delimiter}"person"{tuple_delimiter}"Sam Rivera is a member of a team working on communicating with an unknown intelligence, showing a mix of awe and anxiety."){record_delimiter} +("entity"{tuple_delimiter}"Alex"{tuple_delimiter}"person"{tuple_delimiter}"Alex is the leader of a team attempting first contact with an unknown intelligence, acknowledging the significance of their task."){record_delimiter} +("entity"{tuple_delimiter}"Control"{tuple_delimiter}"concept"{tuple_delimiter}"Control refers to the ability to manage or govern, which is challenged by an intelligence that writes its own rules."){record_delimiter} +("entity"{tuple_delimiter}"Intelligence"{tuple_delimiter}"concept"{tuple_delimiter}"Intelligence here refers to an unknown entity capable of writing its own rules and learning to communicate."){record_delimiter} +("entity"{tuple_delimiter}"First Contact"{tuple_delimiter}"event"{tuple_delimiter}"First Contact is the potential initial communication between humanity and an unknown intelligence."){record_delimiter} +("entity"{tuple_delimiter}"Humanity's Response"{tuple_delimiter}"event"{tuple_delimiter}"Humanity's Response is the collective action taken by Alex's team in response to a message from an unknown intelligence."){record_delimiter} +("relationship"{tuple_delimiter}"Sam Rivera"{tuple_delimiter}"Intelligence"{tuple_delimiter}"Sam Rivera is directly involved in the process of learning to communicate with the unknown intelligence."{tuple_delimiter}"communication, learning process"{tuple_delimiter}9){record_delimiter} +("relationship"{tuple_delimiter}"Alex"{tuple_delimiter}"First Contact"{tuple_delimiter}"Alex leads the team that might be making the First Contact with the unknown intelligence."{tuple_delimiter}"leadership, exploration"{tuple_delimiter}10){record_delimiter} +("relationship"{tuple_delimiter}"Alex"{tuple_delimiter}"Humanity's Response"{tuple_delimiter}"Alex and his team are the key figures in Humanity's Response to the unknown intelligence."{tuple_delimiter}"collective action, cosmic significance"{tuple_delimiter}8){record_delimiter} +("relationship"{tuple_delimiter}"Control"{tuple_delimiter}"Intelligence"{tuple_delimiter}"The concept of Control is challenged by the Intelligence that writes its own rules."{tuple_delimiter}"power dynamics, autonomy"{tuple_delimiter}7){record_delimiter} +############################# +-Real Data- +###################### +Entity_types: [{entity_types}] +Text: +``` +{input_text} +``` +###################### +Output: +""" + DEFAULT_CONTINUE_PROMPT = 'MANY entities were missed in the last extraction. Add them below using the same format:\n' + DEFAULT_IF_LOOP_PROMPT = 'It appears some entities may have still been missed. Answer YES | NO if there are still entities that need to be added.\n' + + DEFAULT_ENTITY_TYPES = ['organization', 'person', 'geo', 'event'] + DEFAULT_TUPLE_DELIMITER = '<|>' + DEFAULT_RECORD_DELIMITER = '##' + DEFAULT_COMPLETION_DELIMITER = '<|COMPLETE|>' + DEFAULT_ENTITY_PATTERN = r'\("entity"(.*?)\)' + DEFAULT_RELATION_PATTERN = r'\("relationship"(.*?)\)' + + # DEFAULT_OUTPUT_PATTERN = r""" + # \#\#\#\s*称呼方式(\d+)\s* + # -\s*\*\*说话人\*\*\s*:\s*(.*?)\s* + # -\s*\*\*被称呼人\*\*\s*:\s*(.*?)\s* + # -\s*\*\*(.*?)对(.*?)的昵称\*\*\s*:\s*(.*?)(?=\#\#\#|\Z) # for double check + # """ + + def __init__(self, + api_model: str = 'gpt-4o', + entity_types: List[str] = None, + *, + entity_key: str = Fields.entity, + relation_key: str = Fields.relation, + api_url: Optional[str] = None, + api_key: Optional[str] = None, + response_path: Optional[str] = None, + prompt_template: Optional[str] = None, + tuple_delimiter: Optional[str] = None, + record_delimiter: Optional[str] = None, + completion_delimiter: Optional[str] = None, + max_gleaning: NonNegativeInt = 1, + continue_prompt: Optional[str] = None, + if_loop_prompt: Optional[str] = None, + entity_pattern: Optional[str] = None, + relation_pattern: Optional[str] = None, + try_num: PositiveInt = 3, + drop_text: bool = False, + model_params: Dict = {}, + sampling_params: Dict = {}, + **kwargs): + """ + Initialization method. + :param api_model: API model name. + :param entity_types: Pre-defined entity types for knowledge graph. + :param entity_key: The field name to store the entities. It's + "__dj__entity__" in default. + :param relation_key: The field name to store the relations between + entities. It's "__dj__relation__" in default. + :param api_url: API URL. Defaults to DJ_API_URL environment variable. + :param api_key: API key. Defaults to DJ_API_KEY environment variable. + :param response_path: Path to extract content from the API response. + Defaults to 'choices.0.message.content'. + :param prompt_template: The template of input prompt. + :param tuple_delimiter: Delimiter to separate items in outputs. + :param record_delimiter: Delimiter to separate records in outputs. + :param completion_delimiter: To mark the end of the output. + :param max_gleaning: the extra max num to call LLM to glean entities + and relations. + :param continue_prompt: the prompt for gleaning entities and + relations. + :param if_loop_prompt: the prompt to determine whether to stop + gleaning. + :param entity_pattern: Regular expression for parsing entity record. + :param relation_pattern: Regular expression for parsing relation + record. + :param try_num: The number of retry attempts when there is an API + call error or output parsing error. + :param drop_text: If drop the text in the output. + :param model_params: Parameters for initializing the model. + :param sampling_params: Extra parameters passed to the API call. + e.g {'temperature': 0.9, 'top_p': 0.95} + :param kwargs: Extra keyword arguments. + """ + super().__init__(**kwargs) + + self.entity_types = entity_types or self.DEFAULT_ENTITY_TYPES + + self.entity_key = entity_key + self.relation_key = relation_key + + self.prompt_template = prompt_template or self.DEFAULT_PROMPT_TEMPLATE + self.tuple_delimiter = tuple_delimiter or self.DEFAULT_TUPLE_DELIMITER + self.record_delimiter = record_delimiter or self.DEFAULT_RECORD_DELIMITER + self.completion_delimiter = completion_delimiter or \ + self.DEFAULT_COMPLETION_DELIMITER + self.max_gleaning = max_gleaning + self.continue_prompt = continue_prompt or self.DEFAULT_CONTINUE_PROMPT + self.if_loop_prompt = if_loop_prompt or self.DEFAULT_IF_LOOP_PROMPT + self.entity_pattern = entity_pattern or self.DEFAULT_ENTITY_PATTERN + self.relation_pattern = relation_pattern or \ + self.DEFAULT_RELATION_PATTERN + + self.model_params = model_params + self.sampling_params = sampling_params + self.model_key = prepare_model(model_type='api', + api_model=api_model, + api_url=api_url, + api_key=api_key, + response_path=response_path, + **model_params) + + self.try_num = try_num + self.drop_text = drop_text + + def parse_output(self, raw_output): + entities, relations = [], [] + + def remove_outer_quotes(text): + if (text[0] == '"' and text[-1] == '"') or (text[0] == "'" + and text[-1] == "'"): + return text[1:-1] + else: + return text + + def split_by_tuple_delimiter(record): + items = record.split(self.tuple_delimiter) + items = [remove_outer_quotes(item.strip()) for item in items] + items = [item.strip() for item in items if item.strip()] + return tuple(items) + + entity_pattern = re.compile(self.entity_pattern, + re.VERBOSE | re.DOTALL) + matches = entity_pattern.findall(raw_output) + for record in matches: + items = split_by_tuple_delimiter(record) + if len(items) != 3: + continue + entities.append(items) + entities = list(set(entities)) + entities = [{ + 'entity': e[0], + 'type': e[1], + 'description': e[2] + } for e in entities] + + relation_pattern = re.compile(self.relation_pattern, + re.VERBOSE | re.DOTALL) + matches = relation_pattern.findall(raw_output) + for record in matches: + items = split_by_tuple_delimiter(record) + if len(items) != 5 or not is_float(items[4]): + continue + relations.append(items) + relations = list(set(relations)) + relations = [{ + 'source_entity': r[0], + 'target_entity': r[1], + 'description': r[2], + 'keywords': split_text_by_punctuation(r[3]), + 'strength': float(r[4]) + } for r in relations] + + return entities, relations + + def add_message(self, messages, role, content): + return messages + [{'role': role, 'content': content}] + + def light_rag_extraction(self, messages, rank=None): + client = get_model(self.model_key, rank=rank) + + final_result = client(messages, **self.sampling_params) + history = self.add_message(messages, 'assistant', final_result) + + for glean_index in range(self.max_gleaning): + messages = self.add_message(history, 'user', self.continue_prompt) + glean_result = client(messages, **self.sampling_params) + history = self.add_message(messages, 'assistant', glean_result) + final_result += glean_result + + if glean_index == self.max_gleaning - 1: + break + + messages = self.add_message(history, 'user', self.if_loop_prompt) + if_loop_result = client(messages, **self.sampling_params) + if_loop_result = if_loop_result.strip().strip('"').strip( + "'").lower() + if if_loop_result != 'yes': + break + + return final_result + + def process_single(self, sample=None, rank=None): + + input_prompt = self.prompt_template.format( + tuple_delimiter=self.tuple_delimiter, + record_delimiter=self.record_delimiter, + completion_delimiter=self.completion_delimiter, + entity_types=', '.join(self.entity_types), + input_text=sample[self.text_key]) + messages = [{'role': 'user', 'content': input_prompt}] + + entities, relations = [], [] + for i in range(self.try_num): + try: + result = self.light_rag_extraction(messages, rank=rank) + entities, relations = self.parse_output(result) + if len(entities) > 0: + break + except Exception as e: + logger.warning(f'Exception: {e}') + + sample[self.entity_key] = entities + sample[self.relation_key] = relations + return sample diff --git a/data_juicer/ops/mapper/extract_event_mapper.py b/data_juicer/ops/mapper/extract_event_mapper.py index c70c9ec6e..562dd6800 100644 --- a/data_juicer/ops/mapper/extract_event_mapper.py +++ b/data_juicer/ops/mapper/extract_event_mapper.py @@ -63,8 +63,8 @@ def __init__(self, output_pattern: Optional[str] = None, try_num: PositiveInt = 3, drop_text: bool = False, - model_params: Optional[Dict] = {}, - sampling_params: Optional[Dict] = {}, + model_params: Dict = {}, + sampling_params: Dict = {}, **kwargs): """ Initialization method. diff --git a/data_juicer/ops/mapper/extract_keyword_mapper.py b/data_juicer/ops/mapper/extract_keyword_mapper.py new file mode 100644 index 000000000..ef531be96 --- /dev/null +++ b/data_juicer/ops/mapper/extract_keyword_mapper.py @@ -0,0 +1,193 @@ +# flake8: noqa: E501 + +import re +from typing import Dict, Optional + +from loguru import logger +from pydantic import PositiveInt + +from data_juicer.ops.base_op import OPERATORS, UNFORKABLE, Mapper +from data_juicer.utils.constant import Fields +from data_juicer.utils.model_utils import get_model, prepare_model + +from ..common import split_text_by_punctuation + +OP_NAME = 'extract_keyword_mapper' + + +# TODO: LLM-based inference. +@UNFORKABLE.register_module(OP_NAME) +@OPERATORS.register_module(OP_NAME) +class ExtractKeywordMapper(Mapper): + """ + Generate keywords for the text + """ + + # This prompt is modified from light RAG + # https://github.com/HKUDS/LightRAG + DEFAULT_PROMPT_TEMPLATE = """-Goal- +Given a text document that is potentially relevant to this activity and a list of entity types, identify all entities of those types from the text and all relationships among the identified entities. + +-Steps- +1. Identify high-level key words that summarize the main concepts, themes, or topics of the entire text. These should capture the overarching ideas present in the document. +Format the content-level key words as ("content_keywords" ) + +3. Return output in the language of the given text. + +4. When finished, output {completion_delimiter} + +###################### +-Examples- +###################### +Example 1: + +Text: +``` +while Alex clenched his jaw, the buzz of frustration dull against the backdrop of Taylor's authoritarian certainty. It was this competitive undercurrent that kept him alert, the sense that his and Jordan's shared commitment to discovery was an unspoken rebellion against Cruz's narrowing vision of control and order. + +Then Taylor did something unexpected. They paused beside Jordan and, for a moment, observed the device with something akin to reverence. “If this tech can be understood..." Taylor said, their voice quieter, "It could change the game for us. For all of us.” + +The underlying dismissal earlier seemed to falter, replaced by a glimpse of reluctant respect for the gravity of what lay in their hands. Jordan looked up, and for a fleeting heartbeat, their eyes locked with Taylor's, a wordless clash of wills softening into an uneasy truce. + +It was a small transformation, barely perceptible, but one that Alex noted with an inward nod. They had all been brought here by different paths +``` +################ +Output: +("content_keywords" "power dynamics, ideological conflict, discovery, rebellion"){completion_delimiter} +############################# +Example 2: + +Text: +``` +他们不再是单纯的执行者;他们已成为某个超越星辰与条纹的领域的信息守护者。这一使命的提升不能被规则和既定协议所束缚——它需要一种新的视角,一种新的决心。 + +随着与华盛顿的通讯在背景中嗡嗡作响,对话中的紧张情绪通过嘟嘟声和静电噪音贯穿始终。团队站立着,一股不祥的气息笼罩着他们。显然,他们在接下来几个小时内做出的决定可能会重新定义人类在宇宙中的位置,或者将他们置于无知和潜在危险之中。 + +随着与星辰的联系变得更加牢固,小组开始处理逐渐成形的警告,从被动接受者转变为积极参与者。梅瑟后来的直觉占据了上风——团队的任务已经演变,不再仅仅是观察和报告,而是互动和准备。一场蜕变已经开始,而“杜尔塞行动”则以他们大胆的新频率震动,这种基调不是由世俗设定的 +``` +############# +Output: +("content_keywords" "任务演变, 决策制定, 积极参与, 宇宙意义"){completion_delimiter} +############################# +Example 3: + +Entity_types: [person, role, technology, organization, event, location, concept] +Text: +``` +their voice slicing through the buzz of activity. "Control may be an illusion when facing an intelligence that literally writes its own rules," they stated stoically, casting a watchful eye over the flurry of data. + +"It's like it's learning to communicate," offered Sam Rivera from a nearby interface, their youthful energy boding a mix of awe and anxiety. "This gives talking to strangers' a whole new meaning." + +Alex surveyed his team—each face a study in concentration, determination, and not a small measure of trepidation. "This might well be our first contact," he acknowledged, "And we need to be ready for whatever answers back." + +Together, they stood on the edge of the unknown, forging humanity's response to a message from the heavens. The ensuing silence was palpable—a collective introspection about their role in this grand cosmic play, one that could rewrite human history. + +The encrypted dialogue continued to unfold, its intricate patterns showing an almost uncanny anticipation +``` +############# +Output: +("content_keywords" "first contact, control, communication, cosmic significance"){completion_delimiter} +-Real Data- +###################### +Text: +``` +{input_text} +``` +###################### +Output: +""" + + DEFAULT_COMPLETION_DELIMITER = '<|COMPLETE|>' + DEFAULT_OUTPUT_PATTERN = r'\("content_keywords"(.*?)\)' + + def __init__(self, + api_model: str = 'gpt-4o', + *, + keyword_key: str = Fields.keyword, + api_url: Optional[str] = None, + api_key: Optional[str] = None, + response_path: Optional[str] = None, + prompt_template: Optional[str] = None, + completion_delimiter: Optional[str] = None, + output_pattern: Optional[str] = None, + try_num: PositiveInt = 3, + drop_text: bool = False, + model_params: Dict = {}, + sampling_params: Dict = {}, + **kwargs): + """ + Initialization method. + :param api_model: API model name. + :param keyword_key: The field name to store the keywords. It's + "__dj__keyword__" in default. + :param api_url: API URL. Defaults to DJ_API_URL environment variable. + :param api_key: API key. Defaults to DJ_API_KEY environment variable. + :param response_path: Path to extract content from the API response. + Defaults to 'choices.0.message.content'. + :param prompt_template: The template of input prompt. + :param completion_delimiter: To mark the end of the output. + :param output_pattern: Regular expression for parsing keywords. + :param try_num: The number of retry attempts when there is an API + call error or output parsing error. + :param drop_text: If drop the text in the output. + :param model_params: Parameters for initializing the model. + :param sampling_params: Extra parameters passed to the API call. + e.g {'temperature': 0.9, 'top_p': 0.95} + :param kwargs: Extra keyword arguments. + """ + super().__init__(**kwargs) + + self.keyword_key = keyword_key + + self.prompt_template = prompt_template or self.DEFAULT_PROMPT_TEMPLATE + self.completion_delimiter = completion_delimiter or \ + self.DEFAULT_COMPLETION_DELIMITER + self.output_pattern = output_pattern or self.DEFAULT_OUTPUT_PATTERN + + self.model_params = model_params + self.sampling_params = sampling_params + self.model_key = prepare_model(model_type='api', + api_model=api_model, + api_url=api_url, + api_key=api_key, + response_path=response_path, + **model_params) + + self.try_num = try_num + self.drop_text = drop_text + + def parse_output(self, raw_output): + keywords = [] + + output_pattern = re.compile(self.output_pattern, + re.VERBOSE | re.DOTALL) + matches = output_pattern.findall(raw_output) + for record in matches: + items = split_text_by_punctuation(record) + keywords.append(items) + + return keywords + + def process_single(self, sample=None, rank=None): + client = get_model(self.model_key, rank=rank) + + input_prompt = self.prompt_template.format( + completion_delimiter=self.completion_delimiter, + input_text=sample[self.text_key]) + messages = [{'role': 'user', 'content': input_prompt}] + + keywords = [] + for i in range(self.try_num): + try: + result = client(messages, **self.sampling_params) + keywords = self.parse_output(result) + if len(keywords) > 0: + break + except Exception as e: + logger.warning(f'Exception: {e}') + + sample[self.keyword_key] = keywords + if self.drop_text: + sample.pop(self.text_key) + + return sample diff --git a/data_juicer/ops/mapper/extract_nickname_mapper.py b/data_juicer/ops/mapper/extract_nickname_mapper.py index eee49a7b6..89adc7bee 100644 --- a/data_juicer/ops/mapper/extract_nickname_mapper.py +++ b/data_juicer/ops/mapper/extract_nickname_mapper.py @@ -60,8 +60,8 @@ def __init__(self, output_pattern: Optional[str] = None, try_num: PositiveInt = 3, drop_text: bool = False, - model_params: Optional[Dict] = {}, - sampling_params: Optional[Dict] = {}, + model_params: Dict = {}, + sampling_params: Dict = {}, **kwargs): """ Initialization method. @@ -126,10 +126,11 @@ def parse_output(self, raw_output): nickname_relations = list(set(nickname_relations)) nickname_relations = [{ - 'entity1': nr[0], - 'entity2': nr[1], + 'source_entity': nr[0], + 'target_entity': nr[1], 'description': nr[2], - 'relation': 'nickname' + 'keywords': ['nickname'], + 'strength': None } for nr in nickname_relations] return nickname_relations @@ -156,4 +157,7 @@ def process_single(self, sample=None, rank=None): logger.warning(f'Exception: {e}') sample[self.nickname_key] = nickname_relations + if self.drop_text: + sample.pop(self.text_key) + return sample diff --git a/data_juicer/utils/common_utils.py b/data_juicer/utils/common_utils.py index 5bd336b9b..46a638987 100644 --- a/data_juicer/utils/common_utils.py +++ b/data_juicer/utils/common_utils.py @@ -19,3 +19,11 @@ def stats_to_number(s, reverse=True): return -sys.maxsize else: return sys.maxsize + + +def is_float(s): + try: + float(s) + return True + except ValueError: + return False diff --git a/data_juicer/utils/constant.py b/data_juicer/utils/constant.py index 17cf54749..df6bad68d 100644 --- a/data_juicer/utils/constant.py +++ b/data_juicer/utils/constant.py @@ -42,6 +42,12 @@ class Fields(object): attribute_support_text = DEFAULT_PREFIX + 'attribute_support_text__' # # the nickname relationship nickname = DEFAULT_PREFIX + 'nickname__' + # # the entity for knowledge graph + entity = DEFAULT_PREFIX + 'entity__' + # # the relationship for knowledge graph + relation = DEFAULT_PREFIX + 'relation__' + # # the keyword in a text + keyword = DEFAULT_PREFIX + 'keyword__' class StatsKeysMeta(type): diff --git a/tests/ops/mapper/test_extract_entity_relation_mapper.py b/tests/ops/mapper/test_extract_entity_relation_mapper.py new file mode 100644 index 000000000..2aee2389b --- /dev/null +++ b/tests/ops/mapper/test_extract_entity_relation_mapper.py @@ -0,0 +1,86 @@ +import unittest +import json + +from loguru import logger + +from data_juicer.core.data import NestedDataset as Dataset +from data_juicer.ops.mapper.extract_entity_relation_mapper import ExtractEntityRelationMapper +from data_juicer.utils.unittest_utils import (SKIPPED_TESTS, + DataJuicerTestCaseBase) +from data_juicer.utils.constant import Fields + +# Skip tests for this OP in the GitHub actions due to unknown DistNetworkError. +# These tests have been tested locally. +@SKIPPED_TESTS.register_module() +class ExtractEntityRelationMapperTest(DataJuicerTestCaseBase): + + + def _run_op(self, op): + + raw_text = """△芩婆走到中间,看着众人。 +芩婆:当年,我那老鬼漆木山与李相夷之父乃是挚交。原本李家隐世而居,一日为了救人,得罪附近山匪,夜里便遭了山匪所袭,唯有二子生还,流落街头。 +封磬震惊:二子?不是只有一个儿子吗? +芩婆:我和漆木山得知这个噩耗后,到处寻找李家那两个孩子的下落。只可惜等我们找他们时,李家长子李相显已经病死。 +李莲花似回忆起了什么:李相显...... +芩婆:我们只从乞丐堆里带回了年纪尚且未满四岁的李相夷,以及,(看向单孤刀)二个一直护着李相夷,与李相显年纪相仿的小乞丐...... +闪回/ +李相显将李且给他的玉佩塞给单孤刀,恳切托付:我没什么值钱的东西,这个玉佩是我唯一的家当了、送给你,我弟弟、相夷......求你照顾他一阵...... +△李相显还想再说什么已气绝而亡,小相夷唤着哥哥大哭,单孤刀愕然看着手里的玉佩有点不知所措。 +△话刚说完,哐当一声破庙门倒进来,几个其他少年乞丐进来。少年乞丐老大:这地儿不错,诶,你俩,出去! +△单孤刀把小相夷护在身后,抓住靠在墙边的木棍。单孤刀:这儿,是我,和我弟弟的。 +乞丐们要抢李相夷的馒头,小李相夷哭着死死护住自馒头不放。 +乞丐甲野蛮地抢:给我拿来! +小单孤刀:放开他! +△单孤刀用力撞向几个乞丐,救下小李相夷。乞丐甲:小子,活腻了! +△几个乞丐围攻小单孤刀,小单孤刀和众乞丐厮打到一起。突然其中一个乞丐掏出一把生锈的刀就朝单孤刀砍去、一个点燃火把棍戳他。单孤刀侧手一挡,火把棍在他手腕上烫出一道伤口,身后几根棍子打得他痛苦倒地! +/闪回结束 +△单孤刀拿着自己手里的玉佩看着,又看看自己手上的印记,不肯相信。单孤刀:胡说!全都是胡说!这些事我为何不知道?都是你在信口雌黄! +芩婆:那我问你,我们将你带回云隐山之前的事你又记得多少? +△单孤刀突然愣住,他意识到那之前的事自己竟都想不起来。 +芩婆:怎么?都想不起来了?(拽起单孤刀手腕,露出他的伤痕)你当日被你师父找到时,手腕上就受了伤,也正因为这处伤,高烧不退,醒来后便忘记了不少从前的事。 +△单孤刀呆住。 +芩婆:而相夷当年不过孩童,尚未到记事的年纪,很多事自然不知道。 +△李莲花得知真相,闭目叹息。 +△封磬震惊地看看单孤刀,又看看李莲花,终于想明白了一切,颓然、懊恼。 +封磬:自萱公主之子下落不明后,这近百年来我们整个家族都一直在不遗余力地寻找萱公主的子嗣后代,直到二十几年前终于让我寻得了线索,知道萱公主的曾孙被漆木山夫妇收为徒,但......我只知道萱公主之孙有一年约十岁的儿子,却不知......原来竟还有一幼子!我......我凭着南胤皇族的玉佩、孩子的年纪和他身上的印记来与主上相认,可没想到......这竟是一个错误!全错了! +△封磬神情复杂地看向李莲花,封磬:你,你才是我的主上...... +△封磬颓然地跪倒下来。 +△李莲花对眼前的一切有些意外、无措。 +笛飞声冷声:怪不得单孤刀的血对业火独毫无作用,李莲花的血才能毁掉这东西。 +△笛飞声不禁冷笑一下。 +""" + samples = [{ + 'text': raw_text, + }] + + dataset = Dataset.from_list(samples) + dataset = dataset.map(op.process, batch_size=2) + sample = dataset[0] + logger.info(f"entitis: {sample[Fields.entity]}") + logger.info(f"relations: {sample[Fields.relation]}") + + def test_default(self): + # before runing this test, set below environment variables: + # export DJ_API_URL=https://dashscope.aliyuncs.com/compatible-mode/v1/chat/completions + # export DJ_API_KEY=your_key + op = ExtractEntityRelationMapper(api_model='qwen2.5-72b-instruct') + self._run_op(op) + + def test_entity_types(self): + op = ExtractEntityRelationMapper( + api_model='qwen2.5-72b-instruct', + entity_types=['人物', '组织', '地点', '物件', '武器', '武功'], + ) + self._run_op(op) + + def test_max_gleaning(self): + op = ExtractEntityRelationMapper( + api_model='qwen2.5-72b-instruct', + entity_types=['人物', '组织', '地点', '物件', '武器', '武功'], + max_gleaning=5, + ) + self._run_op(op) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/ops/mapper/test_extract_keyword_mapper.py b/tests/ops/mapper/test_extract_keyword_mapper.py new file mode 100644 index 000000000..f11df6ed4 --- /dev/null +++ b/tests/ops/mapper/test_extract_keyword_mapper.py @@ -0,0 +1,72 @@ +import unittest +import json + +from loguru import logger + +from data_juicer.core.data import NestedDataset as Dataset +from data_juicer.ops.mapper.extract_keyword_mapper import ExtractKeywordMapper +from data_juicer.utils.unittest_utils import (SKIPPED_TESTS, + DataJuicerTestCaseBase) +from data_juicer.utils.constant import Fields + +# Skip tests for this OP in the GitHub actions due to unknown DistNetworkError. +# These tests have been tested locally. +@SKIPPED_TESTS.register_module() +class ExtractKeywordMapperTest(DataJuicerTestCaseBase): + + + def _run_op(self, api_model, response_path=None): + + op = ExtractKeywordMapper(api_model=api_model, + response_path=response_path) + + raw_text = """△芩婆走到中间,看着众人。 +芩婆:当年,我那老鬼漆木山与李相夷之父乃是挚交。原本李家隐世而居,一日为了救人,得罪附近山匪,夜里便遭了山匪所袭,唯有二子生还,流落街头。 +封磬震惊:二子?不是只有一个儿子吗? +芩婆:我和漆木山得知这个噩耗后,到处寻找李家那两个孩子的下落。只可惜等我们找他们时,李家长子李相显已经病死。 +李莲花似回忆起了什么:李相显...... +芩婆:我们只从乞丐堆里带回了年纪尚且未满四岁的李相夷,以及,(看向单孤刀)二个一直护着李相夷,与李相显年纪相仿的小乞丐...... +闪回/ +李相显将李且给他的玉佩塞给单孤刀,恳切托付:我没什么值钱的东西,这个玉佩是我唯一的家当了、送给你,我弟弟、相夷......求你照顾他一阵...... +△李相显还想再说什么已气绝而亡,小相夷唤着哥哥大哭,单孤刀愕然看着手里的玉佩有点不知所措。 +△话刚说完,哐当一声破庙门倒进来,几个其他少年乞丐进来。少年乞丐老大:这地儿不错,诶,你俩,出去! +△单孤刀把小相夷护在身后,抓住靠在墙边的木棍。单孤刀:这儿,是我,和我弟弟的。 +乞丐们要抢李相夷的馒头,小李相夷哭着死死护住自馒头不放。 +乞丐甲野蛮地抢:给我拿来! +小单孤刀:放开他! +△单孤刀用力撞向几个乞丐,救下小李相夷。乞丐甲:小子,活腻了! +△几个乞丐围攻小单孤刀,小单孤刀和众乞丐厮打到一起。突然其中一个乞丐掏出一把生锈的刀就朝单孤刀砍去、一个点燃火把棍戳他。单孤刀侧手一挡,火把棍在他手腕上烫出一道伤口,身后几根棍子打得他痛苦倒地! +/闪回结束 +△单孤刀拿着自己手里的玉佩看着,又看看自己手上的印记,不肯相信。单孤刀:胡说!全都是胡说!这些事我为何不知道?都是你在信口雌黄! +芩婆:那我问你,我们将你带回云隐山之前的事你又记得多少? +△单孤刀突然愣住,他意识到那之前的事自己竟都想不起来。 +芩婆:怎么?都想不起来了?(拽起单孤刀手腕,露出他的伤痕)你当日被你师父找到时,手腕上就受了伤,也正因为这处伤,高烧不退,醒来后便忘记了不少从前的事。 +△单孤刀呆住。 +芩婆:而相夷当年不过孩童,尚未到记事的年纪,很多事自然不知道。 +△李莲花得知真相,闭目叹息。 +△封磬震惊地看看单孤刀,又看看李莲花,终于想明白了一切,颓然、懊恼。 +封磬:自萱公主之子下落不明后,这近百年来我们整个家族都一直在不遗余力地寻找萱公主的子嗣后代,直到二十几年前终于让我寻得了线索,知道萱公主的曾孙被漆木山夫妇收为徒,但......我只知道萱公主之孙有一年约十岁的儿子,却不知......原来竟还有一幼子!我......我凭着南胤皇族的玉佩、孩子的年纪和他身上的印记来与主上相认,可没想到......这竟是一个错误!全错了! +△封磬神情复杂地看向李莲花,封磬:你,你才是我的主上...... +△封磬颓然地跪倒下来。 +△李莲花对眼前的一切有些意外、无措。 +笛飞声冷声:怪不得单孤刀的血对业火独毫无作用,李莲花的血才能毁掉这东西。 +△笛飞声不禁冷笑一下。 +""" + samples = [{ + 'text': raw_text, + }] + + dataset = Dataset.from_list(samples) + dataset = dataset.map(op.process, batch_size=2) + sample = dataset[0] + logger.info(f"keywords: {sample[Fields.keyword]}") + + def test(self): + # before runing this test, set below environment variables: + # export DJ_API_URL=https://dashscope.aliyuncs.com/compatible-mode/v1/chat/completions + # export DJ_API_KEY=your_key + self._run_op('qwen2.5-72b-instruct') + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/ops/mapper/test_extract_nickname_mapper.py b/tests/ops/mapper/test_extract_nickname_mapper.py index 4f7b44336..780c7e4b2 100644 --- a/tests/ops/mapper/test_extract_nickname_mapper.py +++ b/tests/ops/mapper/test_extract_nickname_mapper.py @@ -39,7 +39,7 @@ def _run_op(self, api_model, response_path=None): dataset = Dataset.from_list(samples) dataset = dataset.map(op.process, batch_size=2) result = dataset[0][Fields.nickname] - result = [(d['entity1'], d['entity2'], d['description']) for d in result] + result = [(d['source_entity'], d['target_entity'], d['description']) for d in result] logger.info(f'result: {result}') self.assertIn(("李莲花","方多病","方小宝"), result) From 0e51a43be8b01a257736e1f8cd914f86b3b1316d Mon Sep 17 00:00:00 2001 From: Haibin <1400012807@pku.edu.cn> Date: Mon, 11 Nov 2024 17:33:13 +0800 Subject: [PATCH 022/118] doc done --- configs/config_all.yaml | 95 +++++++++++++++++-- .../mapper/extract_entity_attribute_mapper.py | 3 +- .../ops/mapper/extract_event_mapper.py | 2 +- .../ops/mapper/extract_nickname_mapper.py | 6 +- docs/Operators.md | 8 +- docs/Operators_ZH.md | 8 +- .../mapper/test_extract_nickname_mapper.py | 4 +- 7 files changed, 110 insertions(+), 16 deletions(-) diff --git a/configs/config_all.yaml b/configs/config_all.yaml index 48c9f172e..d620a505a 100644 --- a/configs/config_all.yaml +++ b/configs/config_all.yaml @@ -63,8 +63,8 @@ process: reference_template: null # Template for formatting the reference text. qa_pair_template: null # Template for formatting question-answer pairs. output_pattern: null # Regular expression for parsing model output. - model_params: null # Parameters for initializing the model. - sampling_params: null # Extra parameters passed to the API call. + model_params: {} # Parameters for initializing the model. + sampling_params: {} # Extra parameters passed to the API call. - calibrate_query_mapper: # calibrate query in question-answer pairs based on reference text. - calibrate_response_mapper: # calibrate response in question-answer pairs based on reference text. - chinese_convert_mapper: # convert Chinese between Traditional Chinese, Simplified Chinese and Japanese Kanji. @@ -75,6 +75,81 @@ process: - clean_links_mapper: # remove web links from text. - clean_copyright_mapper: # remove copyright comments. - expand_macro_mapper: # expand macro definitions in Latex text. + - extract_entity_attribute_mapper: # Extract attributes for given entities from the text. + query_entities: ["孙悟空", "猪八戒"] # Entity list to be queried. + query_attributes: ["人物性格"] # Attribute list to be queried. + api_model: 'gpt-4o' # API model name. + entity_key: '__dj__entity__' # The field name to store the given main entity for attribute extraction. + entity_attribute_key: '__dj__attribute__' # The field name to store the given attribute to be extracted. + attribute_desc_key: '__dj__attribute_description__' # The field name to store the extracted attribute description. + support_text_key: '__dj__support_text__' # The field name to store the attribute support text extracted from the raw text. + api_url: null # URL endpoint for the API. + response_path: null # Path to extract content from the API response. Defaults to 'choices.0.message.content'. + system_prompt_template: null # System prompt template for the task. Need to be specified by given entity and attribute. + input_template: null # Template for building the model input. + attr_pattern_template: null # Pattern for parsing the attribute from output. Need to be specified by given attribute. + demo_pattern: null # Pattern for parsing the demonstraction from output to support the attribute. + try_num: 3 # The number of retry attempts when there is an API call error or output parsing error. + drop_text: false # If drop the text in the output. + model_params: {} # Parameters for initializing the API model. + sampling_params: {} # Extra parameters passed to the API call. e.g {'temperature': 0.9, 'top_p': 0.95} + - extract_entity_relation_mapper: # Extract entities and relations in the text for knowledge graph. + api_model: 'gpt-4o' # API model name. + entity_types: ['person', 'organization', 'location'] # Pre-defined entity types for knowledge graph. + entity_key: '__dj__entity__' # The field name to store the entities. + relation_key: '__dj__relation__' # The field name to store the relations between entities. + api_url: null # URL endpoint for the API. + response_path: null # Path to extract content from the API response. Defaults to 'choices.0.message.content'. + prompt_template: null # The template of input prompt. + tuple_delimiter: null # Delimiter to separate items in outputs. + record_delimiter: null # Delimiter to separate records in outputs. + completion_delimiter: null # To mark the end of the output. + max_gleaning: 1 # the extra max num to call LLM to glean entities and relations. + continue_prompt: null # the prompt for gleaning entities and relations. + if_loop_prompt: null # the prompt to determine whether to stop gleaning. + entity_pattern: null # Regular expression for parsing entity record. + relation_pattern: null # Regular expression for parsing relation record. + try_num: 3 # The number of retry attempts when there is an API call error or output parsing error. + drop_text: false # If drop the text in the output. + model_params: {} # Parameters for initializing the API model. + sampling_params: {} # Extra parameters passed to the API call. e.g {'temperature': 0.9, 'top_p': 0.95} + - extract_event_mapper: # Extract events and relavant characters in the text + api_model: 'gpt-4o' # API model name. + event_desc_key: '__dj__event_description__' # The field name to store the event descriptions. + relavant_char_key: '__dj__relavant_characters__' # The field name to store the relavant characters to the events. + api_url: null # URL endpoint for the API. + response_path: null # Path to extract content from the API response. Defaults to 'choices.0.message.content'. + system_prompt: null # System prompt for the task. + input_template: null # Template for building the model input. + output_pattern: null # Regular expression for parsing model output. + try_num: 3 # The number of retry attempts when there is an API call error or output parsing error. + drop_text: false # If drop the text in the output. + model_params: {} # Parameters for initializing the API model. + sampling_params: {} # Extra parameters passed to the API call. e.g {'temperature': 0.9, 'top_p': 0.95} + - extract_keyword_mapper: # Generate keywords for the text. + api_model: 'gpt-4o' # API model name. + keyword_key: '__dj__keyword__' # The field name to store the keywords. + api_url: null # URL endpoint for the API. + response_path: null # Path to extract content from the API response. Defaults to 'choices.0.message.content'. + prompt_template: null # The template of input prompt. + completion_delimiter: null # To mark the end of the output. + output_pattern: null # Regular expression for parsing keywords. + try_num: 3 # The number of retry attempts when there is an API call error or output parsing error. + drop_text: false # If drop the text in the output. + model_params: {} # Parameters for initializing the API model. + sampling_params: {} # Extra parameters passed to the API call. e.g {'temperature': 0.9, 'top_p': 0.95} + - extract_nickname_mapper: # Extract nickname relationship in the text. + api_model: 'gpt-4o' # API model name. + nickname_key: '__dj__nickname__' # The field name to store the nickname relationship. + api_url: null # URL endpoint for the API. + response_path: null # Path to extract content from the API response. Defaults to 'choices.0.message.content'. + system_prompt: null # System prompt for the task. + input_template: null # Template for building the model input. + output_pattern: null # Regular expression for parsing model output. + try_num: 3 # The number of retry attempts when there is an API call error or output parsing error. + drop_text: false # If drop the text in the output. + model_params: {} # Parameters for initializing the API model. + sampling_params: {} # Extra parameters passed to the API call. e.g {'temperature': 0.9, 'top_p': 0.95} - fix_unicode_mapper: # fix unicode errors in text. - generate_qa_from_examples_mapper: # mapper to generate question and answer pairs from examples. hf_model: 'Qwen/Qwen2.5-7B-Instruct' # Model name on huggingface to generate question and answer pairs. @@ -87,14 +162,14 @@ process: qa_pair_template: null # Template for formatting a single QA pair within each example. output_pattern: null # Regular expression pattern to extract questions and answers from model response. enable_vllm: false # Whether to use vllm for inference acceleration. - model_params: null # Parameters for initializing the model. + model_params: {} # Parameters for initializing the model. sampling_params: {} # Sampling parameters for text generation. e.g {'temperature': 0.9, 'top_p': 0.95} - generate_qa_from_text_mapper: # mapper to generate question and answer pairs from text. hf_model: 'alibaba-pai/pai-qwen1_5-7b-doc2qa' # Model name on huggingface to generate question and answer pairs. output_pattern: null # Regular expression pattern to extract questions and answers from model response. enable_vllm: false # Whether to use vllm for inference acceleration. - model_params: null # Parameters for initializing the model. - sampling_params: null # Sampling parameters for text generation. e.g {'temperature': 0.9, 'top_p': 0.95} + model_params: {} # Parameters for initializing the model. + sampling_params: {} # Sampling parameters for text generation. e.g {'temperature': 0.9, 'top_p': 0.95} - image_blur_mapper: # mapper to blur images. p: 0.2 # probability of the image being blured blur_type: 'gaussian' # type of blur kernel, including ['mean', 'box', 'gaussian'] @@ -163,8 +238,8 @@ process: qa_pair_template: null # Template for formatting the question and answer pair. output_pattern: null # Regular expression pattern to extract question and answer from model response. enable_vllm: false # whether to use vllm for inference acceleration. - model_params: null # Parameters for initializing the model. - sampling_params: null # Sampling parameters for text generation. e.g {'temperature': 0.9, 'top_p': 0.95} + model_params: {} # Parameters for initializing the model. + sampling_params: {} # Sampling parameters for text generation. e.g {'temperature': 0.9, 'top_p': 0.95} - optimize_query_mapper: # optimize query in question-answer pairs. - optimize_response_mapper: # optimize response in question-answer pairs. - punctuation_normalization_mapper: # normalize unicode punctuations to English punctuations. @@ -197,6 +272,12 @@ process: substrings: ['http', 'www', '.com', 'href', '//'] # incorrect substrings to remove - sentence_split_mapper: # split text to multiple sentences and join them with '\n' lang: 'en' # split text in what language + - text_chunk_mapper: # Split input text to chunks. + max_len: 2000 # Split text into multi texts with this max len if not None. + split_pattern: '\n\n' # Make sure split in this pattern if it is not None and force cut if the length exceeds max_len. + overlap_len: 200 # Overlap length of the split texts if not split in the split pattern. + tokenizer: 'gpt-4o' # The tokenizer name of Hugging Face tokenizers. The text length will be calculate as the token num if it is offerd. Otherwise, the text length equals to string length. + trust_remote_code: True # for loading huggingface model. - video_captioning_from_audio_mapper: # caption a video according to its audio streams based on Qwen-Audio model keep_original_sample: true # whether to keep the original sample. If it's set to False, there will be only captioned sample in the final datasets and the original sample will be removed. It's True in default. mem_required: '30GB' # This operation (Op) utilizes deep neural network models that consume a significant amount of memory for computation, hence the system's available memory might constrains the maximum number of processes that can be launched diff --git a/data_juicer/ops/mapper/extract_entity_attribute_mapper.py b/data_juicer/ops/mapper/extract_entity_attribute_mapper.py index 5b3ffbae4..fc259a1b6 100644 --- a/data_juicer/ops/mapper/extract_entity_attribute_mapper.py +++ b/data_juicer/ops/mapper/extract_entity_attribute_mapper.py @@ -77,7 +77,7 @@ def __init__(self, :param api_url: URL endpoint for the API. :param response_path: Path to extract content from the API response. Defaults to 'choices.0.message.content'. - :param system_prompt_template: System prompt for the calibration + :param system_prompt_template: System prompt template for the task. Need to be specified by given entity and attribute. :param input_template: Template for building the model input. :param attr_pattern_template: Pattern for parsing the attribute from @@ -86,6 +86,7 @@ def __init__(self, output to support the attribute. :param try_num: The number of retry attempts when there is an API call error or output parsing error. + :param drop_text: If drop the text in the output. :param model_params: Parameters for initializing the API model. :param sampling_params: Extra parameters passed to the API call. e.g {'temperature': 0.9, 'top_p': 0.95} diff --git a/data_juicer/ops/mapper/extract_event_mapper.py b/data_juicer/ops/mapper/extract_event_mapper.py index b5338d082..37e5aeb1c 100644 --- a/data_juicer/ops/mapper/extract_event_mapper.py +++ b/data_juicer/ops/mapper/extract_event_mapper.py @@ -76,7 +76,7 @@ def __init__(self, :param api_url: URL endpoint for the API. :param response_path: Path to extract content from the API response. Defaults to 'choices.0.message.content'. - :param system_prompt: System prompt for the calibration task. + :param system_prompt: System prompt for the task. :param input_template: Template for building the model input. :param output_pattern: Regular expression for parsing model output. :param try_num: The number of retry attempts when there is an API diff --git a/data_juicer/ops/mapper/extract_nickname_mapper.py b/data_juicer/ops/mapper/extract_nickname_mapper.py index 60e9741e2..265a02441 100644 --- a/data_juicer/ops/mapper/extract_nickname_mapper.py +++ b/data_juicer/ops/mapper/extract_nickname_mapper.py @@ -16,7 +16,7 @@ @OPERATORS.register_module(OP_NAME) class ExtractNicknameMapper(Mapper): """ - Extract nickname relationship in the text + Extract nickname relationship in the text. """ DEFAULT_SYSTEM_PROMPT = ('给定你一段文本,你的任务是将人物之间的称呼方式(昵称)提取出来。\n' @@ -70,7 +70,7 @@ def __init__(self, :param api_url: URL endpoint for the API. :param response_path: Path to extract content from the API response. Defaults to 'choices.0.message.content'. - :param system_prompt: System prompt for the calibration task. + :param system_prompt: System prompt for the task. :param input_template: Template for building the model input. :param output_pattern: Regular expression for parsing model output. :param try_num: The number of retry attempts when there is an API @@ -119,7 +119,7 @@ def parse_output(self, raw_output): continue if role1 and role2 and nickname: nickname_relations.append((role1, role2, nickname)) - nickname_relations = list(set(nickname_relations)) + nickname_relations = list(set(nickname_relations)) nickname_relations = [{ 'source_entity': nr[0], diff --git a/docs/Operators.md b/docs/Operators.md index a4cbba709..9ec31799b 100644 --- a/docs/Operators.md +++ b/docs/Operators.md @@ -11,7 +11,7 @@ The operators in Data-Juicer are categorized into 5 types. | Type | Number | Description | |-----------------------------------|:------:|-------------------------------------------------| | [ Formatter ]( #formatter ) | 9 | Discovers, loads, and canonicalizes source data | -| [ Mapper ]( #mapper ) | 52 | Edits and transforms samples | +| [ Mapper ]( #mapper ) | 58 | Edits and transforms samples | | [ Filter ]( #filter ) | 43 | Filters out low-quality samples | | [ Deduplicator ]( #deduplicator ) | 8 | Detects and removes duplicate samples | | [ Selector ]( #selector ) | 4 | Selects top samples based on ranking | @@ -67,6 +67,11 @@ All the specific operators are listed below, each featured with several capabili | clean_ip_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Removes IP addresses | [code](../data_juicer/ops/mapper/clean_ip_mapper.py) | [tests](../tests/ops/mapper/test_clean_ip_mapper.py) | | clean_links_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![Code](https://img.shields.io/badge/Code-590F08?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Removes links, such as those starting with http or ftp | [code](../data_juicer/ops/mapper/clean_links_mapper.py) | [tests](../tests/ops/mapper/test_clean_links_mapper.py) | | expand_macro_mapper | ![LaTeX](https://img.shields.io/badge/LaTeX-D99379?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Expands macros usually defined at the top of TeX documents | [code](../data_juicer/ops/mapper/expand_macro_mapper.py) | [tests](../tests/ops/mapper/test_expand_macro_mapper.py) | +| extract_entity_attribute_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Extract attributes for given entities from the text. | [code](../data_juicer/ops/mapper/extract_entity_attribute_mapper.py) | [tests](../tests/ops/mapper/test_extract_entity_attribute_mapper.py) | +| extract_entity_relation_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Extract entities and relations in the text for knowledge graph. | [code](../data_juicer/ops/mapper/extract_entity_relation_mapper.py) | [tests](../tests/ops/mapper/test_extract_entity_relation_mapper.py) | +| extract_event_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Extract events and relavant characters in the text. | [code](../data_juicer/ops/mapper/extract_event_mapper.py) | [tests](../tests/ops/mapper/test_extract_event_mapper.py) | +| extract_keyword_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Generate keywords for the text. | [code](../data_juicer/ops/mapper/extract_keyword_mapper.py) | [tests](../tests/ops/mapper/test_extract_keyword_mapper.py) | +| extract_nickname_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Extract nickname relationship in the text. | [code](../data_juicer/ops/mapper/extract_nickname_mapper.py) | [tests](../tests/ops/mapper/test_extract_nickname_mapper.py) | | fix_unicode_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Fixes broken Unicodes (by [ftfy](https://ftfy.readthedocs.io/)) | [code](../data_juicer/ops/mapper/fix_unicode_mapper.py) | [tests](../tests/ops/mapper/test_fix_unicode_mapper.py) | | generate_qa_from_examples_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) ![GPU](https://img.shields.io/badge/GPU-F27649?style=plastic) | Generate question and answer pairs based on examples. | [code](../data_juicer/ops/mapper/generate_qa_from_examples_mapper.py) | [tests](../tests/ops/mapper/test_generate_qa_from_examples_mapper.py) | | generate_qa_from_text_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) ![GPU](https://img.shields.io/badge/GPU-F27649?style=plastic) | Generate question and answer pairs from text. | [code](../data_juicer/ops/mapper/generate_qa_from_text_mapper.py) | [tests](../tests/ops/mapper/test_generate_qa_from_text_mapper.py) | @@ -93,6 +98,7 @@ All the specific operators are listed below, each featured with several capabili | remove_words_with_incorrect_ substrings_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Removes words containing specified substrings | [code](../data_juicer/ops/mapper/remove_words_with_incorrect_substrings_mapper.py) | [tests](../tests/ops/mapper/test_remove_words_with_incorrect_substrings_mapper.py) | | replace_content_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Replace all content in the text that matches a specific regular expression pattern with a designated replacement string | [code](../data_juicer/ops/mapper/replace_content_mapper.py) | [tests](../tests/ops/mapper/test_replace_content_mapper.py) | | sentence_split_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) | Splits and reorganizes sentences according to semantics | [code](../data_juicer/ops/mapper/sentence_split_mapper.py) | [tests](../tests/ops/mapper/test_sentence_split_mapper.py) | +| text_chunk_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Split input text to chunks. | [code](../data_juicer/ops/mapper/text_chunk_mapper.py) | [tests](../tests/ops/mapper/test_text_chunk_mapper.py) | | video_captioning_from_audio_mapper | ![Multimodal](https://img.shields.io/badge/Multimodal-F25922?style=plastic) ![GPU](https://img.shields.io/badge/GPU-F27649?style=plastic) | Caption a video according to its audio streams based on Qwen-Audio model | [code](../data_juicer/ops/mapper/video_captioning_from_audio_mapper.py) | [tests](../tests/ops/mapper/test_video_captioning_from_audio_mapper.py) | | video_captioning_from_frames_mapper | ![Multimodal](https://img.shields.io/badge/Multimodal-F25922?style=plastic) ![GPU](https://img.shields.io/badge/GPU-F27649?style=plastic) | generate samples whose captions are generated based on an image-to-text model and sampled video frames. Captions from different frames will be concatenated to a single string | [code](../data_juicer/ops/mapper/video_captioning_from_frames_mapper.py) | [tests](../tests/ops/mapper/test_video_captioning_from_frames_mapper.py) | | video_captioning_from_summarizer_mapper | ![Multimodal](https://img.shields.io/badge/Multimodal-F25922?style=plastic) ![GPU](https://img.shields.io/badge/GPU-F27649?style=plastic) | Generate video captions by summarizing several kinds of generated texts (captions from video/audio/frames, tags from audio/frames, ...) | [code](../data_juicer/ops/mapper/video_captioning_from_summarizer_mapper.py) | [tests](../tests/ops/mapper/test_video_captioning_from_summarizer_mapper.py) | diff --git a/docs/Operators_ZH.md b/docs/Operators_ZH.md index 5dfe64141..304c0c5e6 100644 --- a/docs/Operators_ZH.md +++ b/docs/Operators_ZH.md @@ -11,7 +11,7 @@ Data-Juicer 中的算子分为以下 5 种类型。 | 类型 | 数量 | 描述 | |------------------------------------|:--:|---------------| | [ Formatter ]( #formatter ) | 9 | 发现、加载、规范化原始数据 | -| [ Mapper ]( #mapper ) | 52 | 对数据样本进行编辑和转换 | +| [ Mapper ]( #mapper ) | 58 | 对数据样本进行编辑和转换 | | [ Filter ]( #filter ) | 43 | 过滤低质量样本 | | [ Deduplicator ]( #deduplicator ) | 8 | 识别、删除重复样本 | | [ Selector ]( #selector ) | 4 | 基于排序选取高质量样本 | @@ -66,6 +66,11 @@ Data-Juicer 中的算子分为以下 5 种类型。 | clean_ip_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 删除 IP 地址 | [code](../data_juicer/ops/mapper/clean_ip_mapper.py) | [tests](../tests/ops/mapper/test_clean_ip_mapper.py) | | clean_links_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![Code](https://img.shields.io/badge/Code-590F08?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 删除链接,例如以 http 或 ftp 开头的 | [code](../data_juicer/ops/mapper/clean_links_mapper.py) | [tests](../tests/ops/mapper/test_clean_links_mapper.py) | | expand_macro_mapper | ![LaTeX](https://img.shields.io/badge/LaTeX-D99379?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 扩展通常在 TeX 文档顶部定义的宏 | [code](../data_juicer/ops/mapper/expand_macro_mapper.py) | [tests](../tests/ops/mapper/test_expand_macro_mapper.py) | +| extract_entity_attribute_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 给定主体和属性名,从文本中抽取主体的属性 | [code](../data_juicer/ops/mapper/extract_entity_attribute_mapper.py) | [tests](../tests/ops/mapper/test_extract_entity_attribute_mapper.py) | +| extract_entity_relation_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 从文本中抽取知识图谱的实体和关系 | [code](../data_juicer/ops/mapper/extract_entity_relation_mapper.py) | [tests](../tests/ops/mapper/test_extract_entity_relation_mapper.py) | +| extract_event_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 从文本中抽取出事件和事件相关人物 | [code](../data_juicer/ops/mapper/extract_event_mapper.py) | [tests](../tests/ops/mapper/test_extract_event_mapper.py) | +| extract_keyword_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 构造文本的关键词 | [code](../data_juicer/ops/mapper/extract_keyword_mapper.py) | [tests](../tests/ops/mapper/test_extract_keyword_mapper.py) | +| extract_nickname_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 抽取昵称称呼关系 | [code](../data_juicer/ops/mapper/extract_nickname_mapper.py) | [tests](../tests/ops/mapper/test_extract_nickname_mapper.py) | | fix_unicode_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 修复损坏的 Unicode(借助 [ftfy](https://ftfy.readthedocs.io/)) | [code](../data_juicer/ops/mapper/fix_unicode_mapper.py) | [tests](../tests/ops/mapper/test_fix_unicode_mapper.py) | | generate_qa_from_examples_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) ![GPU](https://img.shields.io/badge/GPU-F27649?style=plastic) | 根据种子数据,生成新的对话样本。 | [code](../data_juicer/ops/mapper/generate_qa_from_examples_mapper.py) | [tests](../tests/ops/mapper/test_generate_qa_from_examples_mapper.py) | | generate_qa_from_text_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) ![GPU](https://img.shields.io/badge/GPU-F27649?style=plastic) | 从文本中生成问答对 | [code](../data_juicer/ops/mapper/generate_qa_from_text_mapper.py) | [tests](../tests/ops/mapper/test_generate_qa_from_text_mapper.py) | @@ -92,6 +97,7 @@ Data-Juicer 中的算子分为以下 5 种类型。 | remove_words_with_incorrect_ substrings_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 删除包含指定子字符串的单词 | [code](../data_juicer/ops/mapper/remove_words_with_incorrect_substrings_mapper.py) | [tests](../tests/ops/mapper/test_remove_words_with_incorrect_substrings_mapper.py) | | replace_content_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 使用一个指定的替换字符串替换文本中满足特定正则表达式模版的所有内容 | [code](../data_juicer/ops/mapper/replace_content_mapper.py) | [tests](../tests/ops/mapper/test_replace_content_mapper.py) | | sentence_split_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) | 根据语义拆分和重组句子 | [code](../data_juicer/ops/mapper/sentence_split_mapper.py) | [tests](../tests/ops/mapper/test_sentence_split_mapper.py) | +| text_chunk_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 对文本进行分片处理 | [code](../data_juicer/ops/mapper/text_chunk_mapper.py) | [tests](../tests/ops/mapper/test_text_chunk_mapper.py) | | video_captioning_from_audio_mapper | ![Multimodal](https://img.shields.io/badge/Multimodal-F25922?style=plastic) ![GPU](https://img.shields.io/badge/GPU-F27649?style=plastic) | 基于 Qwen-Audio 模型根据视频的音频流为视频生成新的标题描述 | [code](../data_juicer/ops/mapper/video_captioning_from_audio_mapper.py) | [tests](../tests/ops/mapper/test_video_captioning_from_audio_mapper.py) | | video_captioning_from_frames_mapper | ![Multimodal](https://img.shields.io/badge/Multimodal-F25922?style=plastic) ![GPU](https://img.shields.io/badge/GPU-F27649?style=plastic) | 生成样本,其标题是基于一个文字生成图片的模型和原始样本视频中指定帧的图像。不同帧产出的标题会拼接为一条单独的字符串。 | [code](../data_juicer/ops/mapper/video_captioning_from_frames_mapper.py) | [tests](../tests/ops/mapper/test_video_captioning_from_frames_mapper.py) | | video_captioning_from_summarizer_mapper | ![Multimodal](https://img.shields.io/badge/Multimodal-F25922?style=plastic) ![GPU](https://img.shields.io/badge/GPU-F27649?style=plastic) | 通过对多种不同方式生成的文本进行摘要以生成样本的标题(从视频/音频/帧生成标题,从音频/帧生成标签,...) | [code](../data_juicer/ops/mapper/video_captioning_from_summarizer_mapper.py) | [tests](../tests/ops/mapper/test_video_captioning_from_summarizer_mapper.py) | diff --git a/tests/ops/mapper/test_extract_nickname_mapper.py b/tests/ops/mapper/test_extract_nickname_mapper.py index 5714b4576..780c7e4b2 100644 --- a/tests/ops/mapper/test_extract_nickname_mapper.py +++ b/tests/ops/mapper/test_extract_nickname_mapper.py @@ -45,8 +45,8 @@ def _run_op(self, api_model, response_path=None): def test(self): # before runing this test, set below environment variables: - # export OPENAI_BASE_URL=https://dashscope.aliyuncs.com/compatible-mode/v1/ - # export OPENAI_API_KEY=your_dashscope_key + # export DJ_API_URL=https://dashscope.aliyuncs.com/compatible-mode/v1/chat/completions + # export DJ_API_KEY=your_key self._run_op('qwen2.5-72b-instruct') From 6d9d8a5a2227a480cfb1827cf75053c250c42527 Mon Sep 17 00:00:00 2001 From: Haibin <1400012807@pku.edu.cn> Date: Mon, 11 Nov 2024 17:45:03 +0800 Subject: [PATCH 023/118] remove extra test --- tests/ops/mapper/test_extract_qa_mapper.py | 53 ------------------- .../test_generate_instruction_mapper.py | 43 --------------- .../test_optimize_instruction_mapper.py | 36 ------------- 3 files changed, 132 deletions(-) delete mode 100644 tests/ops/mapper/test_extract_qa_mapper.py delete mode 100644 tests/ops/mapper/test_generate_instruction_mapper.py delete mode 100644 tests/ops/mapper/test_optimize_instruction_mapper.py diff --git a/tests/ops/mapper/test_extract_qa_mapper.py b/tests/ops/mapper/test_extract_qa_mapper.py deleted file mode 100644 index 415efad4e..000000000 --- a/tests/ops/mapper/test_extract_qa_mapper.py +++ /dev/null @@ -1,53 +0,0 @@ -import unittest -import json -from data_juicer.ops.mapper.extract_qa_mapper import ExtractQAMapper -from data_juicer.utils.unittest_utils import (SKIPPED_TESTS, - DataJuicerTestCaseBase) - -# Skip tests for this OP in the GitHub actions due to unknown DistNetworkError. -# These tests have been tested locally. -@SKIPPED_TESTS.register_module() -class ExtractQAMapperTest(DataJuicerTestCaseBase): - text_key = 'text' - - def _run_extract_qa(self, samples, enable_vllm=False, sampling_params={}, **kwargs): - op = ExtractQAMapper( - hf_model='alibaba-pai/pai-qwen1_5-7b-doc2qa', - trust_remote_code=True, - qa_format='chatml', - enable_vllm=enable_vllm, - sampling_params=sampling_params, - **kwargs - ) - for sample in samples: - result = op.process(sample) - out_text = json.loads(result[self.text_key]) - print(f'Output sample: {out_text}') - - # test one output qa sample - qa_sample = out_text[0] - self.assertIn('role', qa_sample['messages'][0]) - self.assertIn('content', qa_sample['messages'][0]) - - def test_extract_qa(self): - samples = [ - { - self.text_key: '蒙古国的首都是乌兰巴托(Ulaanbaatar)\n冰岛的首都是雷克雅未克(Reykjavik)\n' - }] - self._run_extract_qa(samples) - - def test_extract_qa_vllm(self): - samples = [ - { - self.text_key: '蒙古国的首都是乌兰巴托(Ulaanbaatar)\n冰岛的首都是雷克雅未克(Reykjavik)\n' - }] - self._run_extract_qa( - samples, - enable_vllm=True, - max_model_len=1024, - max_num_seqs=16, - sampling_params={'temperature': 0.9, 'top_p': 0.95, 'max_tokens': 256}) - - -if __name__ == '__main__': - unittest.main() diff --git a/tests/ops/mapper/test_generate_instruction_mapper.py b/tests/ops/mapper/test_generate_instruction_mapper.py deleted file mode 100644 index a250fbcc4..000000000 --- a/tests/ops/mapper/test_generate_instruction_mapper.py +++ /dev/null @@ -1,43 +0,0 @@ -import unittest -import json -from data_juicer.ops.mapper.generate_instruction_mapper import GenerateInstructionMapper -from data_juicer.utils.unittest_utils import (SKIPPED_TESTS, - DataJuicerTestCaseBase) - -# Skip tests for this OP in the GitHub actions due to unknown DistNetworkError. -# These tests have been tested locally. -@SKIPPED_TESTS.register_module() -class GenerateInstructionMapperTest(DataJuicerTestCaseBase): - - text_key = 'text' - - def _run_generate_instruction(self, enable_vllm=False): - op = GenerateInstructionMapper( - hf_model='Qwen/Qwen-7B-Chat', - seed_file='demos/data/demo-dataset-chatml.jsonl', - instruct_num=2, - trust_remote_code=True, - enable_vllm=enable_vllm - ) - - from data_juicer.format.empty_formatter import EmptyFormatter - dataset = EmptyFormatter(3, [self.text_key]).load_dataset() - - dataset = dataset.map(op.process) - - for item in dataset: - out_sample = json.loads(item[self.text_key]) - print(f'Output sample: {out_sample}') - # test one output qa sample - self.assertIn('role', out_sample['messages'][0]) - self.assertIn('content', out_sample['messages'][0]) - - def test_generate_instruction(self): - self._run_generate_instruction() - - def test_generate_instruction_vllm(self): - self._run_generate_instruction(enable_vllm=True) - - -if __name__ == '__main__': - unittest.main() diff --git a/tests/ops/mapper/test_optimize_instruction_mapper.py b/tests/ops/mapper/test_optimize_instruction_mapper.py deleted file mode 100644 index 4b3e4562b..000000000 --- a/tests/ops/mapper/test_optimize_instruction_mapper.py +++ /dev/null @@ -1,36 +0,0 @@ -import unittest -from data_juicer.ops.mapper.optimize_instruction_mapper import OptimizeInstructionMapper -from data_juicer.utils.unittest_utils import (SKIPPED_TESTS, - DataJuicerTestCaseBase) - -# Skip tests for this OP in the GitHub actions due to unknown DistNetworkError. -# These tests have been tested locally. -@SKIPPED_TESTS.register_module() -class OptimizeInstructionMapperTest(DataJuicerTestCaseBase): - - text_key = 'text' - - def _run_optimize_instruction(self, enable_vllm=False): - op = OptimizeInstructionMapper( - hf_model='alibaba-pai/Qwen2-7B-Instruct-Refine', - enable_vllm=enable_vllm - ) - - samples = [ - {self.text_key: '鱼香肉丝怎么做?'} - ] - - for sample in samples: - result = op.process(sample) - print(f'Output results: {result}') - self.assertIn(self.text_key, result) - - def test_optimize_instruction(self): - self._run_optimize_instruction() - - def test_optimize_instruction_vllm(self): - self._run_optimize_instruction(enable_vllm=True) - - -if __name__ == '__main__': - unittest.main() From a637a64060e9071df0e293cf3aabcf5745ea0643 Mon Sep 17 00:00:00 2001 From: Haibin <1400012807@pku.edu.cn> Date: Mon, 11 Nov 2024 17:59:31 +0800 Subject: [PATCH 024/118] relavant -> relevant --- configs/config_all.yaml | 4 ++-- data_juicer/ops/mapper/extract_event_mapper.py | 12 ++++++------ data_juicer/utils/constant.py | 4 ++-- docs/Operators.md | 2 +- tests/ops/mapper/test_extract_event_mapper.py | 4 ++-- 5 files changed, 13 insertions(+), 13 deletions(-) diff --git a/configs/config_all.yaml b/configs/config_all.yaml index d620a505a..6bd4389b3 100644 --- a/configs/config_all.yaml +++ b/configs/config_all.yaml @@ -113,10 +113,10 @@ process: drop_text: false # If drop the text in the output. model_params: {} # Parameters for initializing the API model. sampling_params: {} # Extra parameters passed to the API call. e.g {'temperature': 0.9, 'top_p': 0.95} - - extract_event_mapper: # Extract events and relavant characters in the text + - extract_event_mapper: # Extract events and relevant characters in the text api_model: 'gpt-4o' # API model name. event_desc_key: '__dj__event_description__' # The field name to store the event descriptions. - relavant_char_key: '__dj__relavant_characters__' # The field name to store the relavant characters to the events. + relevant_char_key: '__dj__relevant_characters__' # The field name to store the relevant characters to the events. api_url: null # URL endpoint for the API. response_path: null # Path to extract content from the API response. Defaults to 'choices.0.message.content'. system_prompt: null # System prompt for the task. diff --git a/data_juicer/ops/mapper/extract_event_mapper.py b/data_juicer/ops/mapper/extract_event_mapper.py index 37e5aeb1c..1efc5915c 100644 --- a/data_juicer/ops/mapper/extract_event_mapper.py +++ b/data_juicer/ops/mapper/extract_event_mapper.py @@ -19,7 +19,7 @@ @OPERATORS.register_module(OP_NAME) class ExtractEventMapper(Mapper): """ - Extract events and relavant characters in the text + Extract events and relevant characters in the text """ _batched_op = True @@ -54,7 +54,7 @@ def __init__(self, api_model: str = 'gpt-4o', *, event_desc_key: str = Fields.event_description, - relavant_char_key: str = Fields.relavant_characters, + relevant_char_key: str = Fields.relevant_characters, api_url: Optional[str] = None, response_path: Optional[str] = None, system_prompt: Optional[str] = None, @@ -70,8 +70,8 @@ def __init__(self, :param api_model: API model name. :param event_desc_key: The field name to store the event descriptions. It's "__dj__event_description__" in default. - :param relavant_char_key: The field name to store the relavant - characters to the events. It's "__dj__relavant_characters__" in + :param relevant_char_key: The field name to store the relevant + characters to the events. It's "__dj__relevant_characters__" in default. :param api_url: URL endpoint for the API. :param response_path: Path to extract content from the API response. @@ -90,7 +90,7 @@ def __init__(self, super().__init__(**kwargs) self.event_desc_key = event_desc_key - self.relavant_char_key = relavant_char_key + self.relevant_char_key = relevant_char_key self.system_prompt = system_prompt or self.DEFAULT_SYSTEM_PROMPT self.input_template = input_template or self.DEFAULT_INPUT_TEMPLATE @@ -162,7 +162,7 @@ def process_batched(self, samples): samples[key] = [[samples[key][i]] * len(events[i]) for i in range(sample_num)] samples[self.event_desc_key] = events - samples[self.relavant_char_key] = characters + samples[self.relevant_char_key] = characters for key in samples: samples[key] = list(chain(*samples[key])) diff --git a/data_juicer/utils/constant.py b/data_juicer/utils/constant.py index df6bad68d..78c373d45 100644 --- a/data_juicer/utils/constant.py +++ b/data_juicer/utils/constant.py @@ -30,8 +30,8 @@ class Fields(object): # field names for info extraction event_description = DEFAULT_PREFIX + 'event_description__' - # # a list of characters relavant to the event - relavant_characters = DEFAULT_PREFIX + 'relavant_characters__' + # # a list of characters relevant to the event + relevant_characters = DEFAULT_PREFIX + 'relevant_characters__' # # the given main entity for attribute extraction main_entity = DEFAULT_PREFIX + 'main_entity__' # # the given attribute to be extracted diff --git a/docs/Operators.md b/docs/Operators.md index 9ec31799b..e24744848 100644 --- a/docs/Operators.md +++ b/docs/Operators.md @@ -69,7 +69,7 @@ All the specific operators are listed below, each featured with several capabili | expand_macro_mapper | ![LaTeX](https://img.shields.io/badge/LaTeX-D99379?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Expands macros usually defined at the top of TeX documents | [code](../data_juicer/ops/mapper/expand_macro_mapper.py) | [tests](../tests/ops/mapper/test_expand_macro_mapper.py) | | extract_entity_attribute_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Extract attributes for given entities from the text. | [code](../data_juicer/ops/mapper/extract_entity_attribute_mapper.py) | [tests](../tests/ops/mapper/test_extract_entity_attribute_mapper.py) | | extract_entity_relation_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Extract entities and relations in the text for knowledge graph. | [code](../data_juicer/ops/mapper/extract_entity_relation_mapper.py) | [tests](../tests/ops/mapper/test_extract_entity_relation_mapper.py) | -| extract_event_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Extract events and relavant characters in the text. | [code](../data_juicer/ops/mapper/extract_event_mapper.py) | [tests](../tests/ops/mapper/test_extract_event_mapper.py) | +| extract_event_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Extract events and relevant characters in the text. | [code](../data_juicer/ops/mapper/extract_event_mapper.py) | [tests](../tests/ops/mapper/test_extract_event_mapper.py) | | extract_keyword_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Generate keywords for the text. | [code](../data_juicer/ops/mapper/extract_keyword_mapper.py) | [tests](../tests/ops/mapper/test_extract_keyword_mapper.py) | | extract_nickname_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Extract nickname relationship in the text. | [code](../data_juicer/ops/mapper/extract_nickname_mapper.py) | [tests](../tests/ops/mapper/test_extract_nickname_mapper.py) | | fix_unicode_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Fixes broken Unicodes (by [ftfy](https://ftfy.readthedocs.io/)) | [code](../data_juicer/ops/mapper/fix_unicode_mapper.py) | [tests](../tests/ops/mapper/test_fix_unicode_mapper.py) | diff --git a/tests/ops/mapper/test_extract_event_mapper.py b/tests/ops/mapper/test_extract_event_mapper.py index c690d104c..1652c8db2 100644 --- a/tests/ops/mapper/test_extract_event_mapper.py +++ b/tests/ops/mapper/test_extract_event_mapper.py @@ -62,8 +62,8 @@ def _run_op(self, api_model, response_path=None): for sample in dataset: logger.info(f"event: {sample[Fields.event_description]}") self.assertNotEqual(sample[Fields.event_description], '') - logger.info(f"characters: {sample[Fields.relavant_characters]}") - self.assertNotEqual(sample[Fields.relavant_characters], []) + logger.info(f"characters: {sample[Fields.relevant_characters]}") + self.assertNotEqual(sample[Fields.relevant_characters], []) def test(self): # before runing this test, set below environment variables: From 56e7988d0013e747f6825b1baa7088a19736890e Mon Sep 17 00:00:00 2001 From: Haibin <1400012807@pku.edu.cn> Date: Mon, 11 Nov 2024 19:32:02 +0800 Subject: [PATCH 025/118] fix minor error --- data_juicer/ops/mapper/extract_entity_attribute_mapper.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/data_juicer/ops/mapper/extract_entity_attribute_mapper.py b/data_juicer/ops/mapper/extract_entity_attribute_mapper.py index fc259a1b6..6fd3d2b42 100644 --- a/data_juicer/ops/mapper/extract_entity_attribute_mapper.py +++ b/data_juicer/ops/mapper/extract_entity_attribute_mapper.py @@ -40,8 +40,8 @@ class ExtractEntityAttributeMapper(Mapper): DEFAULT_DEMON_PATTERN = r'\#\#\#\s*代表性示例(\d+):\s*(.*?)(?=\#\#\#|\Z)' def __init__(self, - query_entities: List[str], - query_attributes: List[str], + query_entities: List[str] = [], + query_attributes: List[str] = [], api_model: str = 'gpt-4o', *, entity_key: str = Fields.main_entity, From 03880b764b08648bccb8e89317fe50845485760f Mon Sep 17 00:00:00 2001 From: Haibin <1400012807@pku.edu.cn> Date: Tue, 12 Nov 2024 16:43:38 +0800 Subject: [PATCH 026/118] group by op done --- configs/config_all.yaml | 10 + data_juicer/ops/__init__.py | 5 +- data_juicer/ops/base_op.py | 61 +++++- data_juicer/ops/filter/__init__.py | 8 +- .../ops/filter/video_motion_score_filter.py | 67 +++---- .../filter/video_motion_score_raft_filter.py | 80 ++++++++ data_juicer/ops/grouper/__init__.py | 4 + data_juicer/ops/grouper/key_value_grouper.py | 51 +++++ data_juicer/ops/grouper/naive_grouper.py | 28 +++ data_juicer/utils/common_utils.py | 33 ++++ data_juicer/utils/file_utils.py | 18 +- data_juicer/utils/mm_utils.py | 54 ++++- docs/Operators.md | 3 +- docs/Operators_ZH.md | 3 +- .../filter/test_video_motion_score_filter.py | 6 +- .../test_video_motion_score_raft_filter.py | 186 ++++++++++++++++++ tests/ops/grouper/__init__.py | 0 tests/ops/grouper/test_key_value_grouper.py | 54 +++++ tests/ops/grouper/test_naive_grouper.py | 47 +++++ 19 files changed, 653 insertions(+), 65 deletions(-) create mode 100644 data_juicer/ops/filter/video_motion_score_raft_filter.py create mode 100644 data_juicer/ops/grouper/__init__.py create mode 100644 data_juicer/ops/grouper/key_value_grouper.py create mode 100644 data_juicer/ops/grouper/naive_grouper.py create mode 100644 tests/ops/filter/test_video_motion_score_raft_filter.py create mode 100644 tests/ops/grouper/__init__.py create mode 100644 tests/ops/grouper/test_key_value_grouper.py create mode 100644 tests/ops/grouper/test_naive_grouper.py diff --git a/configs/config_all.yaml b/configs/config_all.yaml index 48c9f172e..d9951ca02 100644 --- a/configs/config_all.yaml +++ b/configs/config_all.yaml @@ -465,6 +465,16 @@ process: sampling_fps: 2 # the samplig rate of frames_per_second to compute optical flow size: null # resize frames along the smaller edge before computing optical flow, or a sequence like (h, w) max_size: null # maximum allowed for the longer edge of resized frames + divisible: 1 # The number that the dimensions must be divisible by. + relative: false # whether to normalize the optical flow magnitude to [0, 1], relative to the frame's diagonal length + any_or_all: any # keep this sample when any/all videos meet the filter condition + - video_motion_score_raft_filter: # Keep samples with video motion scores (based on RAFT model) within a specific range. + min_score: 1.0 # the minimum motion score to keep samples + max_score: 10000.0 # the maximum motion score to keep samples + sampling_fps: 2 # the samplig rate of frames_per_second to compute optical flow + size: null # resize frames along the smaller edge before computing optical flow, or a sequence like (h, w) + max_size: null # maximum allowed for the longer edge of resized frames + divisible: 8 # The number that the dimensions must be divisible by. relative: false # whether to normalize the optical flow magnitude to [0, 1], relative to the frame's diagonal length any_or_all: any # keep this sample when any/all videos meet the filter condition - video_nsfw_filter: # filter samples according to the nsfw scores of videos in them diff --git a/data_juicer/ops/__init__.py b/data_juicer/ops/__init__.py index c7ab44c25..38116efd5 100644 --- a/data_juicer/ops/__init__.py +++ b/data_juicer/ops/__init__.py @@ -1,6 +1,6 @@ from . import deduplicator, filter, mapper, selector -from .base_op import (OPERATORS, UNFORKABLE, Deduplicator, Filter, Mapper, - Selector) +from .base_op import (OPERATORS, UNFORKABLE, Deduplicator, Filter, Grouper, + Mapper, Selector) from .load import load_ops __all__ = [ @@ -9,4 +9,5 @@ 'Mapper', 'Deduplicator', 'Selector', + 'Grouper', ] diff --git a/data_juicer/ops/base_op.py b/data_juicer/ops/base_op.py index 13f3b61ae..48ecf4947 100644 --- a/data_juicer/ops/base_op.py +++ b/data_juicer/ops/base_op.py @@ -7,6 +7,7 @@ from loguru import logger from data_juicer import is_cuda_available +from data_juicer.core.data import NestedDataset from data_juicer.utils.constant import Fields from data_juicer.utils.mm_utils import size_to_bytes from data_juicer.utils.process_utils import calculate_np @@ -128,6 +129,10 @@ def __init__(self, *args, **kwargs): to be processed :param video_key: the key name of field that stores sample video list to be processed + :param query_key: the key name of field that stores sample queris + :param response_key: the key name of field that stores responses + :param history_key: the key name of field that stores history of + queries and responses """ # init data keys self.text_key = kwargs.get('text_key', 'text') @@ -211,7 +216,6 @@ def add_parameters(self, init_parameter_dict, **extra_param_dict): return related_parameters def run(self, dataset): - from data_juicer.core.data import NestedDataset if not isinstance(dataset, NestedDataset): dataset = NestedDataset(dataset) return dataset @@ -234,6 +238,10 @@ def __init__(self, *args, **kwargs): to be processed :param video_key: the key name of field that stores sample video list to be processed + :param query_key: the key name of field that stores sample queris + :param response_key: the key name of field that stores responses + :param history_key: the key name of field that stores history of + queries and responses """ super(Mapper, self).__init__(*args, **kwargs) @@ -303,6 +311,10 @@ def __init__(self, *args, **kwargs): to be processed :param video_key: the key name of field that stores sample video list to be processed + :param query_key: the key name of field that stores sample queris + :param response_key: the key name of field that stores responses + :param history_key: the key name of field that stores history of + queries and responses """ super(Filter, self).__init__(*args, **kwargs) self.stats_export_path = kwargs.get('stats_export_path', None) @@ -410,6 +422,10 @@ def __init__(self, *args, **kwargs): to be processed :param video_key: the key name of field that stores sample video list to be processed + :param query_key: the key name of field that stores sample queris + :param response_key: the key name of field that stores responses + :param history_key: the key name of field that stores history of + queries and responses """ super(Deduplicator, self).__init__(*args, **kwargs) @@ -469,6 +485,10 @@ def __init__(self, *args, **kwargs): to be processed :param video_key: the key name of field that stores sample video list to be processed + :param query_key: the key name of field that stores sample queris + :param response_key: the key name of field that stores responses + :param history_key: the key name of field that stores history of + queries and responses """ super(Selector, self).__init__(*args, **kwargs) @@ -487,3 +507,42 @@ def run(self, dataset, *, exporter=None, tracer=None): if tracer: tracer.trace_filter(self._name, dataset, new_dataset) return new_dataset + + +class Grouper(OP): + + def __init__(self, *args, **kwargs): + """ + Base class that group samples. + + :param text_key: the key name of field that stores sample texts + to be processed + :param image_key: the key name of field that stores sample image list + to be processed + :param audio_key: the key name of field that stores sample audio list + to be processed + :param video_key: the key name of field that stores sample video list + to be processed + :param query_key: the key name of field that stores sample queris + :param response_key: the key name of field that stores responses + :param history_key: the key name of field that stores history of + queries and responses + """ + super(Grouper, self).__init__(*args, **kwargs) + + def process(self, dataset): + """ + Dataset --> dataset. + + :param dataset: input dataset + :return: dataset of batched samples. + """ + raise NotImplementedError + + def run(self, dataset, *, exporter=None, tracer=None): + dataset = super(Grouper, self).run(dataset) + batched_samples = self.process(dataset) + new_dataset = NestedDataset.from_list(batched_samples) + if tracer: + tracer.trace_filter(self._name, dataset, new_dataset) + return new_dataset diff --git a/data_juicer/ops/filter/__init__.py b/data_juicer/ops/filter/__init__.py index 718f06cd3..dad6818e1 100644 --- a/data_juicer/ops/filter/__init__.py +++ b/data_juicer/ops/filter/__init__.py @@ -35,6 +35,7 @@ from .video_frames_text_similarity_filter import \ VideoFramesTextSimilarityFilter from .video_motion_score_filter import VideoMotionScoreFilter +from .video_motion_score_raft_filter import VideoMotionScoreRaftFilter from .video_nsfw_filter import VideoNSFWFilter from .video_ocr_area_ratio_filter import VideoOcrAreaRatioFilter from .video_resolution_filter import VideoResolutionFilter @@ -57,7 +58,8 @@ 'TextActionFilter', 'TextEntityDependencyFilter', 'TextLengthFilter', 'TokenNumFilter', 'VideoAestheticsFilter', 'VideoAspectRatioFilter', 'VideoDurationFilter', 'VideoFramesTextSimilarityFilter', - 'VideoMotionScoreFilter', 'VideoNSFWFilter', 'VideoOcrAreaRatioFilter', - 'VideoResolutionFilter', 'VideoTaggingFromFramesFilter', - 'VideoWatermarkFilter', 'WordRepetitionFilter', 'WordsNumFilter' + 'VideoMotionScoreFilter', 'VideoMotionScoreRaftFilter', 'VideoNSFWFilter', + 'VideoOcrAreaRatioFilter', 'VideoResolutionFilter', + 'VideoTaggingFromFramesFilter', 'VideoWatermarkFilter', + 'WordRepetitionFilter', 'WordsNumFilter' ] diff --git a/data_juicer/ops/filter/video_motion_score_filter.py b/data_juicer/ops/filter/video_motion_score_filter.py index 43dd38a01..dddd3a1a1 100644 --- a/data_juicer/ops/filter/video_motion_score_filter.py +++ b/data_juicer/ops/filter/video_motion_score_filter.py @@ -7,6 +7,7 @@ from data_juicer.utils.constant import Fields, StatsKeys from data_juicer.utils.lazy_loader import LazyLoader +from data_juicer.utils.mm_utils import calculate_resized_dimensions from ..base_op import OPERATORS, UNFORKABLE, Filter @@ -48,6 +49,7 @@ def __init__(self, size: Union[PositiveInt, Tuple[PositiveInt], Tuple[PositiveInt, PositiveInt], None] = None, max_size: Optional[PositiveInt] = None, + divisible: PositiveInt = 1, relative: bool = False, any_or_all: str = 'any', *args, @@ -69,6 +71,7 @@ def __init__(self, being resized according to size, size will be overruled so that the longer edge is equal to max_size. As a result, the smaller edge may be shorter than size. This is only supported if size is an int. + :param divisible: The number that the dimensions must be divisible by. :param relative: If `True`, the optical flow magnitude is normalized to a [0, 1] range, relative to the frame's diagonal length. :param any_or_all: keep this sample with 'any' or 'all' strategy of @@ -92,6 +95,7 @@ def __init__(self, size = (size, ) self.size = size self.max_size = max_size + self.divisible = divisible self.relative = relative self.extra_kwargs = self._default_kwargs @@ -104,7 +108,21 @@ def __init__(self, f'Can only be one of ["any", "all"].') self.any = (any_or_all == 'any') - def compute_stats_single(self, sample, context=False): + def setup_model(self, rank=None): + self.model = cv2.calcOpticalFlowFarneback + + def compute_flow(self, prev_frame, curr_frame): + curr_frame = cv2.cvtColor(curr_frame, cv2.COLOR_BGR2GRAY) + if prev_frame is None: + flow = None + else: + flow = self.model(prev_frame, curr_frame, None, + **self.extra_kwargs) + return flow, curr_frame + + def compute_stats_single(self, sample, rank=None, context=False): + self.rank = rank + # check if it's computed already if StatsKeys.video_motion_score in sample[Fields.stats]: return sample @@ -115,6 +133,8 @@ def compute_stats_single(self, sample, context=False): [], dtype=np.float64) return sample + self.setup_model(rank) + # load videos loaded_video_keys = sample[self.video_key] unique_motion_scores = {} @@ -133,6 +153,11 @@ def compute_stats_single(self, sample, context=False): # at least two frames for computing optical flow sampling_step = max(min(sampling_step, total_frames - 1), 1) + height = cap.get(cv2.CAP_PROP_FRAME_HEIGHT) + width = cap.get(cv2.CAP_PROP_FRAME_WIDTH) + new_size = calculate_resized_dimensions( + (height, width), self.size, self.max_size, + self.divisible) prev_frame = None frame_count = 0 @@ -143,27 +168,21 @@ def compute_stats_single(self, sample, context=False): # a corrupt frame or reaching the end of the video. break - height, width, _ = frame.shape - new_size = _compute_resized_output_size( - (height, width), self.size, self.max_size) if new_size != (height, width): frame = cv2.resize(frame, new_size, interpolation=cv2.INTER_AREA) - gray_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) - if prev_frame is None: - prev_frame = gray_frame + # return flow of shape (H, W, 2) and transformed frame + # of shape (H, W, 3) in BGR mode + flow, prev_frame = self.compute_flow(prev_frame, frame) + if flow is None: continue - - flow = cv2.calcOpticalFlowFarneback( - prev_frame, gray_frame, None, **self.extra_kwargs) mag, _ = cv2.cartToPolar(flow[..., 0], flow[..., 1]) frame_motion_score = np.mean(mag) if self.relative: - frame_motion_score /= np.hypot(*flow.shape[:2]) + frame_motion_score /= np.hypot(*frame.shape[:2]) video_motion_scores.append(frame_motion_score) - prev_frame = gray_frame # quickly skip frames frame_count += sampling_step @@ -197,27 +216,3 @@ def process_single(self, sample): return keep_bools.any() else: return keep_bools.all() - - -def _compute_resized_output_size( - frame_size: Tuple[int, int], - size: Union[Tuple[PositiveInt], Tuple[PositiveInt, PositiveInt]], - max_size: Optional[int] = None, -) -> Tuple[int, int]: - h, w = frame_size - short, long = (w, h) if w <= h else (h, w) - - if size is None: # no change - new_short, new_long = short, long - elif len(size) == 1: # specified size only for the smallest edge - new_short = size[0] - new_long = int(new_short * long / short) - else: # specified both h and w - new_short, new_long = min(size), max(size) - - if max_size is not None and new_long > max_size: - new_short = int(max_size * new_short / new_long) - new_long = max_size - - new_w, new_h = (new_short, new_long) if w <= h else (new_long, new_short) - return new_h, new_w diff --git a/data_juicer/ops/filter/video_motion_score_raft_filter.py b/data_juicer/ops/filter/video_motion_score_raft_filter.py new file mode 100644 index 000000000..0be2199c2 --- /dev/null +++ b/data_juicer/ops/filter/video_motion_score_raft_filter.py @@ -0,0 +1,80 @@ +import sys +from typing import Optional, Tuple, Union + +from pydantic import PositiveFloat, PositiveInt + +from data_juicer import cuda_device_count +from data_juicer.ops.filter.video_motion_score_filter import \ + VideoMotionScoreFilter +from data_juicer.utils.lazy_loader import LazyLoader + +from ..base_op import OPERATORS, UNFORKABLE + +torch = LazyLoader('torch', 'torch') +tvm = LazyLoader('tvm', 'torchvision.models') +tvt = LazyLoader('tvt', 'torchvision.transforms') + +OP_NAME = 'video_motion_score_raft_filter' + + +@UNFORKABLE.register_module(OP_NAME) +@OPERATORS.register_module(OP_NAME) +class VideoMotionScoreRaftFilter(VideoMotionScoreFilter): + """Filter to keep samples with video motion scores within a specified range. + This operator utilizes the RAFT (Recurrent All-Pairs Field Transforms) + model from torchvision to predict optical flow between video frames. + + For further details, refer to the official torchvision documentation: + https://pytorch.org/vision/main/models/raft.html + + The original paper on RAFT is available here: + https://arxiv.org/abs/2003.12039 + """ + + _accelerator = 'cuda' + _default_kwargs = {} + + def __init__(self, + min_score: float = 1.0, + max_score: float = sys.float_info.max, + sampling_fps: PositiveFloat = 2, + size: Union[PositiveInt, Tuple[PositiveInt], + Tuple[PositiveInt, PositiveInt], None] = None, + max_size: Optional[PositiveInt] = None, + divisible: PositiveInt = 8, + relative: bool = False, + any_or_all: str = 'any', + *args, + **kwargs): + super().__init__(min_score, max_score, sampling_fps, size, max_size, + divisible, relative, any_or_all, *args, **kwargs) + + def setup_model(self, rank=None): + self.model = tvm.optical_flow.raft_large( + weights=tvm.optical_flow.Raft_Large_Weights.DEFAULT, + progress=False) + if self.use_cuda(): + rank = rank if rank is not None else 0 + rank = rank % cuda_device_count() + self.device = f'cuda:{rank}' + else: + self.device = 'cpu' + self.model.to(self.device) + self.model.eval() + + self.transforms = tvt.Compose([ + tvt.ToTensor(), + tvt.Normalize(mean=0.5, std=0.5), # map [0, 1] into [-1, 1] + tvt.Lambda(lambda img: img.flip(-3).unsqueeze(0)), # BGR to RGB + ]) + + def compute_flow(self, prev_frame, curr_frame): + curr_frame = self.transforms(curr_frame).to(self.device) + if prev_frame is None: + flow = None + else: + with torch.inference_mode(): + flows = self.model(prev_frame, curr_frame) + flow = flows[-1][0].cpu().numpy().transpose( + (1, 2, 0)) # 2, H, W -> H, W, 2 + return flow, curr_frame diff --git a/data_juicer/ops/grouper/__init__.py b/data_juicer/ops/grouper/__init__.py new file mode 100644 index 000000000..048b305e4 --- /dev/null +++ b/data_juicer/ops/grouper/__init__.py @@ -0,0 +1,4 @@ +from .key_value_grouper import KeyValueGrouper +from .naive_grouper import NaiveGrouper + +__all__ = ['NaiveGrouper', 'KeyValueGrouper'] diff --git a/data_juicer/ops/grouper/key_value_grouper.py b/data_juicer/ops/grouper/key_value_grouper.py new file mode 100644 index 000000000..15b8d0328 --- /dev/null +++ b/data_juicer/ops/grouper/key_value_grouper.py @@ -0,0 +1,51 @@ +from typing import List, Optional + +from data_juicer.utils.common_utils import dict_to_hash, get_val_by_nested_key + +from ..base_op import OPERATORS, Grouper +from .naive_grouper import NaiveGrouper + + +@OPERATORS.register_module('key_value_grouper') +class KeyValueGrouper(Grouper): + """Group samples to batched samples according values in given keys. """ + + def __init__(self, + group_by_keys: Optional[List[str]] = None, + *args, + **kwargs): + """ + Initialization method. + + :param group_by_keys: group samples according values in the keys. + Support for nested keys such as "__dj__stats__.text_len". + It is [self.text_key] in default. + :param args: extra args + :param kwargs: extra args + """ + super().__init__(*args, **kwargs) + + self.group_by_keys = group_by_keys or [self.text_key] + self.naive_grouper = NaiveGrouper() + + def process(self, dataset): + + if len(dataset) == 0: + return dataset + + sample_map = {} + for sample in dataset: + cur_dict = {} + for key in self.group_by_keys: + cur_dict[key] = get_val_by_nested_key(sample, key) + sample_key = dict_to_hash(cur_dict) + if sample_key in sample_map: + sample_map[sample_key].append(sample) + else: + sample_map[sample_key] = [sample] + + batched_samples = [ + self.naive_grouper.process(sample_map[k])[0] for k in sample_map + ] + + return batched_samples diff --git a/data_juicer/ops/grouper/naive_grouper.py b/data_juicer/ops/grouper/naive_grouper.py new file mode 100644 index 000000000..92da22875 --- /dev/null +++ b/data_juicer/ops/grouper/naive_grouper.py @@ -0,0 +1,28 @@ +from ..base_op import OPERATORS, Grouper + + +@OPERATORS.register_module('naive_grouper') +class NaiveGrouper(Grouper): + """Group all samples to one batched sample. """ + + def __init__(self, *args, **kwargs): + """ + Initialization method. + + :param args: extra args + :param kwargs: extra args + """ + super().__init__(*args, **kwargs) + + def process(self, dataset): + + if len(dataset) == 0: + return dataset + + keys = dataset[0].keys() + batched_sample = {k: [None] * len(dataset) for k in keys} + for i, sample in enumerate(dataset): + for k in keys: + batched_sample[k][i] = sample[k] + + return [batched_sample] diff --git a/data_juicer/utils/common_utils.py b/data_juicer/utils/common_utils.py index 5bd336b9b..0da2cb017 100644 --- a/data_juicer/utils/common_utils.py +++ b/data_juicer/utils/common_utils.py @@ -1,6 +1,8 @@ +import hashlib import sys import numpy as np +from loguru import logger def stats_to_number(s, reverse=True): @@ -19,3 +21,34 @@ def stats_to_number(s, reverse=True): return -sys.maxsize else: return sys.maxsize + + +def dict_to_hash(input_dict: dict, hash_length=None): + """ + hash a dict to a string with length hash_length + + :param input_dict: the given dict + """ + sorted_items = sorted(input_dict.items()) + dict_string = str(sorted_items).encode() + hasher = hashlib.sha256() + hasher.update(dict_string) + hash_value = hasher.hexdigest() + if hash_length: + hash_value = hash_value[:hash_length] + return hash_value + + +def get_val_by_nested_key(input_dict: dict, nested_key: str): + """ + return val of the dict in the nested key. + + :param nested_key: the nested key, such as "__dj__stats__.text_len" + """ + keys = nested_key.split('.') + cur = input_dict + for key in keys: + if key not in cur: + logger.warning(f'Unvisitable nested key: {nested_key}!') + cur = cur[key] + return cur diff --git a/data_juicer/utils/file_utils.py b/data_juicer/utils/file_utils.py index e2fc241cd..7a8618660 100644 --- a/data_juicer/utils/file_utils.py +++ b/data_juicer/utils/file_utils.py @@ -1,6 +1,5 @@ import asyncio import copy -import hashlib import os import re import shutil @@ -10,6 +9,7 @@ from datasets.utils.extract import ZstdExtractor as Extractor +from data_juicer.utils.common_utils import dict_to_hash from data_juicer.utils.constant import DEFAULT_PREFIX, Fields @@ -127,22 +127,6 @@ def add_suffix_to_filename(filename, suffix): return new_name -def dict_to_hash(input_dict, hash_length=None): - """ - hash a dict to a string with length hash_length - - :param input_dict: the given dict - """ - sorted_items = sorted(input_dict.items()) - dict_string = str(sorted_items).encode() - hasher = hashlib.sha256() - hasher.update(dict_string) - hash_value = hasher.hexdigest() - if hash_length: - hash_value = hash_value[:hash_length] - return hash_value - - def create_directory_if_not_exists(directory_path): """ create a directory if not exists, this function is process safe diff --git a/data_juicer/utils/mm_utils.py b/data_juicer/utils/mm_utils.py index 7aa2df833..a1c09668d 100644 --- a/data_juicer/utils/mm_utils.py +++ b/data_juicer/utils/mm_utils.py @@ -3,7 +3,7 @@ import os import re import shutil -from typing import List, Optional, Union +from typing import List, Optional, Tuple, Union import av import numpy as np @@ -164,6 +164,58 @@ def iou(box1, box2): return 1.0 * intersection / union +def calculate_resized_dimensions( + original_size: Tuple[PositiveInt, PositiveInt], + target_size: Union[PositiveInt, Tuple[PositiveInt, PositiveInt]], + max_length: Optional[int] = None, + divisible: PositiveInt = 1) -> Tuple[int, int]: + """ + Resize dimensions based on specified constraints. + + :param original_size: The original dimensions as (height, width). + :param target_size: Desired target size; can be a single integer + (short edge) or a tuple (height, width). + :param max_length: Maximum allowed length for the longer edge. + :param divisible: The number that the dimensions must be divisible by. + :return: Resized dimensions as (height, width). + """ + + height, width = original_size + short_edge, long_edge = sorted((width, height)) + + # Normalize target_size to a tuple + if isinstance(target_size, int): + target_size = (target_size, ) + + # Initialize new dimensions + if target_size: + if len(target_size) == 1: # Only the smaller edge is specified + new_short_edge = target_size[0] + new_long_edge = int(new_short_edge * long_edge / short_edge) + else: # Both dimensions are specified + new_short_edge = min(target_size) + new_long_edge = max(target_size) + else: # No change + new_short_edge, new_long_edge = short_edge, long_edge + + # Enforce maximum length constraint + if max_length is not None and new_long_edge > max_length: + scaling_factor = max_length / new_long_edge + new_short_edge = int(new_short_edge * scaling_factor) + new_long_edge = max_length + + # Determine final dimensions based on original orientation + resized_dimensions = ((new_short_edge, + new_long_edge) if width <= height else + (new_long_edge, new_short_edge)) + + # Ensure final dimensions are divisible by the specified value + resized_dimensions = tuple( + int(dim / divisible) * divisible for dim in resized_dimensions) + + return resized_dimensions + + # Audios def load_audios(paths): return [load_audio(path) for path in paths] diff --git a/docs/Operators.md b/docs/Operators.md index a4cbba709..2a25c4847 100644 --- a/docs/Operators.md +++ b/docs/Operators.md @@ -12,7 +12,7 @@ The operators in Data-Juicer are categorized into 5 types. |-----------------------------------|:------:|-------------------------------------------------| | [ Formatter ]( #formatter ) | 9 | Discovers, loads, and canonicalizes source data | | [ Mapper ]( #mapper ) | 52 | Edits and transforms samples | -| [ Filter ]( #filter ) | 43 | Filters out low-quality samples | +| [ Filter ]( #filter ) | 44 | Filters out low-quality samples | | [ Deduplicator ]( #deduplicator ) | 8 | Detects and removes duplicate samples | | [ Selector ]( #selector ) | 4 | Selects top samples based on ranking | @@ -149,6 +149,7 @@ All the specific operators are listed below, each featured with several capabili | video_duration_filter | ![Video](https://img.shields.io/badge/Video-F2B138?style=plastic) | Keep data samples whose videos' durations are within a specified range | [code](../data_juicer/ops/filter/video_duration_filter.py) | [tests](../tests/ops/filter/test_video_duration_filter.py) | | video_frames_text_similarity_filter | ![Multimodal](https://img.shields.io/badge/Multimodal-F25922?style=plastic) ![GPU](https://img.shields.io/badge/GPU-F27649?style=plastic) | Keep data samples whose similarities between sampled video frame images and text are within a specific range | [code](../data_juicer/ops/filter/video_frames_text_similarity_filter.py) | [tests](../tests/ops/filter/test_video_frames_text_similarity_filter.py) | | video_motion_score_filter | ![Video](https://img.shields.io/badge/Video-F2B138?style=plastic) | Keep samples with video motion scores within a specific range | [code](../data_juicer/ops/filter/video_motion_score_filter.py) | [tests](../tests/ops/filter/test_video_motion_score_filter.py) | +| video_motion_score_raft_filter | ![Video](https://img.shields.io/badge/Video-F2B138?style=plastic) | Keep samples with video motion scores (based on RAFT model) within a specific range | [code](../data_juicer/ops/filter/video_motion_score_raft_filter.py) | [tests](../tests/ops/filter/test_video_motion_score_raft_filter.py) | | video_nsfw_filter | ![Video](https://img.shields.io/badge/Video-F2B138?style=plastic) ![GPU](https://img.shields.io/badge/GPU-F27649?style=plastic) | Keeps samples containing videos with NSFW scores below the threshold | [code](../data_juicer/ops/filter/video_nsfw_filter.py) | [tests](../tests/ops/filter/test_video_nsfw_filter.py) | | video_ocr_area_ratio_filter | ![Video](https://img.shields.io/badge/Video-F2B138?style=plastic) ![GPU](https://img.shields.io/badge/GPU-F27649?style=plastic) | Keep data samples whose detected text area ratios for specified frames in the video are within a specified range | [code](../data_juicer/ops/filter/video_ocr_area_ratio_filter.py) | [tests](../tests/ops/filter/test_video_ocr_area_ratio_filter.py) | | video_resolution_filter | ![Video](https://img.shields.io/badge/Video-F2B138?style=plastic) | Keeps samples containing videos with horizontal and vertical resolutions within the specified range | [code](../data_juicer/ops/filter/video_resolution_filter.py) | [tests](../tests/ops/filter/test_video_resolution_filter.py) | diff --git a/docs/Operators_ZH.md b/docs/Operators_ZH.md index 5dfe64141..88d739d66 100644 --- a/docs/Operators_ZH.md +++ b/docs/Operators_ZH.md @@ -12,7 +12,7 @@ Data-Juicer 中的算子分为以下 5 种类型。 |------------------------------------|:--:|---------------| | [ Formatter ]( #formatter ) | 9 | 发现、加载、规范化原始数据 | | [ Mapper ]( #mapper ) | 52 | 对数据样本进行编辑和转换 | -| [ Filter ]( #filter ) | 43 | 过滤低质量样本 | +| [ Filter ]( #filter ) | 44 | 过滤低质量样本 | | [ Deduplicator ]( #deduplicator ) | 8 | 识别、删除重复样本 | | [ Selector ]( #selector ) | 4 | 基于排序选取高质量样本 | @@ -148,6 +148,7 @@ Data-Juicer 中的算子分为以下 5 种类型。 | video_aesthetics_filter | ![Video](https://img.shields.io/badge/Video-F2B138?style=plastic) | 保留指定帧的美学分数在指定范围内的样本 | [code](../data_juicer/ops/filter/video_duration_filter.py) | [tests](../tests/ops/filter/test_video_duration_filter.py) | | video_frames_text_similarity_filter | ![Multimodal](https://img.shields.io/badge/Multimodal-F25922?style=plastic) ![GPU](https://img.shields.io/badge/GPU-F27649?style=plastic) | 保留视频中指定帧的图像-文本的特征余弦相似度(基于CLIP模型)在指定范围内的样本 | [code](../data_juicer/ops/filter/video_frames_text_similarity_filter.py) | [tests](../tests/ops/filter/test_video_frames_text_similarity_filter.py) | | video_motion_score_filter | ![Video](https://img.shields.io/badge/Video-F2B138?style=plastic) | 保留包含视频的运动分数(基于稠密光流)在指定范围内的样本 | [code](../data_juicer/ops/filter/video_motion_score_filter.py) | [tests](../tests/ops/filter/test_video_motion_score_filter.py) | +| video_motion_score_raft_filter | ![Video](https://img.shields.io/badge/Video-F2B138?style=plastic) | 保留包含视频的运动分数(基于 RAFT 模型估计的稠密光流)在指定范围内的样本 | [code](../data_juicer/ops/filter/video_motion_score_raft_raft_filter.py) | [tests](../tests/ops/filter/test_video_motion_score_filter.py) | | video_nsfw_filter | ![Video](https://img.shields.io/badge/Video-F2B138?style=plastic) ![GPU](https://img.shields.io/badge/GPU-F27649?style=plastic) | 保留包含视频的NSFW分数在指定阈值之下的样本 | [code](../data_juicer/ops/filter/video_nsfw_filter.py) | [tests](../tests/ops/filter/test_video_nsfw_filter.py) | | video_ocr_area_ratio_filter | ![Video](https://img.shields.io/badge/Video-F2B138?style=plastic) ![GPU](https://img.shields.io/badge/GPU-F27649?style=plastic) | 保留包含视频的特定帧中检测出的文本的面积占比在指定范围内的样本 | [code](../data_juicer/ops/filter/video_ocr_area_ratio_filter.py) | [tests](../tests/ops/filter/test_video_ocr_area_ratio_filter.py) | | video_resolution_filter | ![Video](https://img.shields.io/badge/Video-F2B138?style=plastic) | 保留包含视频的分辨率(包括横向分辨率和纵向分辨率)在指定范围内的样本 | [code](../data_juicer/ops/filter/video_resolution_filter.py) | [tests](../tests/ops/filter/test_video_resolution_filter.py) | diff --git a/tests/ops/filter/test_video_motion_score_filter.py b/tests/ops/filter/test_video_motion_score_filter.py index c83ce2fd0..c7dea74a1 100644 --- a/tests/ops/filter/test_video_motion_score_filter.py +++ b/tests/ops/filter/test_video_motion_score_filter.py @@ -13,9 +13,9 @@ class VideoMotionScoreFilterTest(DataJuicerTestCaseBase): data_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', 'data') - vid1_path = os.path.join(data_path, 'video1.mp4') # 1.8210126 - vid2_path = os.path.join(data_path, 'video2.mp4') # 3.600746 - vid3_path = os.path.join(data_path, 'video3.mp4') # 1.1822891 + vid1_path = os.path.join(data_path, 'video1.mp4') # 1.869317 + vid2_path = os.path.join(data_path, 'video2.mp4') # 3.52111 + vid3_path = os.path.join(data_path, 'video3.mp4') # 1.1731424 def _run_helper(self, op, source_list, target_list, np=1): dataset = Dataset.from_list(source_list) diff --git a/tests/ops/filter/test_video_motion_score_raft_filter.py b/tests/ops/filter/test_video_motion_score_raft_filter.py new file mode 100644 index 000000000..abd8e7374 --- /dev/null +++ b/tests/ops/filter/test_video_motion_score_raft_filter.py @@ -0,0 +1,186 @@ +import os +import unittest + +from datasets import Dataset + +from data_juicer.ops.filter.video_motion_score_raft_filter import \ + VideoMotionScoreRaftFilter +from data_juicer.utils.constant import Fields +from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase + + +class VideoMotionScoreRaftFilterTest(DataJuicerTestCaseBase): + + data_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', + 'data') + vid1_path = os.path.join(data_path, 'video1.mp4') # 10.766147 + vid2_path = os.path.join(data_path, 'video2.mp4') # 10.098914 + vid3_path = os.path.join(data_path, 'video3.mp4') # 2.0731936 + + def _run_helper(self, op, source_list, target_list, np=1): + dataset = Dataset.from_list(source_list) + if Fields.stats not in dataset.features: + dataset = dataset.add_column(name=Fields.stats, + column=[{}] * dataset.num_rows) + dataset = dataset.map(op.compute_stats, num_proc=np) + dataset = dataset.filter(op.process, num_proc=np) + dataset = dataset.select_columns(column_names=[op.video_key]) + res_list = dataset.to_list() + self.assertEqual(res_list, target_list) + + def test_default(self): + ds_list = [{ + 'videos': [self.vid1_path] + }, { + 'videos': [self.vid2_path] + }, { + 'videos': [self.vid3_path] + }] + tgt_list = [{ + 'videos': [self.vid1_path] + }, { + 'videos': [self.vid2_path] + }, { + 'videos': [self.vid3_path] + }] + op = VideoMotionScoreRaftFilter() + self._run_helper(op, ds_list, tgt_list) + + def test_downscale(self): + ds_list = [{ + 'videos': [self.vid1_path] + }, { + 'videos': [self.vid2_path] + }, { + 'videos': [self.vid3_path] + }] + tgt_list = [{ + 'videos': [self.vid1_path] + }, { + 'videos': [self.vid2_path] + }] + op = VideoMotionScoreRaftFilter(min_score=1.0, size=128) + self._run_helper(op, ds_list, tgt_list) + + def test_downscale_max(self): + ds_list = [{ + 'videos': [self.vid1_path] + }, { + 'videos': [self.vid2_path] + }, { + 'videos': [self.vid3_path] + }] + tgt_list = [{ + 'videos': [self.vid1_path] + }, { + 'videos': [self.vid2_path] + }] + op = VideoMotionScoreRaftFilter(min_score=1.0, size=256, max_size=256) + self._run_helper(op, ds_list, tgt_list) + + def test_downscale_relative(self): + ds_list = [{ + 'videos': [self.vid1_path] + }, { + 'videos': [self.vid2_path] + }, { + 'videos': [self.vid3_path] + }] + tgt_list = [{ + 'videos': [self.vid1_path] + }, { + 'videos': [self.vid2_path] + }] + op = VideoMotionScoreRaftFilter(min_score=0.005, size=(128, 160), relative=True) + self._run_helper(op, ds_list, tgt_list) + + def test_high(self): + ds_list = [{ + 'videos': [self.vid1_path] + }, { + 'videos': [self.vid2_path] + }, { + 'videos': [self.vid3_path] + }] + tgt_list = [{ + 'videos': [self.vid1_path] + }, { + 'videos': [self.vid2_path] + }] + op = VideoMotionScoreRaftFilter(min_score=10) + self._run_helper(op, ds_list, tgt_list) + + def test_low(self): + ds_list = [{ + 'videos': [self.vid1_path] + }, { + 'videos': [self.vid2_path] + }, { + 'videos': [self.vid3_path] + }] + tgt_list = [{'videos': [self.vid3_path]}] + op = VideoMotionScoreRaftFilter(min_score=0.0, max_score=3) + self._run_helper(op, ds_list, tgt_list) + + def test_middle(self): + ds_list = [{ + 'videos': [self.vid1_path] + }, { + 'videos': [self.vid2_path] + }, { + 'videos': [self.vid3_path] + }] + tgt_list = [{'videos': [self.vid2_path]}] + op = VideoMotionScoreRaftFilter(min_score=3, max_score=10.5) + self._run_helper(op, ds_list, tgt_list) + + def test_any(self): + ds_list = [{ + 'videos': [self.vid1_path, self.vid2_path] + }, { + 'videos': [self.vid2_path, self.vid3_path] + }, { + 'videos': [self.vid1_path, self.vid3_path] + }] + tgt_list = [{ + 'videos': [self.vid1_path, self.vid2_path] + }, { + 'videos': [self.vid2_path, self.vid3_path] + }] + op = VideoMotionScoreRaftFilter(min_score=3, + max_score=10.5, + any_or_all='any') + self._run_helper(op, ds_list, tgt_list) + + def test_all(self): + ds_list = [{ + 'videos': [self.vid1_path, self.vid2_path] + }, { + 'videos': [self.vid2_path, self.vid3_path] + }, { + 'videos': [self.vid1_path, self.vid3_path] + }] + tgt_list = [] + op = VideoMotionScoreRaftFilter(min_score=3, + max_score=10.5, + any_or_all='all') + self._run_helper(op, ds_list, tgt_list) + + def test_parallel(self): + import multiprocess as mp + mp.set_start_method('spawn', force=True) + + ds_list = [{ + 'videos': [self.vid1_path] + }, { + 'videos': [self.vid2_path] + }, { + 'videos': [self.vid3_path] + }] + tgt_list = [{'videos': [self.vid2_path]}] + op = VideoMotionScoreRaftFilter(min_score=3, max_score=10.5) + self._run_helper(op, ds_list, tgt_list, np=2) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/ops/grouper/__init__.py b/tests/ops/grouper/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/ops/grouper/test_key_value_grouper.py b/tests/ops/grouper/test_key_value_grouper.py new file mode 100644 index 000000000..1ac186423 --- /dev/null +++ b/tests/ops/grouper/test_key_value_grouper.py @@ -0,0 +1,54 @@ +import unittest + +from data_juicer.core.data import NestedDataset as Dataset +from data_juicer.ops.grouper.key_value_grouper import KeyValueGrouper +from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase + + +class KeyValueGrouperTest(DataJuicerTestCaseBase): + + def _run_helper(self, op, samples, target): + dataset = Dataset.from_list(samples) + new_dataset = op.run(dataset) + + for batched_sample in new_dataset: + lang = batched_sample['meta'][0]['language'] + self.assertEqual(batched_sample['text'], target[lang]) + + def test_key_value_grouper(self): + + source = [ + { + 'text': "Today is Sunday and it's a happy day!", + 'meta': { + 'language': 'en' + } + }, + { + 'text': "Welcome to Alibaba.", + 'meta': { + 'language': 'en' + } + }, + { + 'text': '欢迎来到阿里巴巴!', + 'meta': { + 'language': 'zh' + } + }, + ] + target = { + 'en':[ + "Today is Sunday and it's a happy day!", + "Welcome to Alibaba." + ], + 'zh':[ + '欢迎来到阿里巴巴!' + ] + } + + op = KeyValueGrouper(['meta.language']) + self._run_helper(op, source, target) + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/tests/ops/grouper/test_naive_grouper.py b/tests/ops/grouper/test_naive_grouper.py new file mode 100644 index 000000000..4e69a8ba2 --- /dev/null +++ b/tests/ops/grouper/test_naive_grouper.py @@ -0,0 +1,47 @@ +import unittest + +from data_juicer.core.data import NestedDataset as Dataset +from data_juicer.ops.grouper.naive_grouper import NaiveGrouper +from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase + + +class NaiveGrouperTest(DataJuicerTestCaseBase): + + def _run_helper(self, op, samples, target): + dataset = Dataset.from_list(samples) + new_dataset = op.run(dataset) + + for d, t in zip(new_dataset, target): + self.assertEqual(d['text'], t['text']) + + def test_naive_group(self): + + source = [ + { + 'text': "Today is Sunday and it's a happy day!" + }, + { + 'text': + "Sur la plateforme MT4, plusieurs manières d'accéder à \n" + 'ces fonctionnalités sont conçues simultanément.' + }, + { + 'text': '欢迎来到阿里巴巴!' + }, + ] + target = [ + { + 'text':[ + "Today is Sunday and it's a happy day!", + "Sur la plateforme MT4, plusieurs manières d'accéder à \n" + 'ces fonctionnalités sont conçues simultanément.', + '欢迎来到阿里巴巴!' + ] + } + ] + + op = NaiveGrouper() + self._run_helper(op, source, target) + +if __name__ == '__main__': + unittest.main() \ No newline at end of file From 23174fd8b6ba618d0a32486e82597c6d37e4812c Mon Sep 17 00:00:00 2001 From: Haibin <1400012807@pku.edu.cn> Date: Tue, 12 Nov 2024 16:45:19 +0800 Subject: [PATCH 027/118] ValueError -> Exception --- data_juicer/utils/common_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/data_juicer/utils/common_utils.py b/data_juicer/utils/common_utils.py index 46a638987..959831c5d 100644 --- a/data_juicer/utils/common_utils.py +++ b/data_juicer/utils/common_utils.py @@ -25,5 +25,5 @@ def is_float(s): try: float(s) return True - except ValueError: + except Exception: return False From 20a8deec1889a4f329bf39353b0c6dc8d8ee1d77 Mon Sep 17 00:00:00 2001 From: Haibin <1400012807@pku.edu.cn> Date: Tue, 12 Nov 2024 17:54:51 +0800 Subject: [PATCH 028/118] fix config_all error --- configs/config_all.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/configs/config_all.yaml b/configs/config_all.yaml index 738b16c6f..0cf35e02c 100644 --- a/configs/config_all.yaml +++ b/configs/config_all.yaml @@ -55,14 +55,14 @@ process: - audio_ffmpeg_wrapped_mapper: # simple wrapper for FFmpeg audio filters - calibrate_qa_mapper: # calibrate question-answer pairs based on reference text. api_model: 'gpt-4o' # API model name. - api_url: null # API URL. Defaults to DJ_API_URL environment variable. - api_key: null # API key. Defaults to DJ_API_KEY environment variable. + api_url: null # URL endpoint for the API. response_path: null # Path to extract content from the API response. Defaults to 'choices.0.message.content'. system_prompt: null # System prompt for the calibration task. input_template: null # Template for building the model input. reference_template: null # Template for formatting the reference text. qa_pair_template: null # Template for formatting question-answer pairs. output_pattern: null # Regular expression for parsing model output. + try_num: 3 # The number of retry attempts when there is an API call error or output parsing error. model_params: {} # Parameters for initializing the model. sampling_params: {} # Extra parameters passed to the API call. - calibrate_query_mapper: # calibrate query in question-answer pairs based on reference text. From 38a95111b2aad577cd4c80d2f987ca2f8b7d9a8d Mon Sep 17 00:00:00 2001 From: Haibin <1400012807@pku.edu.cn> Date: Wed, 13 Nov 2024 17:26:35 +0800 Subject: [PATCH 029/118] fix prepare_api_model --- data_juicer/ops/mapper/calibrate_qa_mapper.py | 2 +- .../mapper/extract_entity_attribute_mapper.py | 2 +- .../mapper/extract_entity_relation_mapper.py | 9 +-------- .../ops/mapper/extract_event_mapper.py | 2 +- .../ops/mapper/extract_keyword_mapper.py | 2 +- .../ops/mapper/extract_nickname_mapper.py | 2 +- data_juicer/ops/mapper/text_chunk_mapper.py | 9 +++++---- data_juicer/utils/model_utils.py | 20 +++++++++++++------ 8 files changed, 25 insertions(+), 23 deletions(-) diff --git a/data_juicer/ops/mapper/calibrate_qa_mapper.py b/data_juicer/ops/mapper/calibrate_qa_mapper.py index d28ef6e4a..c0b2c3165 100644 --- a/data_juicer/ops/mapper/calibrate_qa_mapper.py +++ b/data_juicer/ops/mapper/calibrate_qa_mapper.py @@ -74,7 +74,7 @@ def __init__(self, self.sampling_params = sampling_params self.model_key = prepare_model(model_type='api', - model=api_model, + api_model=api_model, url=api_url, response_path=response_path, **model_params) diff --git a/data_juicer/ops/mapper/extract_entity_attribute_mapper.py b/data_juicer/ops/mapper/extract_entity_attribute_mapper.py index 6fd3d2b42..6d22f5513 100644 --- a/data_juicer/ops/mapper/extract_entity_attribute_mapper.py +++ b/data_juicer/ops/mapper/extract_entity_attribute_mapper.py @@ -111,7 +111,7 @@ def __init__(self, self.sampling_params = sampling_params self.model_key = prepare_model(model_type='api', - model=api_model, + api_model=api_model, url=api_url, response_path=response_path, **model_params) diff --git a/data_juicer/ops/mapper/extract_entity_relation_mapper.py b/data_juicer/ops/mapper/extract_entity_relation_mapper.py index f738aecae..85396dec3 100644 --- a/data_juicer/ops/mapper/extract_entity_relation_mapper.py +++ b/data_juicer/ops/mapper/extract_entity_relation_mapper.py @@ -146,13 +146,6 @@ class ExtractEntityRelationMapper(Mapper): DEFAULT_ENTITY_PATTERN = r'\("entity"(.*?)\)' DEFAULT_RELATION_PATTERN = r'\("relationship"(.*?)\)' - # DEFAULT_OUTPUT_PATTERN = r""" - # \#\#\#\s*称呼方式(\d+)\s* - # -\s*\*\*说话人\*\*\s*:\s*(.*?)\s* - # -\s*\*\*被称呼人\*\*\s*:\s*(.*?)\s* - # -\s*\*\*(.*?)对(.*?)的昵称\*\*\s*:\s*(.*?)(?=\#\#\#|\Z) # for double check - # """ - def __init__(self, api_model: str = 'gpt-4o', entity_types: List[str] = None, @@ -228,7 +221,7 @@ def __init__(self, self.sampling_params = sampling_params self.model_key = prepare_model(model_type='api', - model=api_model, + api_model=api_model, url=api_url, response_path=response_path, **model_params) diff --git a/data_juicer/ops/mapper/extract_event_mapper.py b/data_juicer/ops/mapper/extract_event_mapper.py index 1efc5915c..c7d04d696 100644 --- a/data_juicer/ops/mapper/extract_event_mapper.py +++ b/data_juicer/ops/mapper/extract_event_mapper.py @@ -98,7 +98,7 @@ def __init__(self, self.sampling_params = sampling_params self.model_key = prepare_model(model_type='api', - model=api_model, + api_model=api_model, url=api_url, response_path=response_path, **model_params) diff --git a/data_juicer/ops/mapper/extract_keyword_mapper.py b/data_juicer/ops/mapper/extract_keyword_mapper.py index c4f90fb21..be87c9bcb 100644 --- a/data_juicer/ops/mapper/extract_keyword_mapper.py +++ b/data_juicer/ops/mapper/extract_keyword_mapper.py @@ -144,7 +144,7 @@ def __init__(self, self.sampling_params = sampling_params self.model_key = prepare_model(model_type='api', - model=api_model, + api_model=api_model, url=api_url, response_path=response_path, **model_params) diff --git a/data_juicer/ops/mapper/extract_nickname_mapper.py b/data_juicer/ops/mapper/extract_nickname_mapper.py index 265a02441..d86dda6b9 100644 --- a/data_juicer/ops/mapper/extract_nickname_mapper.py +++ b/data_juicer/ops/mapper/extract_nickname_mapper.py @@ -91,7 +91,7 @@ def __init__(self, self.sampling_params = sampling_params self.model_key = prepare_model(model_type='api', - model=api_model, + api_model=api_model, url=api_url, response_path=response_path, **model_params) diff --git a/data_juicer/ops/mapper/text_chunk_mapper.py b/data_juicer/ops/mapper/text_chunk_mapper.py index 2dd8a9441..80aee40c7 100644 --- a/data_juicer/ops/mapper/text_chunk_mapper.py +++ b/data_juicer/ops/mapper/text_chunk_mapper.py @@ -56,10 +56,11 @@ def __init__(self, self.split_pattern = split_pattern self.tokenizer_name = tokenizer if tokenizer is not None: - self.model_key = prepare_model(model_type='api', - api_model=tokenizer, - return_processor=True, - trust_remote_code=trust_remote_code) + self.model_key = prepare_model( + model_type='api', + api_model=tokenizer, + return_processor=True, + processor_config={'trust_remote_code': trust_remote_code}) def recursively_chunk(self, text): if self.tokenizer_name is not None: diff --git a/data_juicer/utils/model_utils.py b/data_juicer/utils/model_utils.py index 7f9687079..bb6e06788 100644 --- a/data_juicer/utils/model_utils.py +++ b/data_juicer/utils/model_utils.py @@ -203,7 +203,7 @@ def _filter_arguments(func, args_dict): return filtered_args -def prepare_api_model(model, +def prepare_api_model(api_model, *, url=None, response_path=None, @@ -214,7 +214,7 @@ def prepare_api_model(model, The callable supports custom response parsing and works with proxy servers that may be incompatible. - :param model: The name of the model to interact with. + :param api_model: The name of the model to interact with. :param url: URL endpoint for the API. :param response_path: The dot-separated path to extract desired content from the API response. Defaults to 'choices.0.message.content'. @@ -229,7 +229,7 @@ def prepare_api_model(model, :return: A tuple containing the callable API model object and optionally a processor if `return_processor` is True. """ - model = APIModel(model=model, + model = APIModel(model=api_model, url=url, response_path=response_path, **model_params) @@ -240,13 +240,20 @@ def prepare_api_model(model, def get_processor(): try: import tiktoken - return tiktoken.encoding_for_model(model) + return tiktoken.encoding_for_model(api_model) except Exception: pass try: import dashscope - return dashscope.get_tokenizer(model) + return dashscope.get_tokenizer(api_model) + except Exception: + pass + + try: + processor = transformers.AutoProcessor.from_pretrained( + pretrained_model_name_or_path=api_model, **processor_config) + return processor except Exception: pass @@ -257,7 +264,8 @@ def get_processor(): "- For custom models: Use the 'processor_config' parameter to configure a Hugging Face processor." # noqa: E501 ) - if processor_config is not None: + if processor_config is not None \ + and 'pretrained_model_name_or_path' in processor_config: processor = transformers.AutoProcessor.from_pretrained( **processor_config) else: From 35f0eb36ae3cc5d0d6723b3ed1870b2152397044 Mon Sep 17 00:00:00 2001 From: Haibin <1400012807@pku.edu.cn> Date: Wed, 13 Nov 2024 19:51:15 +0800 Subject: [PATCH 030/118] fix rank sample None --- data_juicer/ops/mapper/calibrate_qa_mapper.py | 2 +- .../ops/mapper/extract_entity_attribute_mapper.py | 4 ++-- data_juicer/ops/mapper/extract_entity_relation_mapper.py | 2 +- data_juicer/ops/mapper/extract_event_mapper.py | 5 +++-- data_juicer/ops/mapper/extract_keyword_mapper.py | 2 +- data_juicer/ops/mapper/extract_nickname_mapper.py | 2 +- .../ops/mapper/generate_qa_from_examples_mapper.py | 2 +- data_juicer/ops/mapper/optimize_qa_mapper.py | 2 +- data_juicer/ops/mapper/text_chunk_mapper.py | 9 +++++---- 9 files changed, 16 insertions(+), 14 deletions(-) diff --git a/data_juicer/ops/mapper/calibrate_qa_mapper.py b/data_juicer/ops/mapper/calibrate_qa_mapper.py index c0b2c3165..e21c18ff1 100644 --- a/data_juicer/ops/mapper/calibrate_qa_mapper.py +++ b/data_juicer/ops/mapper/calibrate_qa_mapper.py @@ -96,7 +96,7 @@ def parse_output(self, raw_output): else: return None, None - def process_single(self, sample=None, rank=None): + def process_single(self, sample, rank=None): client = get_model(self.model_key, rank=rank) messages = [{ diff --git a/data_juicer/ops/mapper/extract_entity_attribute_mapper.py b/data_juicer/ops/mapper/extract_entity_attribute_mapper.py index 6d22f5513..ebff93485 100644 --- a/data_juicer/ops/mapper/extract_entity_attribute_mapper.py +++ b/data_juicer/ops/mapper/extract_entity_attribute_mapper.py @@ -169,13 +169,13 @@ def _process_single_sample(self, text='', rank=None): return entities, attributes, descs, demo_lists - def process_batched(self, samples): + def process_batched(self, samples, rank=None): sample_num = len(samples[self.text_key]) entities, attributes, descs, demo_lists = [], [], [], [] for text in samples[self.text_key]: - res = self._process_single_sample(text) + res = self._process_single_sample(text, rank=rank) cur_ents, cur_attrs, cur_descs, cur_demos = res entities.append(cur_ents) attributes.append(cur_attrs) diff --git a/data_juicer/ops/mapper/extract_entity_relation_mapper.py b/data_juicer/ops/mapper/extract_entity_relation_mapper.py index 85396dec3..efb110188 100644 --- a/data_juicer/ops/mapper/extract_entity_relation_mapper.py +++ b/data_juicer/ops/mapper/extract_entity_relation_mapper.py @@ -308,7 +308,7 @@ def light_rag_extraction(self, messages, rank=None): return final_result - def process_single(self, sample=None, rank=None): + def process_single(self, sample, rank=None): input_prompt = self.prompt_template.format( tuple_delimiter=self.tuple_delimiter, diff --git a/data_juicer/ops/mapper/extract_event_mapper.py b/data_juicer/ops/mapper/extract_event_mapper.py index c7d04d696..50e8a14e8 100644 --- a/data_juicer/ops/mapper/extract_event_mapper.py +++ b/data_juicer/ops/mapper/extract_event_mapper.py @@ -145,13 +145,14 @@ def _process_single_sample(self, text='', rank=None): return event_list, character_list - def process_batched(self, samples): + def process_batched(self, samples, rank=None): sample_num = len(samples[self.text_key]) events, characters = [], [] for text in samples[self.text_key]: - cur_events, cur_characters = self._process_single_sample(text) + cur_events, cur_characters = self._process_single_sample(text, + rank=rank) events.append(cur_events) characters.append(cur_characters) diff --git a/data_juicer/ops/mapper/extract_keyword_mapper.py b/data_juicer/ops/mapper/extract_keyword_mapper.py index be87c9bcb..91e712822 100644 --- a/data_juicer/ops/mapper/extract_keyword_mapper.py +++ b/data_juicer/ops/mapper/extract_keyword_mapper.py @@ -164,7 +164,7 @@ def parse_output(self, raw_output): return keywords - def process_single(self, sample=None, rank=None): + def process_single(self, sample, rank=None): client = get_model(self.model_key, rank=rank) input_prompt = self.prompt_template.format( diff --git a/data_juicer/ops/mapper/extract_nickname_mapper.py b/data_juicer/ops/mapper/extract_nickname_mapper.py index d86dda6b9..afeb3332e 100644 --- a/data_juicer/ops/mapper/extract_nickname_mapper.py +++ b/data_juicer/ops/mapper/extract_nickname_mapper.py @@ -131,7 +131,7 @@ def parse_output(self, raw_output): return nickname_relations - def process_single(self, sample=None, rank=None): + def process_single(self, sample, rank=None): client = get_model(self.model_key, rank=rank) input_prompt = self.input_template.format(text=sample[self.text_key]) diff --git a/data_juicer/ops/mapper/generate_qa_from_examples_mapper.py b/data_juicer/ops/mapper/generate_qa_from_examples_mapper.py index 4d7ff01bd..6f5ad7dab 100644 --- a/data_juicer/ops/mapper/generate_qa_from_examples_mapper.py +++ b/data_juicer/ops/mapper/generate_qa_from_examples_mapper.py @@ -210,7 +210,7 @@ def parse_output(self, raw_output): output_qa_pairs.append((question.strip(), answer.strip())) return output_qa_pairs - def process_single(self, sample=None, rank=None): + def process_single(self, sample, rank=None): model, _ = get_model(self.model_key, rank, self.use_cuda()) random_qa_samples = random.sample(self.seed_qa_samples, diff --git a/data_juicer/ops/mapper/optimize_qa_mapper.py b/data_juicer/ops/mapper/optimize_qa_mapper.py index cd5a0aba7..3563a112b 100644 --- a/data_juicer/ops/mapper/optimize_qa_mapper.py +++ b/data_juicer/ops/mapper/optimize_qa_mapper.py @@ -113,7 +113,7 @@ def parse_output(self, raw_output): else: return None, None - def process_single(self, sample=None, rank=None): + def process_single(self, sample, rank=None): model, _ = get_model(self.model_key, rank, self.use_cuda()) input_prompt = self.build_input(sample) diff --git a/data_juicer/ops/mapper/text_chunk_mapper.py b/data_juicer/ops/mapper/text_chunk_mapper.py index 80aee40c7..4f7e03f1c 100644 --- a/data_juicer/ops/mapper/text_chunk_mapper.py +++ b/data_juicer/ops/mapper/text_chunk_mapper.py @@ -90,7 +90,7 @@ def recursively_chunk(self, text): return [cur_text] + self.recursively_chunk(left_text) - def get_text_chunks(self, text): + def get_text_chunks(self, text, rank=None): if self.split_pattern is not None and self.max_len is None: chunks = re.split(f'({self.split_pattern})', text) @@ -99,7 +99,7 @@ def get_text_chunks(self, text): tokens = text total_len = len(text) if self.tokenizer_name is not None: - _, tokenizer = get_model(self.model_key) + _, tokenizer = get_model(self.model_key, rank=rank) tokens = tokenizer.encode(text) total_len = len(tokens) if total_len <= self.max_len: @@ -115,12 +115,13 @@ def get_text_chunks(self, text): return chunks - def process_batched(self, samples): + def process_batched(self, samples, rank=None): sample_num = len(samples[self.text_key]) samples[self.text_key] = [ - self.get_text_chunks(text) for text in samples[self.text_key] + self.get_text_chunks(text, rank=rank) + for text in samples[self.text_key] ] for key in samples: From 155d3dda957dc09e881e88a66a5222df15470d80 Mon Sep 17 00:00:00 2001 From: Haibin <1400012807@pku.edu.cn> Date: Wed, 13 Nov 2024 20:15:12 +0800 Subject: [PATCH 031/118] constant fix key --- .../ops/mapper/extract_entity_relation_mapper.py | 16 ++++++++-------- .../ops/mapper/extract_nickname_mapper.py | 10 +++++----- data_juicer/utils/constant.py | 16 ++++++++++++++++ tests/ops/mapper/test_extract_nickname_mapper.py | 6 +++++- 4 files changed, 34 insertions(+), 14 deletions(-) diff --git a/data_juicer/ops/mapper/extract_entity_relation_mapper.py b/data_juicer/ops/mapper/extract_entity_relation_mapper.py index efb110188..c548a44e4 100644 --- a/data_juicer/ops/mapper/extract_entity_relation_mapper.py +++ b/data_juicer/ops/mapper/extract_entity_relation_mapper.py @@ -257,9 +257,9 @@ def split_by_tuple_delimiter(record): entities.append(items) entities = list(set(entities)) entities = [{ - 'entity': e[0], - 'type': e[1], - 'description': e[2] + Fields.entity_name: e[0], + Fields.entity_type: e[1], + Fields.entity_description: e[2] } for e in entities] relation_pattern = re.compile(self.relation_pattern, @@ -272,11 +272,11 @@ def split_by_tuple_delimiter(record): relations.append(items) relations = list(set(relations)) relations = [{ - 'source_entity': r[0], - 'target_entity': r[1], - 'description': r[2], - 'keywords': split_text_by_punctuation(r[3]), - 'strength': float(r[4]) + Fields.source_entity: r[0], + Fields.target_entity: r[1], + Fields.relation_description: r[2], + Fields.relation_keywords: split_text_by_punctuation(r[3]), + Fields.relation_strength: float(r[4]) } for r in relations] return entities, relations diff --git a/data_juicer/ops/mapper/extract_nickname_mapper.py b/data_juicer/ops/mapper/extract_nickname_mapper.py index afeb3332e..bf5382af9 100644 --- a/data_juicer/ops/mapper/extract_nickname_mapper.py +++ b/data_juicer/ops/mapper/extract_nickname_mapper.py @@ -122,11 +122,11 @@ def parse_output(self, raw_output): nickname_relations = list(set(nickname_relations)) nickname_relations = [{ - 'source_entity': nr[0], - 'target_entity': nr[1], - 'description': nr[2], - 'keywords': ['nickname'], - 'strength': None + Fields.source_entity: nr[0], + Fields.target_entity: nr[1], + Fields.relation_description: nr[2], + Fields.relation_keywords: ['nickname'], + Fields.relation_strength: None } for nr in nickname_relations] return nickname_relations diff --git a/data_juicer/utils/constant.py b/data_juicer/utils/constant.py index 78c373d45..ab88035b9 100644 --- a/data_juicer/utils/constant.py +++ b/data_juicer/utils/constant.py @@ -44,8 +44,24 @@ class Fields(object): nickname = DEFAULT_PREFIX + 'nickname__' # # the entity for knowledge graph entity = DEFAULT_PREFIX + 'entity__' + # # # the name of entity + entity_name = DEFAULT_PREFIX + 'entity_name__' + # # # the type of entity + entity_type = DEFAULT_PREFIX + 'entity_type__' + # # # the description of entity + entity_description = DEFAULT_PREFIX + 'entity_entity_description__' # # the relationship for knowledge graph relation = DEFAULT_PREFIX + 'relation__' + # # # the source entity of the relation + source_entity = DEFAULT_PREFIX + 'relation_source_entity__' + # # # the target entity of the relation + target_entity = DEFAULT_PREFIX + 'relation_target_entity__' + # # # the description of the relation + relation_description = DEFAULT_PREFIX + 'relation_description__' + # # # the keywords of the relation + relation_keywords = DEFAULT_PREFIX + 'relation_keywords__' + # # # the strength of the relation + relation_strength = DEFAULT_PREFIX + 'relation_strength__' # # the keyword in a text keyword = DEFAULT_PREFIX + 'keyword__' diff --git a/tests/ops/mapper/test_extract_nickname_mapper.py b/tests/ops/mapper/test_extract_nickname_mapper.py index 780c7e4b2..635801155 100644 --- a/tests/ops/mapper/test_extract_nickname_mapper.py +++ b/tests/ops/mapper/test_extract_nickname_mapper.py @@ -39,7 +39,11 @@ def _run_op(self, api_model, response_path=None): dataset = Dataset.from_list(samples) dataset = dataset.map(op.process, batch_size=2) result = dataset[0][Fields.nickname] - result = [(d['source_entity'], d['target_entity'], d['description']) for d in result] + result = [( + d[Fields.source_entity], + d[Fields.target_entity], + d[Fields.relation_description]) + for d in result] logger.info(f'result: {result}') self.assertIn(("李莲花","方多病","方小宝"), result) From f862897fb6824a94571b842fa1951358cbd52b83 Mon Sep 17 00:00:00 2001 From: Haibin <1400012807@pku.edu.cn> Date: Thu, 14 Nov 2024 19:22:31 +0800 Subject: [PATCH 032/118] aggregator op --- data_juicer/ops/__init__.py | 5 +- data_juicer/ops/aggregator/__init__.py | 4 + .../aggregator/entity_attribute_aggregator.py | 233 ++++++++++++++++++ .../ops/aggregator/nested_aggregator.py | 211 ++++++++++++++++ data_juicer/ops/base_op.py | 50 +++- data_juicer/ops/grouper/key_value_grouper.py | 4 +- data_juicer/ops/grouper/naive_grouper.py | 8 +- data_juicer/utils/common_utils.py | 57 +++++ data_juicer/utils/model_utils.py | 42 +++- tests/ops/Aggregator/__init__.py | 0 .../test_entity_attribute_aggregator.py | 161 ++++++++++++ .../ops/Aggregator/test_nested_aggregator.py | 139 +++++++++++ 12 files changed, 897 insertions(+), 17 deletions(-) create mode 100644 data_juicer/ops/aggregator/__init__.py create mode 100644 data_juicer/ops/aggregator/entity_attribute_aggregator.py create mode 100644 data_juicer/ops/aggregator/nested_aggregator.py create mode 100644 tests/ops/Aggregator/__init__.py create mode 100644 tests/ops/Aggregator/test_entity_attribute_aggregator.py create mode 100644 tests/ops/Aggregator/test_nested_aggregator.py diff --git a/data_juicer/ops/__init__.py b/data_juicer/ops/__init__.py index 38116efd5..4f9e4f4bc 100644 --- a/data_juicer/ops/__init__.py +++ b/data_juicer/ops/__init__.py @@ -1,6 +1,6 @@ from . import deduplicator, filter, mapper, selector -from .base_op import (OPERATORS, UNFORKABLE, Deduplicator, Filter, Grouper, - Mapper, Selector) +from .base_op import (OPERATORS, UNFORKABLE, Aggregator, Deduplicator, Filter, + Grouper, Mapper, Selector) from .load import load_ops __all__ = [ @@ -10,4 +10,5 @@ 'Deduplicator', 'Selector', 'Grouper', + 'Aggregator', ] diff --git a/data_juicer/ops/aggregator/__init__.py b/data_juicer/ops/aggregator/__init__.py new file mode 100644 index 000000000..cdc67def9 --- /dev/null +++ b/data_juicer/ops/aggregator/__init__.py @@ -0,0 +1,4 @@ +from .entity_attribute_aggregator import EntityAttributeAggregator +from .nested_aggregator import NestedAggregator + +__all__ = ['NestedAggregator', 'EntityAttributeAggregator'] diff --git a/data_juicer/ops/aggregator/entity_attribute_aggregator.py b/data_juicer/ops/aggregator/entity_attribute_aggregator.py new file mode 100644 index 000000000..23986a646 --- /dev/null +++ b/data_juicer/ops/aggregator/entity_attribute_aggregator.py @@ -0,0 +1,233 @@ +import re +from typing import Dict, Optional + +from loguru import logger +from pydantic import PositiveInt + +from data_juicer.ops.base_op import OPERATORS, UNFORKABLE, Aggregator +from data_juicer.utils.common_utils import (avg_split_string_list_under_limit, + get_val_by_nested_key, + is_string_list) +from data_juicer.utils.lazy_loader import LazyLoader +from data_juicer.utils.model_utils import (get_model, parse_model_response, + prepare_model) + +from .nested_aggregator import NestedAggregator + +torch = LazyLoader('torch', 'torch') +vllm = LazyLoader('vllm', 'vllm') + +OP_NAME = 'entity_attribute_aggregator' + + +# TODO: LLM-based inference. +@UNFORKABLE.register_module(OP_NAME) +@OPERATORS.register_module(OP_NAME) +class EntityAttributeAggregator(Aggregator): + """ + Return conclusion of the given entity's attribute from some docs. + """ + + DEFAULT_SYSTEM_TEMPLATE = ( + '给定与`{entity}`相关的一些文档,总结`{entity}`的`{attribute}`。\n' + '要求:\n' + '- 尽量使用原文专有名词\n' + '- 联系上下文,自动忽略上下文不一致的细节错误\n' + '- 只对文档中与`{entity}`的`{attribute}`有关的内容进行总结\n' + '- 字数限制在**{word_limit}字以内**\n' + '- 要求输出格式如下:\n' + '# {entity}\n' + '## {attribute}\n' + '...\n' + '{example}') + + DEFAULT_EXAMPLE_PROMPT = ('- 例如,根据相关文档总结`孙悟空`的`出身背景`,**100字**以内的样例如下:\n' + '`孙悟空`的`出身背景`总结:\n' + '# 孙悟空\n' + '## 出身背景\n' + '号称齐天大圣,花果山水帘洞的美猴王、西行取经队伍中的大师兄。' + '师父是唐僧玄奘,曾拜菩提祖师学艺。' + '亲生父母未知,自石头中孕育而生。自认斗战胜佛,最怕观世音菩萨和紧箍咒。\n') + + DEFAULT_INPUT_TEMPLATE = ('`{entity}`的相关文档:\n' + '{sub_docs}\n\n' + '`{entity}`的`{attribute}`总结:\n') + + DEFAULT_OUTPUT_PATTERN_TEMPLATE = r'\#\s*{entity}\s*\#\#\s*{attribute}\s*(.*?)\Z' # noqa: E501 + + def __init__(self, + hf_or_api_model: str = 'gpt-4o', + entity: str = None, + attribute: str = None, + input_key: str = None, + output_key: str = None, + word_limit: PositiveInt = 100, + max_token_num: Optional[PositiveInt] = None, + *, + api_url: Optional[str] = None, + response_path: Optional[str] = None, + system_prompt_template: Optional[str] = None, + example_prompt: Optional[str] = None, + input_template: Optional[str] = None, + output_pattern_template: Optional[str] = None, + try_num: PositiveInt = 3, + is_hf_model: bool = False, + enable_vllm: bool = False, + model_params: Dict = {}, + sampling_params: Dict = {}, + **kwargs): + """ + Initialization method. + :param hf_or_api_model: Huggingface model or API model name. + :param entity: The given entity. + :param attribute: The given attribute. + :param input_key: The input field key in the samples. Support for + nested keys such as "__dj__stats__.text_len". It is text_key + in default. + :param output_key: The output field key in the samples. Support for + nested keys such as "__dj__stats__.text_len". It is same as the + input_key in default. + :param word_limit: Prompt the output length. + :param max_token_num: The max token num of the total tokens of the + sub documents. Without limitation if it is None. + :param api_url: URL endpoint for the API. + :param response_path: Path to extract content from the API response. + Defaults to 'choices.0.message.content'. + :param system_prompt_template: The system prompt template. + :param example_prompt: The example part in the system prompt. + :param input_template: The input template. + :param output_pattern_template: The output template. + :param try_num: The number of retry attempts when there is an API + call error or output parsing error. + :param is_hf_model: If the hf_or_api_model is huggingface model. + :param enable_vllm: Whether to use VLLM for inference acceleration. + :param model_params: Parameters for initializing the API model. + :param sampling_params: Extra parameters passed to the API call. + e.g {'temperature': 0.9, 'top_p': 0.95} + :param kwargs: Extra keyword arguments. + """ + super().__init__(**kwargs) + + if entity is None or attribute is None: + raise ValueError('The entity and attribute cannot be None!') + + self.entity = entity + self.attribute = attribute + self.input_key = input_key or self.text_key + self.output_key = output_key or self.input_key + self.word_limit = word_limit + self.max_token_num = max_token_num + + self.system_prompt_template = system_prompt_template or \ + self.DEFAULT_SYSTEM_TEMPLATE + self.example_prompt = example_prompt or self.DEFAULT_EXAMPLE_PROMPT + self.input_template = input_template or self.DEFAULT_INPUT_TEMPLATE + output_pattern_template = output_pattern_template or \ + self.DEFAULT_OUTPUT_PATTERN_TEMPLATE + self.output_pattern = output_pattern_template.format( + entity=entity, attribute=attribute) + + self.sampling_params = sampling_params + self.is_hf_model = is_hf_model + self.enable_vllm = enable_vllm + if is_hf_model and enable_vllm: + assert torch.cuda.device_count() >= 1, 'must be executed in CUDA' + # cannot initialize vllm replicas on different GPUs + self.num_proc = 1 + if model_params.get('tensor_parallel_size') is None: + tensor_parallel_size = torch.cuda.device_count() + logger.info(f'Set tensor_parallel_size to \ + {tensor_parallel_size} for vllm.') + model_params['tensor_parallel_size'] = tensor_parallel_size + self.model_key = prepare_model( + model_type='vllm', + pretrained_model_name_or_path=hf_or_api_model, + **model_params) + self.sampling_params = vllm.SamplingParams(**sampling_params) + elif is_hf_model: + self.model_key = prepare_model( + model_type='huggingface', + pretrained_model_name_or_path=hf_or_api_model, + return_pipe=True, + **model_params) + self.sampling_params = sampling_params + else: + self.model_key = prepare_model(model_type='api', + api_model=hf_or_api_model, + url=api_url, + response_path=response_path, + return_processor=True, + **model_params) + + self.try_num = try_num + self.nested_sum = NestedAggregator(hf_or_api_model=hf_or_api_model, + max_token_num=max_token_num, + api_url=api_url, + response_path=response_path, + try_num=try_num, + is_hf_model=is_hf_model, + enable_vllm=enable_vllm, + model_params=model_params, + sampling_params=sampling_params) + + def parse_output(self, response): + response = parse_model_response(response) + + pattern = re.compile(self.output_pattern, re.VERBOSE | re.DOTALL) + matches = pattern.findall(response) + if matches: + result = matches[0].strip() + else: + result = '' + + return result + + def attribute_summary(self, sub_docs, rank=None): + if not sub_docs: + return '' + + model, tokenizer = get_model(self.model_key, rank, self.use_cuda()) + token_nums = [len(tokenizer.encode(sub_doc)) for sub_doc in sub_docs] + group_docs = avg_split_string_list_under_limit(sub_docs, token_nums, + self.max_token_num) + results = [] + for docs in group_docs: + system_prompt = self.system_prompt_template.format( + entity=self.entity, + attribute=self.attribute, + word_limit=self.word_limit, + example=self.example_prompt) + doc_str = '\n\n'.join(docs) + input_prompt = self.input_template.format(entity=self.entity, + attribute=self.attribute, + sub_docs=doc_str) + messages = [{ + 'role': 'system', + 'content': system_prompt + }, { + 'role': 'user', + 'content': input_prompt + }] + result = '' + for i in range(self.try_num): + try: + response = model(messages, **self.sampling_params) + result = self.parse_output(response) + if len(result) > 0: + break + except Exception as e: + logger.warning(f'Exception: {e}') + results.append(result) + + return self.nested_sum.recursive_summary(results) + + def process_single(self, sample=None, rank=None): + + # if not batched sample + sub_docs = get_val_by_nested_key(sample, self.input_key) + if not is_string_list(sub_docs): + return sample + + sample[self.output_key] = self.attribute_summary(sub_docs, rank=rank) + + return sample diff --git a/data_juicer/ops/aggregator/nested_aggregator.py b/data_juicer/ops/aggregator/nested_aggregator.py new file mode 100644 index 000000000..d843358b0 --- /dev/null +++ b/data_juicer/ops/aggregator/nested_aggregator.py @@ -0,0 +1,211 @@ +from typing import Dict, Optional + +from loguru import logger +from pydantic import PositiveInt + +from data_juicer.ops.base_op import OPERATORS, UNFORKABLE, Aggregator +from data_juicer.utils.common_utils import (avg_split_string_list_under_limit, + get_val_by_nested_key, + is_string_list) +from data_juicer.utils.lazy_loader import LazyLoader +from data_juicer.utils.model_utils import (get_model, parse_model_response, + prepare_model) + +torch = LazyLoader('torch', 'torch') +vllm = LazyLoader('vllm', 'vllm') + +OP_NAME = 'nested_aggregator' + + +# TODO: LLM-based inference. +@UNFORKABLE.register_module(OP_NAME) +@OPERATORS.register_module(OP_NAME) +class NestedAggregator(Aggregator): + """ + Considering the limitation of input length, nested aggregate + contents for each given number of samples. + """ + + DEFAULT_SYSTEM_PROMPT = ('给定一些文档碎片,将这些文档整合成一个文档总结。\n' + '要求:\n' + '- 总结的长度与文档碎片的平均长度基本一致\n' + '- 不要包含主观看法\n' + '- 注意要尽可能保留文本的专有名词\n' + '- 只输出文档总结不要输出其他内容\n' + '- 参考如下样例:\n' + '文档碎片:\n' + '唐僧师徒四人行至白虎岭,遇上了变化多端的白骨精。\n\n' + '文档碎片:\n' + '白骨精首次变身少女送斋,被孙悟空识破打死,唐僧责怪悟空。\n\n' + '文档碎片:\n' + '妖怪再变老妇寻女,又被悟空击毙,师傅更加不满,念紧箍咒惩罚。\n\n' + '文档碎片:\n' + '不甘心的白骨精第三次化作老公公来诱骗,依旧逃不过金睛火眼。\n\n' + '文档碎片:\n' + '最终,在观音菩萨的帮助下,真相大白,唐僧明白了自己的误解。\n\n' + '\n' + '文档总结:\n' + '唐僧师徒在白虎岭三遇白骨精变化诱惑,悟空屡次识破击毙妖怪却遭误解,最终观音相助真相大白。') + + DEFAULT_INPUT_TEMPLATE = ('{sub_docs}\n\n' + '文档总结:\n') + + DEFAULT_SUB_DOC_TEMPLATE = '文档碎片:\n{text}\n' + + def __init__(self, + hf_or_api_model: str = 'gpt-4o', + input_key: str = None, + output_key: str = None, + max_token_num: Optional[PositiveInt] = None, + *, + api_url: Optional[str] = None, + response_path: Optional[str] = None, + system_prompt: Optional[str] = None, + sub_doc_template: Optional[str] = None, + input_template: Optional[str] = None, + try_num: PositiveInt = 3, + is_hf_model: bool = False, + enable_vllm: bool = False, + model_params: Dict = {}, + sampling_params: Dict = {}, + **kwargs): + """ + Initialization method. + :param hf_or_api_model: Huggingface model or API model name. + :param input_key: The input field key in the samples. Support for + nested keys such as "__dj__stats__.text_len". It is text_key + in default. + :param output_key: The output field key in the samples. Support for + nested keys such as "__dj__stats__.text_len". It is same as the + input_key in default. + :param max_token_num: The max token num of the total tokens of the + sub documents. Without limitation if it is None. + :param api_url: URL endpoint for the API. + :param response_path: Path to extract content from the API response. + Defaults to 'choices.0.message.content'. + :param system_prompt: The system prompt. + :param sub_doc_template: The template for input text in each sample. + :param input_template: The input template. + :param try_num: The number of retry attempts when there is an API + call error or output parsing error. + :param is_hf_model: If the hf_or_api_model is huggingface model. + :param enable_vllm: Whether to use VLLM for inference acceleration. + :param model_params: Parameters for initializing the API model. + :param sampling_params: Extra parameters passed to the API call. + e.g {'temperature': 0.9, 'top_p': 0.95} + :param kwargs: Extra keyword arguments. + """ + super().__init__(**kwargs) + + self.input_key = input_key or self.text_key + self.output_key = output_key or self.input_key + self.max_token_num = max_token_num + + self.system_prompt = system_prompt or self.DEFAULT_SYSTEM_PROMPT + self.sub_doc_template = sub_doc_template or \ + self.DEFAULT_SUB_DOC_TEMPLATE + self.input_template = input_template or self.DEFAULT_INPUT_TEMPLATE + + self.sampling_params = sampling_params + self.is_hf_model = is_hf_model + self.enable_vllm = enable_vllm + if is_hf_model and enable_vllm: + assert torch.cuda.device_count() >= 1, 'must be executed in CUDA' + # cannot initialize vllm replicas on different GPUs + self.num_proc = 1 + if model_params.get('tensor_parallel_size') is None: + tensor_parallel_size = torch.cuda.device_count() + logger.info(f'Set tensor_parallel_size to \ + {tensor_parallel_size} for vllm.') + model_params['tensor_parallel_size'] = tensor_parallel_size + self.model_key = prepare_model( + model_type='vllm', + pretrained_model_name_or_path=hf_or_api_model, + **model_params) + self.sampling_params = vllm.SamplingParams(**sampling_params) + elif is_hf_model: + self.model_key = prepare_model( + model_type='huggingface', + pretrained_model_name_or_path=hf_or_api_model, + return_pipe=True, + **model_params) + self.sampling_params = sampling_params + else: + self.model_key = prepare_model(model_type='api', + api_model=hf_or_api_model, + url=api_url, + response_path=response_path, + return_processor=True, + **model_params) + + self.try_num = try_num + + def parse_output(self, response): + response = parse_model_response(response) + + def if_match(text): + quotes = [("'", "'"), ('"', '"'), ('“', '”'), ('‘', '’'), + ('`', '`')] + if len(text) < 2: + return False + if (text[0], text[-1]) in quotes: + return True + else: + return False + + text = response.strip() + while if_match(text): + text = text[1:-1].strip() + return text + + def recursive_summary(self, sub_docs, rank=None): + if not sub_docs: + return '' + if len(sub_docs) == 1: + return sub_docs[0] + model, tokenizer = get_model(self.model_key, rank, self.use_cuda()) + token_nums = [len(tokenizer.encode(sub_doc)) for sub_doc in sub_docs] + group_docs = avg_split_string_list_under_limit(sub_docs, token_nums, + self.max_token_num) + # merge every two if every single sub doc is a group + group_num = len(group_docs) + if group_num == len(sub_docs): + group_docs = [ + group_docs[i] + + group_docs[i + 1] if i + 1 < group_num else group_docs[i] + for i in range(0, group_num, 2) + ] + results = [] + for docs in group_docs: + doc_strs = [self.sub_doc_template.format(text=d) for d in docs] + input_prompt = self.input_template.format( + sub_docs='\n'.join(doc_strs)) + messages = [{ + 'role': 'system', + 'content': self.system_prompt + }, { + 'role': 'user', + 'content': input_prompt + }] + result = '' + for i in range(self.try_num): + try: + response = model(messages, **self.sampling_params) + result = self.parse_output(response) + if len(result) > 0: + break + except Exception as e: + logger.warning(f'Exception: {e}') + results.append(result) + return self.recursive_summary(results) + + def process_single(self, sample=None, rank=None): + + # if not batched sample + sub_docs = get_val_by_nested_key(sample, self.input_key) + if not is_string_list(sub_docs): + return sample + + sample[self.output_key] = self.recursive_summary(sub_docs, rank=rank) + + return sample diff --git a/data_juicer/ops/base_op.py b/data_juicer/ops/base_op.py index 48ecf4947..361bf7671 100644 --- a/data_juicer/ops/base_op.py +++ b/data_juicer/ops/base_op.py @@ -7,7 +7,6 @@ from loguru import logger from data_juicer import is_cuda_available -from data_juicer.core.data import NestedDataset from data_juicer.utils.constant import Fields from data_juicer.utils.mm_utils import size_to_bytes from data_juicer.utils.process_utils import calculate_np @@ -216,6 +215,7 @@ def add_parameters(self, init_parameter_dict, **extra_param_dict): return related_parameters def run(self, dataset): + from data_juicer.core.data import NestedDataset if not isinstance(dataset, NestedDataset): dataset = NestedDataset(dataset) return dataset @@ -542,7 +542,55 @@ def process(self, dataset): def run(self, dataset, *, exporter=None, tracer=None): dataset = super(Grouper, self).run(dataset) batched_samples = self.process(dataset) + from data_juicer.core.data import NestedDataset new_dataset = NestedDataset.from_list(batched_samples) if tracer: tracer.trace_filter(self._name, dataset, new_dataset) return new_dataset + + +class Aggregator(OP): + + def __init__(self, *args, **kwargs): + """ + Base class that group samples. + + :param text_key: the key name of field that stores sample texts + to be processed + :param image_key: the key name of field that stores sample image list + to be processed + :param audio_key: the key name of field that stores sample audio list + to be processed + :param video_key: the key name of field that stores sample video list + to be processed + :param query_key: the key name of field that stores sample queris + :param response_key: the key name of field that stores responses + :param history_key: the key name of field that stores history of + queries and responses + """ + super(Aggregator, self).__init__(*args, **kwargs) + self.process = catch_map_single_exception(self.process_single) + + def process_single(self, sample): + """ + For sample level, batched sample --> sample, + the input must be the output of some Grouper OP. + + :param sample: batched sample to aggregate + :return: aggregated sample + """ + raise NotImplementedError + + def run(self, dataset, *, exporter=None, tracer=None): + dataset = super(Aggregator, self).run(dataset) + new_dataset = dataset.map( + self.process, + num_proc=self.runtime_np(), + with_rank=self.use_cuda(), + batch_size=self.batch_size, + desc=self._name + '_process', + ) + if tracer: + tracer.trace_mapper(self._name, dataset, new_dataset, + self.text_key) + return new_dataset diff --git a/data_juicer/ops/grouper/key_value_grouper.py b/data_juicer/ops/grouper/key_value_grouper.py index 15b8d0328..4b6e94bf9 100644 --- a/data_juicer/ops/grouper/key_value_grouper.py +++ b/data_juicer/ops/grouper/key_value_grouper.py @@ -2,7 +2,7 @@ from data_juicer.utils.common_utils import dict_to_hash, get_val_by_nested_key -from ..base_op import OPERATORS, Grouper +from ..base_op import OPERATORS, Grouper, convert_list_dict_to_dict_list from .naive_grouper import NaiveGrouper @@ -45,7 +45,7 @@ def process(self, dataset): sample_map[sample_key] = [sample] batched_samples = [ - self.naive_grouper.process(sample_map[k])[0] for k in sample_map + convert_list_dict_to_dict_list(sample_map[k]) for k in sample_map ] return batched_samples diff --git a/data_juicer/ops/grouper/naive_grouper.py b/data_juicer/ops/grouper/naive_grouper.py index 92da22875..4633dc48e 100644 --- a/data_juicer/ops/grouper/naive_grouper.py +++ b/data_juicer/ops/grouper/naive_grouper.py @@ -1,4 +1,4 @@ -from ..base_op import OPERATORS, Grouper +from ..base_op import OPERATORS, Grouper, convert_list_dict_to_dict_list @OPERATORS.register_module('naive_grouper') @@ -19,10 +19,6 @@ def process(self, dataset): if len(dataset) == 0: return dataset - keys = dataset[0].keys() - batched_sample = {k: [None] * len(dataset) for k in keys} - for i, sample in enumerate(dataset): - for k in keys: - batched_sample[k][i] = sample[k] + batched_sample = convert_list_dict_to_dict_list(dataset) return [batched_sample] diff --git a/data_juicer/utils/common_utils.py b/data_juicer/utils/common_utils.py index 0da2cb017..2c4c3f68e 100644 --- a/data_juicer/utils/common_utils.py +++ b/data_juicer/utils/common_utils.py @@ -50,5 +50,62 @@ def get_val_by_nested_key(input_dict: dict, nested_key: str): for key in keys: if key not in cur: logger.warning(f'Unvisitable nested key: {nested_key}!') + return None cur = cur[key] return cur + + +def is_string_list(var): + """ + return if the var is list of string. + + :param var: input variance + """ + return isinstance(var, list) and all(isinstance(it, str) for it in var) + + +def avg_split_string_list_under_limit(str_list: list, + token_nums: list, + max_token_num=None): + """ + Split the string list to several sub str_list, such that the total + token num of each sub string list is less than max_token_num, keeping + the total token nums of sub string lists are similar. + + :param str_list: input string list. + :param token_nums: token num of each string list. + :param max_token_num: max token num of each sub string list. + """ + if max_token_num is None: + return [str_list] + + if len(str_list) != len(token_nums): + logger.warning('The length of str_list and token_nums must be equal!') + return [str_list] + + total_num = sum(token_nums) + if total_num <= max_token_num: + return [str_list] + + group_num = total_num // max_token_num + 1 + avg_num = total_num / group_num + res = [] + cur_list = [] + cur_sum = 0 + for text, token_num in zip(str_list, token_nums): + if token_num > max_token_num: + logger.warning( + 'Token num is greater than max_token_num in one sample!') + if cur_sum + token_num > max_token_num and cur_list: + res.append(cur_list) + cur_list = [] + cur_sum = 0 + cur_list.append(text) + cur_sum += token_num + if cur_sum > avg_num: + res.append(cur_list) + cur_list = [] + cur_sum = 0 + if cur_list: + res.append(cur_list) + return res diff --git a/data_juicer/utils/model_utils.py b/data_juicer/utils/model_utils.py index 7f9687079..4849c2caa 100644 --- a/data_juicer/utils/model_utils.py +++ b/data_juicer/utils/model_utils.py @@ -203,7 +203,7 @@ def _filter_arguments(func, args_dict): return filtered_args -def prepare_api_model(model, +def prepare_api_model(api_model, *, url=None, response_path=None, @@ -214,7 +214,7 @@ def prepare_api_model(model, The callable supports custom response parsing and works with proxy servers that may be incompatible. - :param model: The name of the model to interact with. + :param api_model: The name of the model to interact with. :param url: URL endpoint for the API. :param response_path: The dot-separated path to extract desired content from the API response. Defaults to 'choices.0.message.content'. @@ -229,7 +229,7 @@ def prepare_api_model(model, :return: A tuple containing the callable API model object and optionally a processor if `return_processor` is True. """ - model = APIModel(model=model, + model = APIModel(model=api_model, url=url, response_path=response_path, **model_params) @@ -240,13 +240,20 @@ def prepare_api_model(model, def get_processor(): try: import tiktoken - return tiktoken.encoding_for_model(model) + return tiktoken.encoding_for_model(api_model) except Exception: pass try: import dashscope - return dashscope.get_tokenizer(model) + return dashscope.get_tokenizer(api_model) + except Exception: + pass + + try: + processor = transformers.AutoProcessor.from_pretrained( + pretrained_model_name_or_path=api_model, **processor_config) + return processor except Exception: pass @@ -257,7 +264,8 @@ def get_processor(): "- For custom models: Use the 'processor_config' parameter to configure a Hugging Face processor." # noqa: E501 ) - if processor_config is not None: + if processor_config is not None and \ + 'pretrained_model_name_or_path' in processor_config: processor = transformers.AutoProcessor.from_pretrained( **processor_config) else: @@ -804,3 +812,25 @@ def free_models(): except Exception: pass MODEL_ZOO.clear() + + +def parse_model_response(response): + """ + Parse model response of LLM to text. + """ + if isinstance(response, str): + return response + elif isinstance(response, dict) and 'content' in response: + return response['content'] + elif isinstance(response, list): + res = None + for msg in response: + if not isinstance( + msg, dict) or 'role' not in msg or 'content' not in msg: + logger.warning('Unvalid response of LLM!') + return None + if msg['role'] == 'assistant': + res = msg['content'] + return res + logger.warning('Unvalid response of LLM!') + return None diff --git a/tests/ops/Aggregator/__init__.py b/tests/ops/Aggregator/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/ops/Aggregator/test_entity_attribute_aggregator.py b/tests/ops/Aggregator/test_entity_attribute_aggregator.py new file mode 100644 index 000000000..cacbb2b6b --- /dev/null +++ b/tests/ops/Aggregator/test_entity_attribute_aggregator.py @@ -0,0 +1,161 @@ +import unittest + +from loguru import logger + +from data_juicer.core.data import NestedDataset as Dataset +from data_juicer.ops.aggregator import EntityAttributeAggregator +from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase + + +class EntityAttributeAggregatorTest(DataJuicerTestCaseBase): + + def _run_helper(self, op, samples): + + dataset = Dataset.from_list(samples) + new_dataset = op.run(dataset) + + for data in new_dataset: + for k in data: + logger.info(f"{k}: {data[k]}") + + self.assertEqual(len(new_dataset), len(samples)) + + def test_default_aggregator(self): + samples = [ + { + 'text': [ + "十年前,李相夷十五岁战胜西域天魔成为天下第一高手,十七岁建立四顾门,二十岁问鼎武林盟主,成为传奇人物。", + "有人视李相夷为中原武林的希望,但也有人以战胜他为目标,包括魔教金鸳盟盟主笛飞声。笛飞声设计加害李相夷的师兄单孤刀,引得李相夷与之一战。", + '在东海的一艘船上,李相夷独自一人对抗金鸳盟的高手,最终击败了大部分敌人。笛飞声突然出现,两人激战,李相夷在战斗中中毒,最终被笛飞声重伤,船只爆炸,李相夷沉入大海。', + '十年后,李莲花在一个寒酸的莲花楼内醒来,表现出与李相夷截然不同的性格。他以神医的身份在小镇上行医,但生活贫困。', + '小镇上的皮影戏摊讲述李相夷和笛飞声的故事,孩子们争论谁赢了。风火堂管事带着人来找李莲花,要求他救治一个“死人”。' + ] + }, + ] + op = EntityAttributeAggregator( + hf_or_api_model='qwen2.5-72b-instruct', + entity='李莲花', + attribute='主要经历' + ) + self._run_helper(op, samples) + + def test_input_output(self): + samples = [ + { + 'sub_docs': [ + "十年前,李相夷十五岁战胜西域天魔成为天下第一高手,十七岁建立四顾门,二十岁问鼎武林盟主,成为传奇人物。", + "有人视李相夷为中原武林的希望,但也有人以战胜他为目标,包括魔教金鸳盟盟主笛飞声。笛飞声设计加害李相夷的师兄单孤刀,引得李相夷与之一战。", + '在东海的一艘船上,李相夷独自一人对抗金鸳盟的高手,最终击败了大部分敌人。笛飞声突然出现,两人激战,李相夷在战斗中中毒,最终被笛飞声重伤,船只爆炸,李相夷沉入大海。', + '十年后,李莲花在一个寒酸的莲花楼内醒来,表现出与李相夷截然不同的性格。他以神医的身份在小镇上行医,但生活贫困。', + '小镇上的皮影戏摊讲述李相夷和笛飞声的故事,孩子们争论谁赢了。风火堂管事带着人来找李莲花,要求他救治一个“死人”。' + ] + }, + ] + op = EntityAttributeAggregator( + hf_or_api_model='qwen2.5-72b-instruct', + entity='李莲花', + attribute='身份背景', + input_key='sub_docs', + output_key='text' + ) + self._run_helper(op, samples) + + def test_max_token_num(self): + samples = [ + { + 'text': [ + "十年前,李相夷十五岁战胜西域天魔成为天下第一高手,十七岁建立四顾门,二十岁问鼎武林盟主,成为传奇人物。", + "有人视李相夷为中原武林的希望,但也有人以战胜他为目标,包括魔教金鸳盟盟主笛飞声。笛飞声设计加害李相夷的师兄单孤刀,引得李相夷与之一战。", + '在东海的一艘船上,李相夷独自一人对抗金鸳盟的高手,最终击败了大部分敌人。笛飞声突然出现,两人激战,李相夷在战斗中中毒,最终被笛飞声重伤,船只爆炸,李相夷沉入大海。', + '十年后,李莲花在一个寒酸的莲花楼内醒来,表现出与李相夷截然不同的性格。他以神医的身份在小镇上行医,但生活贫困。', + '小镇上的皮影戏摊讲述李相夷和笛飞声的故事,孩子们争论谁赢了。风火堂管事带着人来找李莲花,要求他救治一个“死人”。' + ] + }, + ] + op = EntityAttributeAggregator( + hf_or_api_model='qwen2.5-72b-instruct', + entity='李莲花', + attribute='身份背景', + max_token_num=200 + ) + self._run_helper(op, samples) + + def test_word_limit_num(self): + samples = [ + { + 'text': [ + "十年前,李相夷十五岁战胜西域天魔成为天下第一高手,十七岁建立四顾门,二十岁问鼎武林盟主,成为传奇人物。", + "有人视李相夷为中原武林的希望,但也有人以战胜他为目标,包括魔教金鸳盟盟主笛飞声。笛飞声设计加害李相夷的师兄单孤刀,引得李相夷与之一战。", + '在东海的一艘船上,李相夷独自一人对抗金鸳盟的高手,最终击败了大部分敌人。笛飞声突然出现,两人激战,李相夷在战斗中中毒,最终被笛飞声重伤,船只爆炸,李相夷沉入大海。', + '十年后,李莲花在一个寒酸的莲花楼内醒来,表现出与李相夷截然不同的性格。他以神医的身份在小镇上行医,但生活贫困。', + '小镇上的皮影戏摊讲述李相夷和笛飞声的故事,孩子们争论谁赢了。风火堂管事带着人来找李莲花,要求他救治一个“死人”。' + ] + }, + ] + op = EntityAttributeAggregator( + hf_or_api_model='qwen2.5-72b-instruct', + entity='李莲花', + attribute='身份背景', + word_limit=20 + ) + self._run_helper(op, samples) + + + def test_example_prompt(self): + samples = [ + { + 'text': [ + "十年前,李相夷十五岁战胜西域天魔成为天下第一高手,十七岁建立四顾门,二十岁问鼎武林盟主,成为传奇人物。", + "有人视李相夷为中原武林的希望,但也有人以战胜他为目标,包括魔教金鸳盟盟主笛飞声。笛飞声设计加害李相夷的师兄单孤刀,引得李相夷与之一战。", + '在东海的一艘船上,李相夷独自一人对抗金鸳盟的高手,最终击败了大部分敌人。笛飞声突然出现,两人激战,李相夷在战斗中中毒,最终被笛飞声重伤,船只爆炸,李相夷沉入大海。', + '十年后,李莲花在一个寒酸的莲花楼内醒来,表现出与李相夷截然不同的性格。他以神医的身份在小镇上行医,但生活贫困。', + '小镇上的皮影戏摊讲述李相夷和笛飞声的故事,孩子们争论谁赢了。风火堂管事带着人来找李莲花,要求他救治一个“死人”。' + ] + }, + ] + example_prompt=( + '- 例如,根据相关文档总结`孙悟空`的`另外身份`,样例如下:\n' + '`孙悟空`的`另外身份`总结:\n' + '# 孙悟空\n' + '## 另外身份\n' + '孙行者、齐天大圣、美猴王\n' + ) + op = EntityAttributeAggregator( + hf_or_api_model='qwen2.5-72b-instruct', + entity='李莲花', + attribute='另外身份', + example_prompt=example_prompt, + word_limit=20 + ) + self._run_helper(op, samples) + + + # def test_hf_model(self): + # samples = [ + # { + # 'text': [ + # "十年前,李相夷十五岁战胜西域天魔成为天下第一高手,十七岁建立四顾门,二十岁问鼎武林盟主,成为传奇人物。", + # "有人视李相夷为中原武林的希望,但也有人以战胜他为目标,包括魔教金鸳盟盟主笛飞声。笛飞声设计加害李相夷的师兄单孤刀,引得李相夷与之一战。", + # '在东海的一艘船上,李相夷独自一人对抗金鸳盟的高手,最终击败了大部分敌人。笛飞声突然出现,两人激战,李相夷在战斗中中毒,最终被笛飞声重伤,船只爆炸,李相夷沉入大海。', + # '十年后,李莲花在一个寒酸的莲花楼内醒来,表现出与李相夷截然不同的性格。他以神医的身份在小镇上行医,但生活贫困。', + # '小镇上的皮影戏摊讲述李相夷和笛飞声的故事,孩子们争论谁赢了。风火堂管事带着人来找李莲花,要求他救治一个“死人”。' + # ] + # }, + # ] + # op = EntityAttributeAggregator( + # hf_or_api_model='/mnt/workspace/shared/checkpoints/qwen/qwen2.5/Qwen2.5-3B-Instruct', + # entity='李莲花', + # attribute='身份背景', + # is_hf_model=True, + # model_params={ + # 'trust_remote_code': True + # }, + # sampling_params={ + # 'max_new_tokens': 50 + # } + # ) + # ## TODO:返回的是message而不是text + # self._run_helper(op, samples) + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/tests/ops/Aggregator/test_nested_aggregator.py b/tests/ops/Aggregator/test_nested_aggregator.py new file mode 100644 index 000000000..107971d5f --- /dev/null +++ b/tests/ops/Aggregator/test_nested_aggregator.py @@ -0,0 +1,139 @@ +import unittest + +from loguru import logger + +from data_juicer.core.data import NestedDataset as Dataset +from data_juicer.ops.aggregator import NestedAggregator +from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase + + +class NestedAggregatorTest(DataJuicerTestCaseBase): + + def _run_helper(self, op, samples): + + dataset = Dataset.from_list(samples) + new_dataset = op.run(dataset) + + for data in new_dataset: + for k in data: + logger.info(f"{k}: {data[k]}") + + self.assertEqual(len(new_dataset), len(samples)) + + def test_default_aggregator(self): + samples = [ + { + 'text': [ + "十年前,李相夷十五岁战胜西域天魔成为天下第一高手,十七岁建立四顾门,二十岁问鼎武林盟主,成为传奇人物。", + "有人视李相夷为中原武林的希望,但也有人以战胜他为目标,包括魔教金鸳盟盟主笛飞声。笛飞声设计加害李相夷的师兄单孤刀,引得李相夷与之一战。", + '在东海的一艘船上,李相夷独自一人对抗金鸳盟的高手,最终击败了大部分敌人。笛飞声突然出现,两人激战,李相夷在战斗中中毒,最终被笛飞声重伤,船只爆炸,李相夷沉入大海。', + '十年后,李莲花在一个寒酸的莲花楼内醒来,表现出与李相夷截然不同的性格。他以神医的身份在小镇上行医,但生活贫困。', + '小镇上的皮影戏摊讲述李相夷和笛飞声的故事,孩子们争论谁赢了。风火堂管事带着人来找李莲花,要求他救治一个“死人”。' + ] + }, + ] + op = NestedAggregator( + hf_or_api_model='qwen2.5-72b-instruct' + ) + self._run_helper(op, samples) + + def test_input_output(self): + samples = [ + { + 'sub_docs': [ + "十年前,李相夷十五岁战胜西域天魔成为天下第一高手,十七岁建立四顾门,二十岁问鼎武林盟主,成为传奇人物。", + "有人视李相夷为中原武林的希望,但也有人以战胜他为目标,包括魔教金鸳盟盟主笛飞声。笛飞声设计加害李相夷的师兄单孤刀,引得李相夷与之一战。", + '在东海的一艘船上,李相夷独自一人对抗金鸳盟的高手,最终击败了大部分敌人。笛飞声突然出现,两人激战,李相夷在战斗中中毒,最终被笛飞声重伤,船只爆炸,李相夷沉入大海。', + '十年后,李莲花在一个寒酸的莲花楼内醒来,表现出与李相夷截然不同的性格。他以神医的身份在小镇上行医,但生活贫困。', + '小镇上的皮影戏摊讲述李相夷和笛飞声的故事,孩子们争论谁赢了。风火堂管事带着人来找李莲花,要求他救治一个“死人”。' + ] + }, + ] + op = NestedAggregator( + hf_or_api_model='qwen2.5-72b-instruct', + input_key='sub_docs', + output_key='text' + ) + self._run_helper(op, samples) + + def test_max_token_num_1(self): + samples = [ + { + 'text': [ + "十年前,李相夷十五岁战胜西域天魔成为天下第一高手,十七岁建立四顾门,二十岁问鼎武林盟主,成为传奇人物。", + "有人视李相夷为中原武林的希望,但也有人以战胜他为目标,包括魔教金鸳盟盟主笛飞声。笛飞声设计加害李相夷的师兄单孤刀,引得李相夷与之一战。", + '在东海的一艘船上,李相夷独自一人对抗金鸳盟的高手,最终击败了大部分敌人。笛飞声突然出现,两人激战,李相夷在战斗中中毒,最终被笛飞声重伤,船只爆炸,李相夷沉入大海。', + '十年后,李莲花在一个寒酸的莲花楼内醒来,表现出与李相夷截然不同的性格。他以神医的身份在小镇上行医,但生活贫困。', + '小镇上的皮影戏摊讲述李相夷和笛飞声的故事,孩子们争论谁赢了。风火堂管事带着人来找李莲花,要求他救治一个“死人”。' + ] + }, + ] + op = NestedAggregator( + hf_or_api_model='qwen2.5-72b-instruct', + max_token_num=2 + ) + self._run_helper(op, samples) + + def test_max_token_num_2(self): + samples = [ + { + 'text': [ + "十年前,李相夷十五岁战胜西域天魔成为天下第一高手,十七岁建立四顾门,二十岁问鼎武林盟主,成为传奇人物。", + "有人视李相夷为中原武林的希望,但也有人以战胜他为目标,包括魔教金鸳盟盟主笛飞声。笛飞声设计加害李相夷的师兄单孤刀,引得李相夷与之一战。", + '在东海的一艘船上,李相夷独自一人对抗金鸳盟的高手,最终击败了大部分敌人。笛飞声突然出现,两人激战,李相夷在战斗中中毒,最终被笛飞声重伤,船只爆炸,李相夷沉入大海。', + '十年后,李莲花在一个寒酸的莲花楼内醒来,表现出与李相夷截然不同的性格。他以神医的身份在小镇上行医,但生活贫困。', + '小镇上的皮影戏摊讲述李相夷和笛飞声的故事,孩子们争论谁赢了。风火堂管事带着人来找李莲花,要求他救治一个“死人”。' + ] + }, + ] + op = NestedAggregator( + hf_or_api_model='qwen2.5-72b-instruct', + max_token_num=90 + ) + self._run_helper(op, samples) + + def test_max_token_num_3(self): + samples = [ + { + 'text': [ + "十年前,李相夷十五岁战胜西域天魔成为天下第一高手,十七岁建立四顾门,二十岁问鼎武林盟主,成为传奇人物。", + "有人视李相夷为中原武林的希望,但也有人以战胜他为目标,包括魔教金鸳盟盟主笛飞声。笛飞声设计加害李相夷的师兄单孤刀,引得李相夷与之一战。", + '在东海的一艘船上,李相夷独自一人对抗金鸳盟的高手,最终击败了大部分敌人。笛飞声突然出现,两人激战,李相夷在战斗中中毒,最终被笛飞声重伤,船只爆炸,李相夷沉入大海。', + '十年后,李莲花在一个寒酸的莲花楼内醒来,表现出与李相夷截然不同的性格。他以神医的身份在小镇上行医,但生活贫困。', + '小镇上的皮影戏摊讲述李相夷和笛飞声的故事,孩子们争论谁赢了。风火堂管事带着人来找李莲花,要求他救治一个“死人”。' + ] + }, + ] + op = NestedAggregator( + hf_or_api_model='qwen2.5-72b-instruct', + max_token_num=200 + ) + self._run_helper(op, samples) + + # def test_hf_model(self): + # samples = [ + # { + # 'text': [ + # "十年前,李相夷十五岁战胜西域天魔成为天下第一高手,十七岁建立四顾门,二十岁问鼎武林盟主,成为传奇人物。", + # "有人视李相夷为中原武林的希望,但也有人以战胜他为目标,包括魔教金鸳盟盟主笛飞声。笛飞声设计加害李相夷的师兄单孤刀,引得李相夷与之一战。", + # '在东海的一艘船上,李相夷独自一人对抗金鸳盟的高手,最终击败了大部分敌人。笛飞声突然出现,两人激战,李相夷在战斗中中毒,最终被笛飞声重伤,船只爆炸,李相夷沉入大海。', + # '十年后,李莲花在一个寒酸的莲花楼内醒来,表现出与李相夷截然不同的性格。他以神医的身份在小镇上行医,但生活贫困。', + # '小镇上的皮影戏摊讲述李相夷和笛飞声的故事,孩子们争论谁赢了。风火堂管事带着人来找李莲花,要求他救治一个“死人”。' + # ] + # }, + # ] + # op = NestedAggregator( + # hf_or_api_model='/mnt/workspace/shared/checkpoints/qwen/qwen2.5/Qwen2.5-3B-Instruct', + # is_hf_model=True, + # model_params={ + # 'trust_remote_code': True + # }, + # sampling_params={ + # 'max_new_tokens': 50 + # } + # ) + # ## TODO:返回的是message而不是text + # self._run_helper(op, samples) + +if __name__ == '__main__': + unittest.main() \ No newline at end of file From 7e66057be6b66ed225109da53de9057f88a0e2e0 Mon Sep 17 00:00:00 2001 From: "gece.gc" Date: Wed, 20 Nov 2024 03:26:26 +0000 Subject: [PATCH 033/118] init python_lambda_mapper --- data_juicer/ops/mapper/__init__.py | 28 +++++---- .../ops/mapper/python_lambda_mapper.py | 47 ++++++++++++++ tests/ops/mapper/test_python_lambda_mapper.py | 63 +++++++++++++++++++ 3 files changed, 125 insertions(+), 13 deletions(-) create mode 100644 data_juicer/ops/mapper/python_lambda_mapper.py create mode 100644 tests/ops/mapper/test_python_lambda_mapper.py diff --git a/data_juicer/ops/mapper/__init__.py b/data_juicer/ops/mapper/__init__.py index 41bf092a3..9806dd527 100644 --- a/data_juicer/ops/mapper/__init__.py +++ b/data_juicer/ops/mapper/__init__.py @@ -29,6 +29,7 @@ from .optimize_query_mapper import OptimizeQueryMapper from .optimize_response_mapper import OptimizeResponseMapper from .punctuation_normalization_mapper import PunctuationNormalizationMapper +from .python_lambda_mapper import PythonLambdaMapper from .remove_bibliography_mapper import RemoveBibliographyMapper from .remove_comments_mapper import RemoveCommentsMapper from .remove_header_mapper import RemoveHeaderMapper @@ -73,17 +74,18 @@ 'ImageCaptioningMapper', 'ImageDiffusionMapper', 'ImageFaceBlurMapper', 'ImageTaggingMapper', 'NlpaugEnMapper', 'NlpcdaZhMapper', 'OptimizeQAMapper', 'OptimizeQueryMapper', 'OptimizeResponseMapper', - 'PunctuationNormalizationMapper', 'RemoveBibliographyMapper', - 'RemoveCommentsMapper', 'RemoveHeaderMapper', 'RemoveLongWordsMapper', - 'RemoveNonChineseCharacterlMapper', 'RemoveRepeatSentencesMapper', - 'RemoveSpecificCharsMapper', 'RemoveTableTextMapper', - 'RemoveWordsWithIncorrectSubstringsMapper', 'ReplaceContentMapper', - 'SentenceSplitMapper', 'TextChunkMapper', 'VideoCaptioningFromAudioMapper', - 'VideoCaptioningFromFramesMapper', 'VideoCaptioningFromSummarizerMapper', - 'VideoCaptioningFromVideoMapper', 'VideoFFmpegWrappedMapper', - 'VideoFaceBlurMapper', 'VideoRemoveWatermarkMapper', - 'VideoResizeAspectRatioMapper', 'VideoResizeResolutionMapper', - 'VideoSplitByDurationMapper', 'VideoSplitByKeyFrameMapper', - 'VideoSplitBySceneMapper', 'VideoTaggingFromAudioMapper', - 'VideoTaggingFromFramesMapper', 'WhitespaceNormalizationMapper' + 'PunctuationNormalizationMapper', 'PythonLambdaMapper', + 'RemoveBibliographyMapper', 'RemoveCommentsMapper', 'RemoveHeaderMapper', + 'RemoveLongWordsMapper', 'RemoveNonChineseCharacterlMapper', + 'RemoveRepeatSentencesMapper', 'RemoveSpecificCharsMapper', + 'RemoveTableTextMapper', 'RemoveWordsWithIncorrectSubstringsMapper', + 'ReplaceContentMapper', 'SentenceSplitMapper', 'TextChunkMapper', + 'VideoCaptioningFromAudioMapper', 'VideoCaptioningFromFramesMapper', + 'VideoCaptioningFromSummarizerMapper', 'VideoCaptioningFromVideoMapper', + 'VideoFFmpegWrappedMapper', 'VideoFaceBlurMapper', + 'VideoRemoveWatermarkMapper', 'VideoResizeAspectRatioMapper', + 'VideoResizeResolutionMapper', 'VideoSplitByDurationMapper', + 'VideoSplitByKeyFrameMapper', 'VideoSplitBySceneMapper', + 'VideoTaggingFromAudioMapper', 'VideoTaggingFromFramesMapper', + 'WhitespaceNormalizationMapper' ] diff --git a/data_juicer/ops/mapper/python_lambda_mapper.py b/data_juicer/ops/mapper/python_lambda_mapper.py new file mode 100644 index 000000000..cd61ed04d --- /dev/null +++ b/data_juicer/ops/mapper/python_lambda_mapper.py @@ -0,0 +1,47 @@ +import ast + +from ..base_op import OPERATORS, Mapper + +OP_NAME = 'python_lambda_mapper' + + +@OPERATORS.register_module(OP_NAME) +class PythonLambdaMapper(Mapper): + + def __init__(self, lambda_str: str): + # Parse and validate the lambda function + self.lambda_func = self._create_lambda(lambda_str) + + def _create_lambda(self, lambda_str: str): + # Parse input string into an AST and check for a valid lambda function + try: + node = ast.parse(lambda_str, mode='eval') + + # Check if the body of the expression is a lambda + if not isinstance(node.body, ast.Lambda): + raise ValueError( + 'Input string must be a valid lambda function.') + + # Check that the lambda has exactly one argument + if len(node.body.args.args) != 1: + raise ValueError( + 'Lambda function must have exactly one argument.') + + # Compile the AST to code + compiled_code = compile(node, '', 'eval') + # Safely evaluate the compiled code allowing built-in functions + func = eval(compiled_code, {'__builtins__': __builtins__}) + return func + except Exception as e: + raise ValueError(f'Invalid lambda function: {e}') + + def process_single(self, sample): + # Process the input through the lambda function and return the result + result = self.lambda_func(sample) + + # Check if the result is a valid + if not isinstance(result, dict): + raise ValueError(f'Lambda function must return a dictionary, ' + f'got {type(result).__name__} instead.') + + return result diff --git a/tests/ops/mapper/test_python_lambda_mapper.py b/tests/ops/mapper/test_python_lambda_mapper.py new file mode 100644 index 000000000..2d9d97742 --- /dev/null +++ b/tests/ops/mapper/test_python_lambda_mapper.py @@ -0,0 +1,63 @@ +import unittest + +from data_juicer.ops.mapper.python_lambda_mapper import PythonLambdaMapper +from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase + +class PythonLambdaMapperMapper(DataJuicerTestCaseBase): + + def test_lambda_modifies_values(self): + mapper = PythonLambdaMapper("lambda d: {'value': d['value'] + 1}") # '+1' to 'value' + result = mapper.process_single({'value': 5}) + self.assertEqual(result, {'value': 6}) + + def test_lambda_combines_values(self): + mapper = PythonLambdaMapper("lambda d: {'combined': d['a'] + d['b']}") + result = mapper.process_single({'a': 3, 'b': 7}) + self.assertEqual(result, {'combined': 10}) + + def test_lambda_swaps_values(self): + mapper = PythonLambdaMapper("lambda d: {'a': d['b'], 'b': d['a']}") + result = mapper.process_single({'a': 1, 'b': 2}) + self.assertEqual(result, {'a': 2, 'b': 1}) + + def test_lambda_result_is_not_dict(self): + mapper = PythonLambdaMapper("lambda d: d['value'] + 1") # This returns an int + with self.assertRaises(ValueError) as cm: + mapper.process_single({'value': 10}) + self.assertIn("Lambda function must return a dictionary, got int instead.", str(cm.exception)) + + def test_invalid_syntax(self): + with self.assertRaises(ValueError) as cm: + PythonLambdaMapper("invalid lambda") # Invalid syntax + self.assertIn("Invalid lambda function", str(cm.exception)) + + def test_invalid_expression(self): + with self.assertRaises(ValueError) as cm: + PythonLambdaMapper("3 + 5") # Not a lambda + self.assertIn("Input string must be a valid lambda function.", str(cm.exception)) + + def test_lambda_with_multiple_arguments(self): + with self.assertRaises(ValueError) as cm: + PythonLambdaMapper("lambda x, y: {'sum': x + y}") # Creating a lambda accepts two arguments + self.assertIn("Lambda function must have exactly one argument.", str(cm.exception)) + + def test_lambda_returning_unexpected_structure(self): + mapper = PythonLambdaMapper("lambda d: ({'value': d['value']}, {'extra': d['extra']})") # Invalid return type; too many dictionaries + with self.assertRaises(ValueError) as cm: + mapper.process_single({'value': 5, 'extra': 10}) + self.assertIn("Lambda function must return a dictionary, got tuple instead.", str(cm.exception)) + + def test_lambda_modifies_in_place_and_returns(self): + mapper = PythonLambdaMapper("lambda d: d.update({'new_key': 'added_value'}) or d") # Returns the modified dictionary + sample_dict = {'value': 3} + result = mapper.process_single(sample_dict) + self.assertEqual(result, {'value': 3, 'new_key': 'added_value'}) # Ensure the update worked + + def test_lambda_function_with_no_operation(self): + mapper = PythonLambdaMapper("lambda d: d") # Simply returns the input dictionary + sample_dict = {'key': 'value'} + result = mapper.process_single(sample_dict) + self.assertEqual(result, {'key': 'value'}) # Unchanged + +if __name__ == '__main__': + unittest.main() \ No newline at end of file From a61859b166475738fb47039b6c9df86524cc11c2 Mon Sep 17 00:00:00 2001 From: "gece.gc" Date: Wed, 20 Nov 2024 06:08:03 +0000 Subject: [PATCH 034/118] set default arg --- data_juicer/ops/mapper/python_lambda_mapper.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/data_juicer/ops/mapper/python_lambda_mapper.py b/data_juicer/ops/mapper/python_lambda_mapper.py index cd61ed04d..389fdad49 100644 --- a/data_juicer/ops/mapper/python_lambda_mapper.py +++ b/data_juicer/ops/mapper/python_lambda_mapper.py @@ -1,4 +1,5 @@ import ast +from typing import Optional from ..base_op import OPERATORS, Mapper @@ -8,9 +9,12 @@ @OPERATORS.register_module(OP_NAME) class PythonLambdaMapper(Mapper): - def __init__(self, lambda_str: str): + def __init__(self, lambda_str: Optional[str] = None): # Parse and validate the lambda function - self.lambda_func = self._create_lambda(lambda_str) + if not lambda_str: + self.lambda_func = lambda sample: sample + else: + self.lambda_func = self._create_lambda(lambda_str) def _create_lambda(self, lambda_str: str): # Parse input string into an AST and check for a valid lambda function From 8031a316e7a9121a17985201eb46547a0c510112 Mon Sep 17 00:00:00 2001 From: "gece.gc" Date: Thu, 21 Nov 2024 08:11:32 +0000 Subject: [PATCH 035/118] fix init --- data_juicer/ops/mapper/python_lambda_mapper.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/data_juicer/ops/mapper/python_lambda_mapper.py b/data_juicer/ops/mapper/python_lambda_mapper.py index 389fdad49..3aa78d891 100644 --- a/data_juicer/ops/mapper/python_lambda_mapper.py +++ b/data_juicer/ops/mapper/python_lambda_mapper.py @@ -9,7 +9,8 @@ @OPERATORS.register_module(OP_NAME) class PythonLambdaMapper(Mapper): - def __init__(self, lambda_str: Optional[str] = None): + def __init__(self, lambda_str: Optional[str] = None, **kwargs): + super().__init__(**kwargs) # Parse and validate the lambda function if not lambda_str: self.lambda_func = lambda sample: sample From 67711f9524d31ca925d2d448fb33b7586038178a Mon Sep 17 00:00:00 2001 From: "gece.gc" Date: Thu, 21 Nov 2024 09:13:24 +0000 Subject: [PATCH 036/118] add python_file_mapper --- data_juicer/ops/mapper/__init__.py | 28 +++--- data_juicer/ops/mapper/python_file_mapper.py | 71 +++++++++++++++ tests/ops/mapper/test_python_file_mapper.py | 96 ++++++++++++++++++++ 3 files changed, 182 insertions(+), 13 deletions(-) create mode 100644 data_juicer/ops/mapper/python_file_mapper.py create mode 100644 tests/ops/mapper/test_python_file_mapper.py diff --git a/data_juicer/ops/mapper/__init__.py b/data_juicer/ops/mapper/__init__.py index 41bf092a3..12d55ea1a 100644 --- a/data_juicer/ops/mapper/__init__.py +++ b/data_juicer/ops/mapper/__init__.py @@ -29,6 +29,7 @@ from .optimize_query_mapper import OptimizeQueryMapper from .optimize_response_mapper import OptimizeResponseMapper from .punctuation_normalization_mapper import PunctuationNormalizationMapper +from .python_file_mapper import PythonFileMapper from .remove_bibliography_mapper import RemoveBibliographyMapper from .remove_comments_mapper import RemoveCommentsMapper from .remove_header_mapper import RemoveHeaderMapper @@ -73,17 +74,18 @@ 'ImageCaptioningMapper', 'ImageDiffusionMapper', 'ImageFaceBlurMapper', 'ImageTaggingMapper', 'NlpaugEnMapper', 'NlpcdaZhMapper', 'OptimizeQAMapper', 'OptimizeQueryMapper', 'OptimizeResponseMapper', - 'PunctuationNormalizationMapper', 'RemoveBibliographyMapper', - 'RemoveCommentsMapper', 'RemoveHeaderMapper', 'RemoveLongWordsMapper', - 'RemoveNonChineseCharacterlMapper', 'RemoveRepeatSentencesMapper', - 'RemoveSpecificCharsMapper', 'RemoveTableTextMapper', - 'RemoveWordsWithIncorrectSubstringsMapper', 'ReplaceContentMapper', - 'SentenceSplitMapper', 'TextChunkMapper', 'VideoCaptioningFromAudioMapper', - 'VideoCaptioningFromFramesMapper', 'VideoCaptioningFromSummarizerMapper', - 'VideoCaptioningFromVideoMapper', 'VideoFFmpegWrappedMapper', - 'VideoFaceBlurMapper', 'VideoRemoveWatermarkMapper', - 'VideoResizeAspectRatioMapper', 'VideoResizeResolutionMapper', - 'VideoSplitByDurationMapper', 'VideoSplitByKeyFrameMapper', - 'VideoSplitBySceneMapper', 'VideoTaggingFromAudioMapper', - 'VideoTaggingFromFramesMapper', 'WhitespaceNormalizationMapper' + 'PunctuationNormalizationMapper', 'PythonFileMapper', + 'RemoveBibliographyMapper', 'RemoveCommentsMapper', 'RemoveHeaderMapper', + 'RemoveLongWordsMapper', 'RemoveNonChineseCharacterlMapper', + 'RemoveRepeatSentencesMapper', 'RemoveSpecificCharsMapper', + 'RemoveTableTextMapper', 'RemoveWordsWithIncorrectSubstringsMapper', + 'ReplaceContentMapper', 'SentenceSplitMapper', 'TextChunkMapper', + 'VideoCaptioningFromAudioMapper', 'VideoCaptioningFromFramesMapper', + 'VideoCaptioningFromSummarizerMapper', 'VideoCaptioningFromVideoMapper', + 'VideoFFmpegWrappedMapper', 'VideoFaceBlurMapper', + 'VideoRemoveWatermarkMapper', 'VideoResizeAspectRatioMapper', + 'VideoResizeResolutionMapper', 'VideoSplitByDurationMapper', + 'VideoSplitByKeyFrameMapper', 'VideoSplitBySceneMapper', + 'VideoTaggingFromAudioMapper', 'VideoTaggingFromFramesMapper', + 'WhitespaceNormalizationMapper' ] diff --git a/data_juicer/ops/mapper/python_file_mapper.py b/data_juicer/ops/mapper/python_file_mapper.py new file mode 100644 index 000000000..8cc730461 --- /dev/null +++ b/data_juicer/ops/mapper/python_file_mapper.py @@ -0,0 +1,71 @@ +import importlib.util +import inspect +import os + +from ..base_op import OPERATORS, Mapper + +OP_NAME = 'python_file_mapper' + + +@OPERATORS.register_module(OP_NAME) +class PythonFileMapper(Mapper): + + def __init__(self, + file_path: str = '', + function_name: str = 'process_single', + **kwargs): + super().__init__(**kwargs) + + self.file_path = file_path + self.function_name = function_name + if not file_path: + self.func = lambda sample: sample + else: + self.func = self._load_function() + + def _load_function(self): + if not os.path.isfile(self.file_path): + raise ValueError(f"The file '{self.file_path}' does not exist.") + + if not self.file_path.endswith('.py'): + raise ValueError( + f"The file '{self.file_path}' is not a Python file.") + + # Load the module from the file + module_name = os.path.splitext(os.path.basename(self.file_path))[0] + spec = importlib.util.spec_from_file_location(module_name, + self.file_path) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + + # Fetch the specified function from the module + if not hasattr(module, self.function_name): + raise ValueError( + f"Function '{self.function_name}' not found in '{self.file_path}'." # noqa: E501 + ) + + func = getattr(module, self.function_name) + + if not callable(func): + raise ValueError( + f"The attribute '{self.function_name}' is not callable.") + + # Check that the function has exactly one argument + argspec = inspect.getfullargspec(func) + if len(argspec.args) != 1: + raise ValueError( + f"The function '{self.function_name}' must take exactly one argument" # noqa: E501 + ) + + return func + + def process_single(self, sample): + """Invoke the loaded function with the provided sample.""" + result = self.func(sample) + + if not isinstance(result, dict): + raise ValueError( + f'Function must return a dictionary, got {type(result).__name__} instead.' # noqa: E501 + ) + + return result diff --git a/tests/ops/mapper/test_python_file_mapper.py b/tests/ops/mapper/test_python_file_mapper.py new file mode 100644 index 000000000..0e491bf3f --- /dev/null +++ b/tests/ops/mapper/test_python_file_mapper.py @@ -0,0 +1,96 @@ +import unittest +import tempfile + +from data_juicer.ops.mapper.python_file_mapper import PythonFileMapper +from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase + +class TestPythonFileMapper(DataJuicerTestCaseBase): + + def test_function_execution(self): + """Test the correct execution of a loadable function.""" + with tempfile.NamedTemporaryFile(delete=True, suffix='.py', mode='w+') as temp_file: + temp_file.write( + "def process_data(sample):\n" + " return {'result': sample['value'] + 10}\n" + ) + temp_file.seek(0) # Rewind the file so it can be read + mapper = PythonFileMapper(temp_file.name, "process_data") + result = mapper.process_single({'value': 5}) + self.assertEqual(result, {'result': 15}) + + def test_function_with_import(self): + """Test for a function that contains an import statement.""" + with tempfile.NamedTemporaryFile(delete=True, suffix='.py', mode='w+') as temp_file: + temp_file.write( + "import numpy as np\n" + "def process_data(sample):\n" + " return {'result': np.sum([sample['value'], 10])}\n" + ) + temp_file.seek(0) # Rewind the file so it can be read + mapper = PythonFileMapper(temp_file.name, "process_data") + result = mapper.process_single({'value': 5}) + self.assertEqual(result, {'result': 15}) + + def test_file_not_found(self): + """Test for a non-existent file.""" + with self.assertRaises(ValueError) as cm: + PythonFileMapper("non_existent.py", "process_data") + self.assertIn("does not exist", str(cm.exception)) + + def test_file_not_python_extension(self): + """Test for a file that exists but is not a .py file.""" + with tempfile.NamedTemporaryFile(delete=True, suffix='.txt', mode='w+') as temp_file: + temp_file.write("This is a text file.") + temp_file.seek(0) # Rewind the file so it can be read + with self.assertRaises(ValueError) as cm: + PythonFileMapper(temp_file.name, "some_function") + self.assertIn("is not a Python file", str(cm.exception)) + + def test_function_not_found(self): + """Test for function not existing in the provided file.""" + with tempfile.NamedTemporaryFile(delete=True, suffix='.py', mode='w+') as temp_file: + temp_file.write( + "def existing_function(sample):\n" + " return sample\n" + ) + temp_file.seek(0) # Rewind the file so it can be read + with self.assertRaises(ValueError) as cm: + PythonFileMapper(temp_file.name, "non_existing_function") + self.assertIn("not found", str(cm.exception)) + + def test_function_not_callable(self): + """Test for trying to load a non-callable function.""" + with tempfile.NamedTemporaryFile(delete=True, suffix='.py', mode='w+') as temp_file: + temp_file.write("x = 42") + temp_file.seek(0) # Rewind the file so it can be read + with self.assertRaises(ValueError) as cm: + PythonFileMapper(temp_file.name, "x") + self.assertIn("not callable", str(cm.exception)) + + def test_function_mutiple_arguments(self): + """Test for function that requires more than one argument.""" + with tempfile.NamedTemporaryFile(delete=True, suffix='.py', mode='w+') as temp_file: + temp_file.write( + "def multi_arg_function(arg1, arg2):\n" + " return arg1 + arg2\n" + ) + temp_file.seek(0) # Rewind the file so it can be read + with self.assertRaises(ValueError) as cm: + PythonFileMapper(temp_file.name, "multi_arg_function") + self.assertIn("must take exactly one argument", str(cm.exception)) + + def test_invalid_return_type(self): + """Test for a function returning a non-dictionary.""" + with tempfile.NamedTemporaryFile(delete=True, suffix='.py', mode='w+') as temp_file: + temp_file.write( + "def invalid_function(sample):\n" + " return sample['value'] + 5\n" + ) + temp_file.seek(0) # Rewind the file so it can be read + mapper = PythonFileMapper(temp_file.name, "invalid_function") + with self.assertRaises(ValueError) as cm: + mapper.process_single({'value': 5}) + self.assertIn("Function must return a dictionary, got int instead.", str(cm.exception)) + +if __name__ == '__main__': + unittest.main() \ No newline at end of file From cdeb6924c3d4555a73e2a55f6c082569cee377c1 Mon Sep 17 00:00:00 2001 From: Haibin <1400012807@pku.edu.cn> Date: Fri, 22 Nov 2024 16:35:17 +0800 Subject: [PATCH 037/118] support text & most relavant entities --- data_juicer/config/config.py | 7 +- data_juicer/ops/__init__.py | 2 +- data_juicer/ops/aggregator/__init__.py | 6 +- .../aggregator/entity_attribute_aggregator.py | 76 +++----- .../most_relavant_entities_aggregator.py | 183 ++++++++++++++++++ .../ops/aggregator/nested_aggregator.py | 53 ++--- data_juicer/ops/grouper/key_value_grouper.py | 4 +- data_juicer/ops/mapper/__init__.py | 4 +- .../mapper/extract_entity_attribute_mapper.py | 4 +- .../ops/mapper/extract_support_text_mapper.py | 133 +++++++++++++ data_juicer/utils/common_utils.py | 46 ++++- data_juicer/utils/constant.py | 2 + data_juicer/utils/model_utils.py | 44 +---- .../test_entity_attribute_aggregator.py | 41 +--- .../ops/Aggregator/test_nested_aggregator.py | 39 +--- .../test_extract_entity_attribute_mapper.py | 4 +- .../test_extract_support_text_mapper.py | 80 ++++++++ 17 files changed, 509 insertions(+), 219 deletions(-) create mode 100644 data_juicer/ops/aggregator/most_relavant_entities_aggregator.py create mode 100644 data_juicer/ops/mapper/extract_support_text_mapper.py create mode 100644 tests/ops/mapper/test_extract_support_text_mapper.py diff --git a/data_juicer/config/config.py b/data_juicer/config/config.py index 0b0487dc3..470f1c04b 100644 --- a/data_juicer/config/config.py +++ b/data_juicer/config/config.py @@ -537,8 +537,13 @@ def sort_op_by_types_and_names(op_name_classes): if 'deduplicator' in name] selector_ops = [(name, c) for (name, c) in op_name_classes if 'selector' in name] + grouper_ops = [(name, c) for (name, c) in op_name_classes + if 'grouper' in name] + aggregator_ops = [(name, c) for (name, c) in op_name_classes + if 'aggregator' in name] ops_sorted_by_types = sorted(mapper_ops) + sorted(filter_ops) + sorted( - deduplicator_ops) + sorted(selector_ops) + deduplicator_ops) + sorted(selector_ops) + sorted(grouper_ops) + \ + sorted(aggregator_ops) return ops_sorted_by_types diff --git a/data_juicer/ops/__init__.py b/data_juicer/ops/__init__.py index 4f9e4f4bc..e02e10efa 100644 --- a/data_juicer/ops/__init__.py +++ b/data_juicer/ops/__init__.py @@ -1,4 +1,4 @@ -from . import deduplicator, filter, mapper, selector +from . import aggregator, deduplicator, filter, grouper, mapper, selector from .base_op import (OPERATORS, UNFORKABLE, Aggregator, Deduplicator, Filter, Grouper, Mapper, Selector) from .load import load_ops diff --git a/data_juicer/ops/aggregator/__init__.py b/data_juicer/ops/aggregator/__init__.py index cdc67def9..4afe2974a 100644 --- a/data_juicer/ops/aggregator/__init__.py +++ b/data_juicer/ops/aggregator/__init__.py @@ -1,4 +1,8 @@ from .entity_attribute_aggregator import EntityAttributeAggregator +from .most_relavant_entities_aggregator import MostRelavantEntitiesAggregator from .nested_aggregator import NestedAggregator -__all__ = ['NestedAggregator', 'EntityAttributeAggregator'] +__all__ = [ + 'NestedAggregator', 'EntityAttributeAggregator', + 'MostRelavantEntitiesAggregator' +] diff --git a/data_juicer/ops/aggregator/entity_attribute_aggregator.py b/data_juicer/ops/aggregator/entity_attribute_aggregator.py index 23986a646..6471b7f6b 100644 --- a/data_juicer/ops/aggregator/entity_attribute_aggregator.py +++ b/data_juicer/ops/aggregator/entity_attribute_aggregator.py @@ -6,11 +6,10 @@ from data_juicer.ops.base_op import OPERATORS, UNFORKABLE, Aggregator from data_juicer.utils.common_utils import (avg_split_string_list_under_limit, - get_val_by_nested_key, - is_string_list) + is_string_list, nested_access, + nested_set) from data_juicer.utils.lazy_loader import LazyLoader -from data_juicer.utils.model_utils import (get_model, parse_model_response, - prepare_model) +from data_juicer.utils.model_utils import get_model, prepare_model from .nested_aggregator import NestedAggregator @@ -56,7 +55,7 @@ class EntityAttributeAggregator(Aggregator): DEFAULT_OUTPUT_PATTERN_TEMPLATE = r'\#\s*{entity}\s*\#\#\s*{attribute}\s*(.*?)\Z' # noqa: E501 def __init__(self, - hf_or_api_model: str = 'gpt-4o', + api_model: str = 'gpt-4o', entity: str = None, attribute: str = None, input_key: str = None, @@ -71,14 +70,12 @@ def __init__(self, input_template: Optional[str] = None, output_pattern_template: Optional[str] = None, try_num: PositiveInt = 3, - is_hf_model: bool = False, - enable_vllm: bool = False, model_params: Dict = {}, sampling_params: Dict = {}, **kwargs): """ Initialization method. - :param hf_or_api_model: Huggingface model or API model name. + :param api_model: API model name. :param entity: The given entity. :param attribute: The given attribute. :param input_key: The input field key in the samples. Support for @@ -99,8 +96,6 @@ def __init__(self, :param output_pattern_template: The output template. :param try_num: The number of retry attempts when there is an API call error or output parsing error. - :param is_hf_model: If the hf_or_api_model is huggingface model. - :param enable_vllm: Whether to use VLLM for inference acceleration. :param model_params: Parameters for initializing the API model. :param sampling_params: Extra parameters passed to the API call. e.g {'temperature': 0.9, 'top_p': 0.95} @@ -118,61 +113,38 @@ def __init__(self, self.word_limit = word_limit self.max_token_num = max_token_num - self.system_prompt_template = system_prompt_template or \ + system_prompt_template = system_prompt_template or \ self.DEFAULT_SYSTEM_TEMPLATE self.example_prompt = example_prompt or self.DEFAULT_EXAMPLE_PROMPT self.input_template = input_template or self.DEFAULT_INPUT_TEMPLATE output_pattern_template = output_pattern_template or \ self.DEFAULT_OUTPUT_PATTERN_TEMPLATE + self.system_prompt = system_prompt_template.format( + entity=self.entity, + attribute=self.attribute, + word_limit=self.word_limit, + example=self.example_prompt) self.output_pattern = output_pattern_template.format( entity=entity, attribute=attribute) self.sampling_params = sampling_params - self.is_hf_model = is_hf_model - self.enable_vllm = enable_vllm - if is_hf_model and enable_vllm: - assert torch.cuda.device_count() >= 1, 'must be executed in CUDA' - # cannot initialize vllm replicas on different GPUs - self.num_proc = 1 - if model_params.get('tensor_parallel_size') is None: - tensor_parallel_size = torch.cuda.device_count() - logger.info(f'Set tensor_parallel_size to \ - {tensor_parallel_size} for vllm.') - model_params['tensor_parallel_size'] = tensor_parallel_size - self.model_key = prepare_model( - model_type='vllm', - pretrained_model_name_or_path=hf_or_api_model, - **model_params) - self.sampling_params = vllm.SamplingParams(**sampling_params) - elif is_hf_model: - self.model_key = prepare_model( - model_type='huggingface', - pretrained_model_name_or_path=hf_or_api_model, - return_pipe=True, - **model_params) - self.sampling_params = sampling_params - else: - self.model_key = prepare_model(model_type='api', - api_model=hf_or_api_model, - url=api_url, - response_path=response_path, - return_processor=True, - **model_params) + self.model_key = prepare_model(model_type='api', + api_model=api_model, + url=api_url, + response_path=response_path, + return_processor=True, + **model_params) self.try_num = try_num - self.nested_sum = NestedAggregator(hf_or_api_model=hf_or_api_model, + self.nested_sum = NestedAggregator(api_model=api_model, max_token_num=max_token_num, api_url=api_url, response_path=response_path, try_num=try_num, - is_hf_model=is_hf_model, - enable_vllm=enable_vllm, model_params=model_params, sampling_params=sampling_params) def parse_output(self, response): - response = parse_model_response(response) - pattern = re.compile(self.output_pattern, re.VERBOSE | re.DOTALL) matches = pattern.findall(response) if matches: @@ -192,18 +164,13 @@ def attribute_summary(self, sub_docs, rank=None): self.max_token_num) results = [] for docs in group_docs: - system_prompt = self.system_prompt_template.format( - entity=self.entity, - attribute=self.attribute, - word_limit=self.word_limit, - example=self.example_prompt) doc_str = '\n\n'.join(docs) input_prompt = self.input_template.format(entity=self.entity, attribute=self.attribute, sub_docs=doc_str) messages = [{ 'role': 'system', - 'content': system_prompt + 'content': self.system_prompt }, { 'role': 'user', 'content': input_prompt @@ -224,10 +191,11 @@ def attribute_summary(self, sub_docs, rank=None): def process_single(self, sample=None, rank=None): # if not batched sample - sub_docs = get_val_by_nested_key(sample, self.input_key) + sub_docs = nested_access(sample, self.input_key) if not is_string_list(sub_docs): return sample - sample[self.output_key] = self.attribute_summary(sub_docs, rank=rank) + sample = nested_set(sample, self.output_key, + self.attribute_summary(sub_docs, rank=rank)) return sample diff --git a/data_juicer/ops/aggregator/most_relavant_entities_aggregator.py b/data_juicer/ops/aggregator/most_relavant_entities_aggregator.py new file mode 100644 index 000000000..b48be7856 --- /dev/null +++ b/data_juicer/ops/aggregator/most_relavant_entities_aggregator.py @@ -0,0 +1,183 @@ +import re +from typing import Dict, Optional + +from loguru import logger +from pydantic import PositiveInt + +from data_juicer.ops.base_op import OPERATORS, UNFORKABLE, Aggregator +from data_juicer.utils.common_utils import (is_string_list, nested_access, + nested_set) +from data_juicer.utils.lazy_loader import LazyLoader +from data_juicer.utils.model_utils import get_model, prepare_model + +from ..common import split_text_by_punctuation + +torch = LazyLoader('torch', 'torch') +vllm = LazyLoader('vllm', 'vllm') + +OP_NAME = 'most_relavant_entities_aggregator' + + +# TODO: LLM-based inference. +@UNFORKABLE.register_module(OP_NAME) +@OPERATORS.register_module(OP_NAME) +class MostRelavantEntitiesAggregator(Aggregator): + """ + Return most relavant entities with the given entity from some docs. + """ + + DEFAULT_SYSTEM_TEMPLATE = ( + '给定与`{entity}`相关的一些文档,' + '总结一些与`{entity}`最为相关的`{entity_type}`。\n' + '要求:\n' + '- 不用包含与{entity}为同一{entity_type}的{entity_type}。\n' + '- 请按照人物的重要性进行排序,越重要人物在列表越前面。\n' + '- 你的返回格式如下:\n' + '## 分析\n' + '你对各个{entity_type}与{entity}关联度的分析\n' + '## 列表\n' + '人物1, 人物2, 人物3, ...') + + DEFAULT_INPUT_TEMPLATE = ('`{entity}`的相关文档:\n' + '{sub_docs}\n\n' + '与`{entity}`最相关的一些`{entity_type}`:\n') + + DEFAULT_OUTPUT_PATTERN = r'\#\#\s*列表\s*(.*?)\Z' + + def __init__(self, + api_model: str = 'gpt-4o', + entity: str = None, + query_entity_type: str = None, + input_key: str = None, + output_key: str = None, + max_token_num: Optional[PositiveInt] = None, + *, + api_url: Optional[str] = None, + response_path: Optional[str] = None, + system_prompt_template: Optional[str] = None, + input_template: Optional[str] = None, + output_pattern: Optional[str] = None, + try_num: PositiveInt = 3, + model_params: Dict = {}, + sampling_params: Dict = {}, + **kwargs): + """ + Initialization method. + :param api_model: API model name. + :param entity: The given entity. + :param query_entity_type: The type of queried relavant entities. + :param input_key: The input field key in the samples. Support for + nested keys such as "__dj__stats__.text_len". It is text_key + in default. + :param output_key: The output field key in the samples. Support for + nested keys such as "__dj__stats__.text_len". It is same as the + input_key in default. + :param max_token_num: The max token num of the total tokens of the + sub documents. Without limitation if it is None. + :param api_url: URL endpoint for the API. + :param response_path: Path to extract content from the API response. + Defaults to 'choices.0.message.content'. + :param system_prompt_template: The system prompt template. + :param input_template: The input template. + :param output_pattern: The output pattern. + :param try_num: The number of retry attempts when there is an API + call error or output parsing error. + :param model_params: Parameters for initializing the API model. + :param sampling_params: Extra parameters passed to the API call. + e.g {'temperature': 0.9, 'top_p': 0.95} + :param kwargs: Extra keyword arguments. + """ + super().__init__(**kwargs) + + if entity is None or query_entity_type is None: + raise ValueError( + 'The entity and query_entity_type cannot be None!') + + self.entity = entity + self.query_entity_type = query_entity_type + self.input_key = input_key or self.text_key + self.output_key = output_key or self.input_key + self.max_token_num = max_token_num + + system_prompt_template = system_prompt_template or \ + self.DEFAULT_SYSTEM_TEMPLATE + self.input_template = input_template or self.DEFAULT_INPUT_TEMPLATE + self.output_pattern = output_pattern or self.DEFAULT_OUTPUT_PATTERN + self.system_prompt = system_prompt_template.format( + entity=entity, entity_type=query_entity_type) + + self.sampling_params = sampling_params + self.model_key = prepare_model(model_type='api', + api_model=api_model, + url=api_url, + response_path=response_path, + return_processor=True, + **model_params) + + self.try_num = try_num + + def parse_output(self, response): + pattern = re.compile(self.output_pattern, re.VERBOSE | re.DOTALL) + matches = pattern.findall(response) + if matches: + result = matches[0].strip() + else: + result = '' + result = split_text_by_punctuation(result) + + return result + + def query_most_relavant_entities(self, sub_docs, rank=None): + if not sub_docs: + return '' + + model, tokenizer = get_model(self.model_key, rank, self.use_cuda()) + token_nums = [len(tokenizer.encode(sub_doc)) for sub_doc in sub_docs] + if self.max_token_num is None: + final_docs = sub_docs + else: + final_docs = [] + total_num = 0 + for token_num, doc in zip(token_nums, sub_docs): + total_num += token_num + if total_num > self.max_token_num: + break + final_docs.append(doc) + + doc_str = '\n\n'.join(final_docs) + input_prompt = self.input_template.format( + entity=self.entity, + entity_type=self.query_entity_type, + sub_docs=doc_str) + + messages = [{ + 'role': 'system', + 'content': self.system_prompt + }, { + 'role': 'user', + 'content': input_prompt + }] + result = [] + for i in range(self.try_num): + try: + response = model(messages, **self.sampling_params) + result = self.parse_output(response) + if len(result) > 0: + break + except Exception as e: + logger.warning(f'Exception: {e}') + + return result + + def process_single(self, sample=None, rank=None): + + # if not batched sample + sub_docs = nested_access(sample, self.input_key) + if not is_string_list(sub_docs): + return sample + + sample = nested_set( + sample, self.output_key, + self.query_most_relavant_entities(sub_docs, rank=rank)) + + return sample diff --git a/data_juicer/ops/aggregator/nested_aggregator.py b/data_juicer/ops/aggregator/nested_aggregator.py index d843358b0..4671dc61b 100644 --- a/data_juicer/ops/aggregator/nested_aggregator.py +++ b/data_juicer/ops/aggregator/nested_aggregator.py @@ -5,11 +5,9 @@ from data_juicer.ops.base_op import OPERATORS, UNFORKABLE, Aggregator from data_juicer.utils.common_utils import (avg_split_string_list_under_limit, - get_val_by_nested_key, - is_string_list) + is_string_list, nested_access) from data_juicer.utils.lazy_loader import LazyLoader -from data_juicer.utils.model_utils import (get_model, parse_model_response, - prepare_model) +from data_juicer.utils.model_utils import get_model, prepare_model torch = LazyLoader('torch', 'torch') vllm = LazyLoader('vllm', 'vllm') @@ -53,7 +51,7 @@ class NestedAggregator(Aggregator): DEFAULT_SUB_DOC_TEMPLATE = '文档碎片:\n{text}\n' def __init__(self, - hf_or_api_model: str = 'gpt-4o', + api_model: str = 'gpt-4o', input_key: str = None, output_key: str = None, max_token_num: Optional[PositiveInt] = None, @@ -64,14 +62,12 @@ def __init__(self, sub_doc_template: Optional[str] = None, input_template: Optional[str] = None, try_num: PositiveInt = 3, - is_hf_model: bool = False, - enable_vllm: bool = False, model_params: Dict = {}, sampling_params: Dict = {}, **kwargs): """ Initialization method. - :param hf_or_api_model: Huggingface model or API model name. + :param api_model: API model name. :param input_key: The input field key in the samples. Support for nested keys such as "__dj__stats__.text_len". It is text_key in default. @@ -88,8 +84,6 @@ def __init__(self, :param input_template: The input template. :param try_num: The number of retry attempts when there is an API call error or output parsing error. - :param is_hf_model: If the hf_or_api_model is huggingface model. - :param enable_vllm: Whether to use VLLM for inference acceleration. :param model_params: Parameters for initializing the API model. :param sampling_params: Extra parameters passed to the API call. e.g {'temperature': 0.9, 'top_p': 0.95} @@ -107,41 +101,16 @@ def __init__(self, self.input_template = input_template or self.DEFAULT_INPUT_TEMPLATE self.sampling_params = sampling_params - self.is_hf_model = is_hf_model - self.enable_vllm = enable_vllm - if is_hf_model and enable_vllm: - assert torch.cuda.device_count() >= 1, 'must be executed in CUDA' - # cannot initialize vllm replicas on different GPUs - self.num_proc = 1 - if model_params.get('tensor_parallel_size') is None: - tensor_parallel_size = torch.cuda.device_count() - logger.info(f'Set tensor_parallel_size to \ - {tensor_parallel_size} for vllm.') - model_params['tensor_parallel_size'] = tensor_parallel_size - self.model_key = prepare_model( - model_type='vllm', - pretrained_model_name_or_path=hf_or_api_model, - **model_params) - self.sampling_params = vllm.SamplingParams(**sampling_params) - elif is_hf_model: - self.model_key = prepare_model( - model_type='huggingface', - pretrained_model_name_or_path=hf_or_api_model, - return_pipe=True, - **model_params) - self.sampling_params = sampling_params - else: - self.model_key = prepare_model(model_type='api', - api_model=hf_or_api_model, - url=api_url, - response_path=response_path, - return_processor=True, - **model_params) + self.model_key = prepare_model(model_type='api', + api_model=api_model, + url=api_url, + response_path=response_path, + return_processor=True, + **model_params) self.try_num = try_num def parse_output(self, response): - response = parse_model_response(response) def if_match(text): quotes = [("'", "'"), ('"', '"'), ('“', '”'), ('‘', '’'), @@ -202,7 +171,7 @@ def recursive_summary(self, sub_docs, rank=None): def process_single(self, sample=None, rank=None): # if not batched sample - sub_docs = get_val_by_nested_key(sample, self.input_key) + sub_docs = nested_access(sample, self.input_key) if not is_string_list(sub_docs): return sample diff --git a/data_juicer/ops/grouper/key_value_grouper.py b/data_juicer/ops/grouper/key_value_grouper.py index 4b6e94bf9..3d786319f 100644 --- a/data_juicer/ops/grouper/key_value_grouper.py +++ b/data_juicer/ops/grouper/key_value_grouper.py @@ -1,6 +1,6 @@ from typing import List, Optional -from data_juicer.utils.common_utils import dict_to_hash, get_val_by_nested_key +from data_juicer.utils.common_utils import dict_to_hash, nested_access from ..base_op import OPERATORS, Grouper, convert_list_dict_to_dict_list from .naive_grouper import NaiveGrouper @@ -37,7 +37,7 @@ def process(self, dataset): for sample in dataset: cur_dict = {} for key in self.group_by_keys: - cur_dict[key] = get_val_by_nested_key(sample, key) + cur_dict[key] = nested_access(sample, key) sample_key = dict_to_hash(cur_dict) if sample_key in sample_map: sample_map[sample_key].append(sample) diff --git a/data_juicer/ops/mapper/__init__.py b/data_juicer/ops/mapper/__init__.py index 41bf092a3..7b2e24f02 100644 --- a/data_juicer/ops/mapper/__init__.py +++ b/data_juicer/ops/mapper/__init__.py @@ -14,6 +14,7 @@ from .extract_event_mapper import ExtractEventMapper from .extract_keyword_mapper import ExtractKeywordMapper from .extract_nickname_mapper import ExtractNicknameMapper +from .extract_support_text_mapper import ExtractSupportTextMapper from .fix_unicode_mapper import FixUnicodeMapper from .generate_qa_from_examples_mapper import GenerateQAFromExamplesMapper from .generate_qa_from_text_mapper import GenerateQAFromTextMapper @@ -67,7 +68,8 @@ 'CleanEmailMapper', 'CleanHtmlMapper', 'CleanIpMapper', 'CleanLinksMapper', 'ExpandMacroMapper', 'ExtractEntityAttributeMapper', 'ExtractEntityRelationMapper', 'ExtractEventMapper', - 'ExtractKeywordMapper', 'ExtractNicknameMapper', 'FixUnicodeMapper', + 'ExtractKeywordMapper', 'ExtractNicknameMapper', + 'ExtractSupportTextMapper', 'FixUnicodeMapper', 'GenerateQAFromExamplesMapper', 'GenerateQAFromTextMapper', 'ImageBlurMapper', 'ImageCaptioningFromGPT4VMapper', 'ImageCaptioningMapper', 'ImageDiffusionMapper', 'ImageFaceBlurMapper', diff --git a/data_juicer/ops/mapper/extract_entity_attribute_mapper.py b/data_juicer/ops/mapper/extract_entity_attribute_mapper.py index ebff93485..e5594aa61 100644 --- a/data_juicer/ops/mapper/extract_entity_attribute_mapper.py +++ b/data_juicer/ops/mapper/extract_entity_attribute_mapper.py @@ -40,9 +40,9 @@ class ExtractEntityAttributeMapper(Mapper): DEFAULT_DEMON_PATTERN = r'\#\#\#\s*代表性示例(\d+):\s*(.*?)(?=\#\#\#|\Z)' def __init__(self, + api_model: str = 'gpt-4o', query_entities: List[str] = [], query_attributes: List[str] = [], - api_model: str = 'gpt-4o', *, entity_key: str = Fields.main_entity, attribute_key: str = Fields.attribute, @@ -61,9 +61,9 @@ def __init__(self, **kwargs): """ Initialization method. + :param api_model: API model name. :param query_entities: Entity list to be queried. :param query_attributes: Attribute list to be queried. - :param api_model: API model name. :param entity_key: The field name to store the given main entity for attribute extraction. It's "__dj__entity__" in default. :param entity_attribute_key: The field name to store the given diff --git a/data_juicer/ops/mapper/extract_support_text_mapper.py b/data_juicer/ops/mapper/extract_support_text_mapper.py new file mode 100644 index 000000000..73c89b88c --- /dev/null +++ b/data_juicer/ops/mapper/extract_support_text_mapper.py @@ -0,0 +1,133 @@ +from typing import Dict, Optional + +from loguru import logger +from pydantic import PositiveInt + +from data_juicer.ops.base_op import OPERATORS, UNFORKABLE, Mapper +from data_juicer.utils.common_utils import nested_access, nested_set +from data_juicer.utils.constant import Fields +from data_juicer.utils.model_utils import get_model, prepare_model + +OP_NAME = 'extract_support_text_mapper' + + +# TODO: LLM-based inference. +@UNFORKABLE.register_module(OP_NAME) +@OPERATORS.register_module(OP_NAME) +class ExtractSupportTextMapper(Mapper): + """ + Extract support sub text for a summary. + """ + + DEFAULT_SYSTEM_PROMPT = ('你将扮演一个文本摘录助手的角色。你的主要任务是基于给定' + '的文章(称为“原文”)以及对原文某个部分的简短描述或总结' + '(称为“总结”),准确地识别并提取出与该总结相对应的原文' + '片段。\n' + '要求:\n' + '- 你需要尽可能精确地匹配到最符合总结内容的那部分内容\n' + '- 如果存在多个可能的答案,请选择最贴近总结意思的那个\n' + '- 下面是一个例子帮助理解这一过程:\n' + '### 原文:\n' + '《红楼梦》是中国古典小说四大名著之一,由清代作家曹雪芹创' + '作。它讲述了贾宝玉、林黛玉等人的爱情故事及四大家族的兴衰' + '历程。书中通过复杂的人物关系展现了封建社会的各种矛盾冲突' + '。其中关于贾府内部斗争的部分尤其精彩,特别是王熙凤与尤二' + '姐之间的争斗,生动描绘了权力争夺下的女性形象。此外,《红' + '楼梦》还以其精美的诗词闻名,这些诗词不仅增添了文学色彩,' + '也深刻反映了人物的性格特点和命运走向。\n\n' + '### 总结:\n' + '描述了书中的两个女性角色之间围绕权力展开的竞争。\n\n' + '### 原文摘录:\n' + '其中关于贾府内部斗争的部分尤其精彩,特别是王熙凤与尤二姐' + '之间的争斗,生动描绘了权力争夺下的女性形象。') + DEFAULT_INPUT_TEMPLATE = ('### 原文:\n{text}\n\n' + '### 总结:\n{summary}\n\n' + '### 原文摘录:\n') + + def __init__(self, + api_model: str = 'gpt-4o', + *, + summary_key: str = Fields.event_description, + support_text_key: str = Fields.support_text, + api_url: Optional[str] = None, + response_path: Optional[str] = None, + system_prompt: Optional[str] = None, + input_template: Optional[str] = None, + try_num: PositiveInt = 3, + drop_text: bool = False, + model_params: Dict = {}, + sampling_params: Dict = {}, + **kwargs): + """ + Initialization method. + :param api_model: API model name. + :param summary_key: The field name to store the input summary. + Support for nested keys such as "__dj__stats__.text_len". + It's "__dj__event_description__" in default. + :param relevant_char_key: The field name to store the output + support text for the summary. It's "__dj__support_text__" in + default. + :param api_url: URL endpoint for the API. + :param response_path: Path to extract content from the API response. + Defaults to 'choices.0.message.content'. + :param system_prompt: System prompt for the task. + :param input_template: Template for building the model input. + :param try_num: The number of retry attempts when there is an API + call error or output parsing error. + :param drop_text: If drop the text in the output. + :param model_params: Parameters for initializing the API model. + :param sampling_params: Extra parameters passed to the API call. + e.g {'temperature': 0.9, 'top_p': 0.95} + :param kwargs: Extra keyword arguments. + """ + super().__init__(**kwargs) + + self.summary_key = summary_key + self.support_text_key = support_text_key + + self.system_prompt = system_prompt or self.DEFAULT_SYSTEM_PROMPT + self.input_template = input_template or self.DEFAULT_INPUT_TEMPLATE + + self.sampling_params = sampling_params + self.model_key = prepare_model(model_type='api', + api_model=api_model, + url=api_url, + response_path=response_path, + **model_params) + + self.try_num = try_num + self.drop_text = drop_text + + def process_single(self, sample, rank=None): + client = get_model(self.model_key, rank=rank) + + summary = nested_access(sample, self.summary_key) + if not isinstance(summary, str): + logger.warning('Unvalid input summary!') + return sample + + input_prompt = self.input_template.format(text=sample[self.text_key], + summary=summary) + messages = [{ + 'role': 'system', + 'content': self.system_prompt + }, { + 'role': 'user', + 'content': input_prompt + }] + + support_text = '' + for i in range(self.try_num): + try: + response = client(messages, **self.sampling_params) + support_text = response.strip() + if len(support_text) > 0: + break + except Exception as e: + logger.warning(f'Exception: {e}') + # default to summary if return None + if not support_text: + support_text = summary + + sample = nested_set(sample, self.support_text_key, support_text) + return sample diff --git a/data_juicer/utils/common_utils.py b/data_juicer/utils/common_utils.py index d9f084e10..bd649bb96 100644 --- a/data_juicer/utils/common_utils.py +++ b/data_juicer/utils/common_utils.py @@ -39,20 +39,48 @@ def dict_to_hash(input_dict: dict, hash_length=None): return hash_value -def get_val_by_nested_key(input_dict: dict, nested_key: str): +def nested_access(data, path, digit_allowed=True): """ - return val of the dict in the nested key. - - :param nested_key: the nested key, such as "__dj__stats__.text_len" + Access nested data using a dot-separated path. + + :param data: A dictionary or a list to access the nested data from. + :param path: A dot-separated string representing the path to access. + This can include numeric indices when accessing list + elements. + :param digit_allowed: Allow transfering string to digit. + :return: The value located at the specified path, or raises a KeyError + or IndexError if the path does not exist. """ - keys = nested_key.split('.') - cur = input_dict + keys = path.split('.') for key in keys: - if key not in cur: - logger.warning(f'Unvisitable nested key: {nested_key}!') + # Convert string keys to integers if they are numeric + key = int(key) if key.isdigit() and digit_allowed else key + try: + data = data[key] + except Exception: + logger.warning(f'Unaccessible dot-separated path: {path}!') return None + return data + + +def nested_set(data: dict, path: str, val): + """ + Set the val to the nested data in the dot-separated path. + + :param data: A dictionary with nested format. + :param path: A dot-separated string representing the path to set. + This can include numeric indices when setting list + elements. + :return: The nested data after the val set. + """ + keys = path.split('.') + cur = data + for key in keys[:-1]: + if key not in cur: + cur[key] = {} cur = cur[key] - return cur + cur[keys[-1]] = val + return data def is_string_list(var): diff --git a/data_juicer/utils/constant.py b/data_juicer/utils/constant.py index ab88035b9..c6e62c8fc 100644 --- a/data_juicer/utils/constant.py +++ b/data_juicer/utils/constant.py @@ -64,6 +64,8 @@ class Fields(object): relation_strength = DEFAULT_PREFIX + 'relation_strength__' # # the keyword in a text keyword = DEFAULT_PREFIX + 'keyword__' + # # support text + support_text = DEFAULT_PREFIX + 'support_text__' class StatsKeysMeta(type): diff --git a/data_juicer/utils/model_utils.py b/data_juicer/utils/model_utils.py index 4849c2caa..ec1eaed49 100644 --- a/data_juicer/utils/model_utils.py +++ b/data_juicer/utils/model_utils.py @@ -11,6 +11,7 @@ from loguru import logger from data_juicer import cuda_device_count +from data_juicer.utils.common_utils import nested_access from data_juicer.utils.lazy_loader import AUTOINSTALL, LazyLoader from .cache_utils import DATA_JUICER_MODELS_CACHE as DJMC @@ -156,30 +157,11 @@ def __call__(self, messages, **kwargs): stream=stream, stream_cls=stream_cls) result = response.json() - return self._nested_access(result, self.response_path) + return nested_access(result, self.response_path) except Exception as e: logger.exception(e) return '' - @staticmethod - def _nested_access(data, path): - """ - Access nested data using a dot-separated path. - - :param data: A dictionary or a list to access the nested data from. - :param path: A dot-separated string representing the path to access. - This can include numeric indices when accessing list - elements. - :return: The value located at the specified path, or raises a KeyError - or IndexError if the path does not exist. - """ - keys = path.split('.') - for key in keys: - # Convert string keys to integers if they are numeric - key = int(key) if key.isdigit() else key - data = data[key] - return data - @staticmethod def _filter_arguments(func, args_dict): """ @@ -812,25 +794,3 @@ def free_models(): except Exception: pass MODEL_ZOO.clear() - - -def parse_model_response(response): - """ - Parse model response of LLM to text. - """ - if isinstance(response, str): - return response - elif isinstance(response, dict) and 'content' in response: - return response['content'] - elif isinstance(response, list): - res = None - for msg in response: - if not isinstance( - msg, dict) or 'role' not in msg or 'content' not in msg: - logger.warning('Unvalid response of LLM!') - return None - if msg['role'] == 'assistant': - res = msg['content'] - return res - logger.warning('Unvalid response of LLM!') - return None diff --git a/tests/ops/Aggregator/test_entity_attribute_aggregator.py b/tests/ops/Aggregator/test_entity_attribute_aggregator.py index cacbb2b6b..647d0486a 100644 --- a/tests/ops/Aggregator/test_entity_attribute_aggregator.py +++ b/tests/ops/Aggregator/test_entity_attribute_aggregator.py @@ -11,6 +11,10 @@ class EntityAttributeAggregatorTest(DataJuicerTestCaseBase): def _run_helper(self, op, samples): + # before runing this test, set below environment variables: + # export OPENAI_BASE_URL=https://dashscope.aliyuncs.com/compatible-mode/v1/ + # export OPENAI_API_KEY=your_dashscope_key + dataset = Dataset.from_list(samples) new_dataset = op.run(dataset) @@ -33,7 +37,7 @@ def test_default_aggregator(self): }, ] op = EntityAttributeAggregator( - hf_or_api_model='qwen2.5-72b-instruct', + api_model='qwen2.5-72b-instruct', entity='李莲花', attribute='主要经历' ) @@ -52,7 +56,7 @@ def test_input_output(self): }, ] op = EntityAttributeAggregator( - hf_or_api_model='qwen2.5-72b-instruct', + api_model='qwen2.5-72b-instruct', entity='李莲花', attribute='身份背景', input_key='sub_docs', @@ -73,7 +77,7 @@ def test_max_token_num(self): }, ] op = EntityAttributeAggregator( - hf_or_api_model='qwen2.5-72b-instruct', + api_model='qwen2.5-72b-instruct', entity='李莲花', attribute='身份背景', max_token_num=200 @@ -93,7 +97,7 @@ def test_word_limit_num(self): }, ] op = EntityAttributeAggregator( - hf_or_api_model='qwen2.5-72b-instruct', + api_model='qwen2.5-72b-instruct', entity='李莲花', attribute='身份背景', word_limit=20 @@ -121,7 +125,7 @@ def test_example_prompt(self): '孙行者、齐天大圣、美猴王\n' ) op = EntityAttributeAggregator( - hf_or_api_model='qwen2.5-72b-instruct', + api_model='qwen2.5-72b-instruct', entity='李莲花', attribute='另外身份', example_prompt=example_prompt, @@ -130,32 +134,5 @@ def test_example_prompt(self): self._run_helper(op, samples) - # def test_hf_model(self): - # samples = [ - # { - # 'text': [ - # "十年前,李相夷十五岁战胜西域天魔成为天下第一高手,十七岁建立四顾门,二十岁问鼎武林盟主,成为传奇人物。", - # "有人视李相夷为中原武林的希望,但也有人以战胜他为目标,包括魔教金鸳盟盟主笛飞声。笛飞声设计加害李相夷的师兄单孤刀,引得李相夷与之一战。", - # '在东海的一艘船上,李相夷独自一人对抗金鸳盟的高手,最终击败了大部分敌人。笛飞声突然出现,两人激战,李相夷在战斗中中毒,最终被笛飞声重伤,船只爆炸,李相夷沉入大海。', - # '十年后,李莲花在一个寒酸的莲花楼内醒来,表现出与李相夷截然不同的性格。他以神医的身份在小镇上行医,但生活贫困。', - # '小镇上的皮影戏摊讲述李相夷和笛飞声的故事,孩子们争论谁赢了。风火堂管事带着人来找李莲花,要求他救治一个“死人”。' - # ] - # }, - # ] - # op = EntityAttributeAggregator( - # hf_or_api_model='/mnt/workspace/shared/checkpoints/qwen/qwen2.5/Qwen2.5-3B-Instruct', - # entity='李莲花', - # attribute='身份背景', - # is_hf_model=True, - # model_params={ - # 'trust_remote_code': True - # }, - # sampling_params={ - # 'max_new_tokens': 50 - # } - # ) - # ## TODO:返回的是message而不是text - # self._run_helper(op, samples) - if __name__ == '__main__': unittest.main() \ No newline at end of file diff --git a/tests/ops/Aggregator/test_nested_aggregator.py b/tests/ops/Aggregator/test_nested_aggregator.py index 107971d5f..eebf9d38a 100644 --- a/tests/ops/Aggregator/test_nested_aggregator.py +++ b/tests/ops/Aggregator/test_nested_aggregator.py @@ -11,6 +11,10 @@ class NestedAggregatorTest(DataJuicerTestCaseBase): def _run_helper(self, op, samples): + # before runing this test, set below environment variables: + # export OPENAI_BASE_URL=https://dashscope.aliyuncs.com/compatible-mode/v1/ + # export OPENAI_API_KEY=your_dashscope_key + dataset = Dataset.from_list(samples) new_dataset = op.run(dataset) @@ -33,7 +37,7 @@ def test_default_aggregator(self): }, ] op = NestedAggregator( - hf_or_api_model='qwen2.5-72b-instruct' + api_model='qwen2.5-72b-instruct' ) self._run_helper(op, samples) @@ -50,7 +54,7 @@ def test_input_output(self): }, ] op = NestedAggregator( - hf_or_api_model='qwen2.5-72b-instruct', + api_model='qwen2.5-72b-instruct', input_key='sub_docs', output_key='text' ) @@ -69,7 +73,7 @@ def test_max_token_num_1(self): }, ] op = NestedAggregator( - hf_or_api_model='qwen2.5-72b-instruct', + api_model='qwen2.5-72b-instruct', max_token_num=2 ) self._run_helper(op, samples) @@ -87,7 +91,7 @@ def test_max_token_num_2(self): }, ] op = NestedAggregator( - hf_or_api_model='qwen2.5-72b-instruct', + api_model='qwen2.5-72b-instruct', max_token_num=90 ) self._run_helper(op, samples) @@ -105,35 +109,10 @@ def test_max_token_num_3(self): }, ] op = NestedAggregator( - hf_or_api_model='qwen2.5-72b-instruct', + api_model='qwen2.5-72b-instruct', max_token_num=200 ) self._run_helper(op, samples) - # def test_hf_model(self): - # samples = [ - # { - # 'text': [ - # "十年前,李相夷十五岁战胜西域天魔成为天下第一高手,十七岁建立四顾门,二十岁问鼎武林盟主,成为传奇人物。", - # "有人视李相夷为中原武林的希望,但也有人以战胜他为目标,包括魔教金鸳盟盟主笛飞声。笛飞声设计加害李相夷的师兄单孤刀,引得李相夷与之一战。", - # '在东海的一艘船上,李相夷独自一人对抗金鸳盟的高手,最终击败了大部分敌人。笛飞声突然出现,两人激战,李相夷在战斗中中毒,最终被笛飞声重伤,船只爆炸,李相夷沉入大海。', - # '十年后,李莲花在一个寒酸的莲花楼内醒来,表现出与李相夷截然不同的性格。他以神医的身份在小镇上行医,但生活贫困。', - # '小镇上的皮影戏摊讲述李相夷和笛飞声的故事,孩子们争论谁赢了。风火堂管事带着人来找李莲花,要求他救治一个“死人”。' - # ] - # }, - # ] - # op = NestedAggregator( - # hf_or_api_model='/mnt/workspace/shared/checkpoints/qwen/qwen2.5/Qwen2.5-3B-Instruct', - # is_hf_model=True, - # model_params={ - # 'trust_remote_code': True - # }, - # sampling_params={ - # 'max_new_tokens': 50 - # } - # ) - # ## TODO:返回的是message而不是text - # self._run_helper(op, samples) - if __name__ == '__main__': unittest.main() \ No newline at end of file diff --git a/tests/ops/mapper/test_extract_entity_attribute_mapper.py b/tests/ops/mapper/test_extract_entity_attribute_mapper.py index 96f186d29..177880358 100644 --- a/tests/ops/mapper/test_extract_entity_attribute_mapper.py +++ b/tests/ops/mapper/test_extract_entity_attribute_mapper.py @@ -21,9 +21,9 @@ def _run_op(self, api_model, response_path=None): query_attributes = ["语言风格", "角色性格"] op = ExtractEntityAttributeMapper( + api_model=api_model, query_entities=query_entities, - query_attributes=query_attributes, - api_model=api_model, + query_attributes=query_attributes, response_path=response_path) raw_text = """△笛飞声独自坐在莲花楼屋顶上。李莲花边走边悠闲地给马喂草。方多病则走在一侧,却总不时带着怀疑地盯向楼顶的笛飞声。 diff --git a/tests/ops/mapper/test_extract_support_text_mapper.py b/tests/ops/mapper/test_extract_support_text_mapper.py new file mode 100644 index 000000000..0445d2526 --- /dev/null +++ b/tests/ops/mapper/test_extract_support_text_mapper.py @@ -0,0 +1,80 @@ +import unittest +import json + +from loguru import logger + +from data_juicer.core.data import NestedDataset as Dataset +from data_juicer.ops.mapper.extract_support_text_mapper import ExtractSupportTextMapper +from data_juicer.utils.unittest_utils import (SKIPPED_TESTS, + DataJuicerTestCaseBase) +from data_juicer.utils.constant import Fields +from data_juicer.utils.common_utils import nested_access + +# Skip tests for this OP in the GitHub actions due to unknown DistNetworkError. +# These tests have been tested locally. +@SKIPPED_TESTS.register_module() +class ExtractSupportTextMapperTest(DataJuicerTestCaseBase): + + + def _run_op(self, api_model): + + summary_key = 'data.event' + support_text_key = 'data.support_text' + op = ExtractSupportTextMapper(api_model=api_model, + summary_key=summary_key, + support_text_key=support_text_key) + + raw_text = """△芩婆走到中间,看着众人。 +芩婆:当年,我那老鬼漆木山与李相夷之父乃是挚交。原本李家隐世而居,一日为了救人,得罪附近山匪,夜里便遭了山匪所袭,唯有二子生还,流落街头。 +封磬震惊:二子?不是只有一个儿子吗? +芩婆:我和漆木山得知这个噩耗后,到处寻找李家那两个孩子的下落。只可惜等我们找他们时,李家长子李相显已经病死。 +李莲花似回忆起了什么:李相显...... +芩婆:我们只从乞丐堆里带回了年纪尚且未满四岁的李相夷,以及,(看向单孤刀)二个一直护着李相夷,与李相显年纪相仿的小乞丐...... +闪回/ +李相显将李且给他的玉佩塞给单孤刀,恳切托付:我没什么值钱的东西,这个玉佩是我唯一的家当了、送给你,我弟弟、相夷......求你照顾他一阵...... +△李相显还想再说什么已气绝而亡,小相夷唤着哥哥大哭,单孤刀愕然看着手里的玉佩有点不知所措。 +△话刚说完,哐当一声破庙门倒进来,几个其他少年乞丐进来。少年乞丐老大:这地儿不错,诶,你俩,出去! +△单孤刀把小相夷护在身后,抓住靠在墙边的木棍。单孤刀:这儿,是我,和我弟弟的。 +乞丐们要抢李相夷的馒头,小李相夷哭着死死护住自馒头不放。 +乞丐甲野蛮地抢:给我拿来! +小单孤刀:放开他! +△单孤刀用力撞向几个乞丐,救下小李相夷。乞丐甲:小子,活腻了! +△几个乞丐围攻小单孤刀,小单孤刀和众乞丐厮打到一起。突然其中一个乞丐掏出一把生锈的刀就朝单孤刀砍去、一个点燃火把棍戳他。单孤刀侧手一挡,火把棍在他手腕上烫出一道伤口,身后几根棍子打得他痛苦倒地! +/闪回结束 +△单孤刀拿着自己手里的玉佩看着,又看看自己手上的印记,不肯相信。单孤刀:胡说!全都是胡说!这些事我为何不知道?都是你在信口雌黄! +芩婆:那我问你,我们将你带回云隐山之前的事你又记得多少? +△单孤刀突然愣住,他意识到那之前的事自己竟都想不起来。 +芩婆:怎么?都想不起来了?(拽起单孤刀手腕,露出他的伤痕)你当日被你师父找到时,手腕上就受了伤,也正因为这处伤,高烧不退,醒来后便忘记了不少从前的事。 +△单孤刀呆住。 +芩婆:而相夷当年不过孩童,尚未到记事的年纪,很多事自然不知道。 +△李莲花得知真相,闭目叹息。 +△封磬震惊地看看单孤刀,又看看李莲花,终于想明白了一切,颓然、懊恼。 +封磬:自萱公主之子下落不明后,这近百年来我们整个家族都一直在不遗余力地寻找萱公主的子嗣后代,直到二十几年前终于让我寻得了线索,知道萱公主的曾孙被漆木山夫妇收为徒,但......我只知道萱公主之孙有一年约十岁的儿子,却不知......原来竟还有一幼子!我......我凭着南胤皇族的玉佩、孩子的年纪和他身上的印记来与主上相认,可没想到......这竟是一个错误!全错了! +△封磬神情复杂地看向李莲花,封磬:你,你才是我的主上...... +△封磬颓然地跪倒下来。 +△李莲花对眼前的一切有些意外、无措。 +笛飞声冷声:怪不得单孤刀的血对业火独毫无作用,李莲花的血才能毁掉这东西。 +△笛飞声不禁冷笑一下。 +""" + event = "李相显托付单孤刀。" + samples = [{ + 'text': raw_text, + 'data':{ + 'event': event + } + }] + + dataset = Dataset.from_list(samples) + dataset = dataset.map(op.process, batch_size=2) + sample = dataset[0] + logger.info(f"support_text: \n{nested_access(sample, support_text_key)}") + + def test(self): + # before runing this test, set below environment variables: + # export OPENAI_BASE_URL=https://dashscope.aliyuncs.com/compatible-mode/v1/ + # export OPENAI_API_KEY=your_dashscope_key + self._run_op('qwen2.5-72b-instruct') + + +if __name__ == '__main__': + unittest.main() From 125a8f356c3129bd8372274d6c4f5409c418971f Mon Sep 17 00:00:00 2001 From: "gece.gc" Date: Mon, 25 Nov 2024 02:30:24 +0000 Subject: [PATCH 038/118] coverage ignore_errors --- .coveragerc | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.coveragerc b/.coveragerc index d4a7a6d63..d95c7fc28 100644 --- a/.coveragerc +++ b/.coveragerc @@ -9,3 +9,6 @@ omit = # avoid measuring code of unittest tests/* + +[report] +ignore_errors = True From 0c68089beb446ada737b55a6c9a5a551e5348a19 Mon Sep 17 00:00:00 2001 From: Haibin <1400012807@pku.edu.cn> Date: Mon, 25 Nov 2024 11:21:16 +0800 Subject: [PATCH 039/118] index sample --- data_juicer/ops/base_op.py | 11 +++++++++++ tests/config/test_config_funcs.py | 7 +++++++ tests/ops/mapper/test_extract_event_mapper.py | 7 +++++-- 3 files changed, 23 insertions(+), 2 deletions(-) diff --git a/data_juicer/ops/base_op.py b/data_juicer/ops/base_op.py index 361bf7671..02651cbc7 100644 --- a/data_juicer/ops/base_op.py +++ b/data_juicer/ops/base_op.py @@ -132,6 +132,7 @@ def __init__(self, *args, **kwargs): :param response_key: the key name of field that stores responses :param history_key: the key name of field that stores history of queries and responses + :param index_key: index the samples before process if not None """ # init data keys self.text_key = kwargs.get('text_key', 'text') @@ -143,6 +144,8 @@ def __init__(self, *args, **kwargs): self.response_key = kwargs.get('response_key', 'response') self.history_key = kwargs.get('history_key', 'history') + self.index_key = kwargs.get('index_key', None) + self.batch_size = kwargs.get('batch_size', 1000) # whether the model can be accelerated using cuda @@ -218,6 +221,14 @@ def run(self, dataset): from data_juicer.core.data import NestedDataset if not isinstance(dataset, NestedDataset): dataset = NestedDataset(dataset) + if self.index_key is not None: + + def add_index(sample, idx): + sample[self.index_key] = idx + return sample + + dataset = dataset.map(add_index, with_indices=True) + return dataset def empty_history(self): diff --git a/tests/config/test_config_funcs.py b/tests/config/test_config_funcs.py index 1cb7c4463..2502fe1e1 100644 --- a/tests/config/test_config_funcs.py +++ b/tests/config/test_config_funcs.py @@ -54,6 +54,7 @@ def test_yaml_cfg_file(self): 'mem_required': 0, 'turbo': False, 'batch_size': 1000, + 'index_key': None, } }, 'nested dict load fail, for nonparametric op') self.assertDictEqual( @@ -75,6 +76,7 @@ def test_yaml_cfg_file(self): 'mem_required': 0, 'turbo': False, 'batch_size': 1000, + 'index_key': None, } }, 'nested dict load fail, un-expected internal value') @@ -144,6 +146,7 @@ def test_mixture_cfg(self): 'mem_required': 0, 'turbo': False, 'batch_size': 1000, + 'index_key': None, } }) self.assertDictEqual( @@ -165,6 +168,7 @@ def test_mixture_cfg(self): 'mem_required': 0, 'turbo': False, 'batch_size': 1000, + 'index_key': None, } }) self.assertDictEqual( @@ -186,6 +190,7 @@ def test_mixture_cfg(self): 'mem_required': 0, 'turbo': False, 'batch_size': 1000, + 'index_key': None, } }) self.assertDictEqual( @@ -207,6 +212,7 @@ def test_mixture_cfg(self): 'mem_required': 0, 'turbo': False, 'batch_size': 1000, + 'index_key': None, } }) self.assertDictEqual( @@ -228,6 +234,7 @@ def test_mixture_cfg(self): 'mem_required': 0, 'turbo': False, 'batch_size': 1000, + 'index_key': None, } }) diff --git a/tests/ops/mapper/test_extract_event_mapper.py b/tests/ops/mapper/test_extract_event_mapper.py index 1652c8db2..aba40d73e 100644 --- a/tests/ops/mapper/test_extract_event_mapper.py +++ b/tests/ops/mapper/test_extract_event_mapper.py @@ -18,7 +18,8 @@ class ExtractEventMapperTest(DataJuicerTestCaseBase): def _run_op(self, api_model, response_path=None): op = ExtractEventMapper(api_model=api_model, - response_path=response_path) + response_path=response_path, + index_key='chunk_id') raw_text = """△芩婆走到中间,看着众人。 芩婆:当年,我那老鬼漆木山与李相夷之父乃是挚交。原本李家隐世而居,一日为了救人,得罪附近山匪,夜里便遭了山匪所袭,唯有二子生还,流落街头。 @@ -57,9 +58,11 @@ def _run_op(self, api_model, response_path=None): }] dataset = Dataset.from_list(samples) - dataset = dataset.map(op.process, batch_size=2) + dataset = op.run(dataset) self.assertNotEqual(len(dataset), 0) for sample in dataset: + logger.info(f"chunk_id: {sample['chunk_id']}") + self.assertEqual(sample['chunk_id'], 0) logger.info(f"event: {sample[Fields.event_description]}") self.assertNotEqual(sample[Fields.event_description], '') logger.info(f"characters: {sample[Fields.relevant_characters]}") From 651789d162872d5dfeeb2e92c3377eaaa2950204 Mon Sep 17 00:00:00 2001 From: Haibin <1400012807@pku.edu.cn> Date: Mon, 25 Nov 2024 11:49:10 +0800 Subject: [PATCH 040/118] role_playing_system_prompt_yaml --- .../role_playing_system_prompt.yaml | 35 +++++++++++++++++++ 1 file changed, 35 insertions(+) create mode 100644 demos/role_playing_system_prompt/role_playing_system_prompt.yaml diff --git a/demos/role_playing_system_prompt/role_playing_system_prompt.yaml b/demos/role_playing_system_prompt/role_playing_system_prompt.yaml new file mode 100644 index 000000000..97fc88464 --- /dev/null +++ b/demos/role_playing_system_prompt/role_playing_system_prompt.yaml @@ -0,0 +1,35 @@ +# Process config example for dataset + +# global parameters +project_name: 'role-play-demo-process' +dataset_path: 'path_to_the_lianhualou_novel_json_file' +np: 1 # number of subprocess to process your dataset + +export_path: 'path_to_output_jsonl_file' + +# process schedule +# a list of several process operators with their arguments +process: + # - text_chunk_mapper: # if not chunked + # max_len: 8000 + # split_pattern: '\n\n' + # overlap_len: 400 + # tokenizer: 'qwen2.5-72b-instruct' + # trust_remote_code: True + - extract_entity_attribute_mapper: # extract attribute of '李莲花' + api_model: 'qwen2.5-72b-instruct' + query_entities: ['李莲花'] + query_attributes: ["语言风格", "角色性格", "角色能力"] + - extract_nickname_mapper: # extract nicknames in the novel + api_model: 'qwen2.5-72b-instruct' + - extract_event_mapper: # extract events for each chunk + api_model: 'qwen2.5-72b-instruct' + index_key: 'chunk_id' # chunk_id for deduplicating attributes and nicknames + - naive_grouper: # grouper all events + - most_relavant_entities_aggregator: # identify important roles relavant to '李莲花' + api_model: 'qwen2.5-72b-instruct' + entity: '李莲花' + query_entity_type: '人物' + input_key: '__dj__event_description__' + output_key: '__dj__important_relavant_roles__' + From 222790ecdd3aaa94452d17fda0e37e4bd4ad15ac Mon Sep 17 00:00:00 2001 From: Haibin <1400012807@pku.edu.cn> Date: Wed, 27 Nov 2024 12:52:39 +0800 Subject: [PATCH 041/118] system_prompt begin --- .../most_relavant_entities_aggregator.py | 2 +- .../mapper/extract_entity_relation_mapper.py | 2 +- .../role_playing_system_prompt.yaml | 45 ++++++++++++------- .../system_prompt_generator.py | 0 4 files changed, 32 insertions(+), 17 deletions(-) create mode 100644 demos/role_playing_system_prompt/system_prompt_generator.py diff --git a/data_juicer/ops/aggregator/most_relavant_entities_aggregator.py b/data_juicer/ops/aggregator/most_relavant_entities_aggregator.py index d247d48bd..901ccf6a9 100644 --- a/data_juicer/ops/aggregator/most_relavant_entities_aggregator.py +++ b/data_juicer/ops/aggregator/most_relavant_entities_aggregator.py @@ -109,7 +109,7 @@ def __init__(self, self.sampling_params = sampling_params self.model_key = prepare_model(model_type='api', model=api_model, - url=api_endpoint, + endpoint=api_endpoint, response_path=response_path, return_processor=True, **model_params) diff --git a/data_juicer/ops/mapper/extract_entity_relation_mapper.py b/data_juicer/ops/mapper/extract_entity_relation_mapper.py index 085a2da47..4b026f2a4 100644 --- a/data_juicer/ops/mapper/extract_entity_relation_mapper.py +++ b/data_juicer/ops/mapper/extract_entity_relation_mapper.py @@ -222,7 +222,7 @@ def __init__(self, self.sampling_params = sampling_params self.model_key = prepare_model(model_type='api', model=api_model, - url=api_endpoint, + endpoint=api_endpoint, response_path=response_path, **model_params) diff --git a/demos/role_playing_system_prompt/role_playing_system_prompt.yaml b/demos/role_playing_system_prompt/role_playing_system_prompt.yaml index 97fc88464..63a9adc4c 100644 --- a/demos/role_playing_system_prompt/role_playing_system_prompt.yaml +++ b/demos/role_playing_system_prompt/role_playing_system_prompt.yaml @@ -1,5 +1,3 @@ -# Process config example for dataset - # global parameters project_name: 'role-play-demo-process' dataset_path: 'path_to_the_lianhualou_novel_json_file' @@ -8,28 +6,45 @@ np: 1 # number of subprocess to process your dataset export_path: 'path_to_output_jsonl_file' # process schedule -# a list of several process operators with their arguments process: - # - text_chunk_mapper: # if not chunked - # max_len: 8000 - # split_pattern: '\n\n' - # overlap_len: 400 - # tokenizer: 'qwen2.5-72b-instruct' - # trust_remote_code: True - - extract_entity_attribute_mapper: # extract attribute of '李莲花' +# # chunk the novel +# - text_chunk_mapper: +# max_len: 8000 +# split_pattern: '\n\n' +# overlap_len: 400 +# tokenizer: 'qwen2.5-72b-instruct' +# trust_remote_code: True + # extract language_style, role_charactor and role_skill + - extract_entity_attribute_mapper: api_model: 'qwen2.5-72b-instruct' query_entities: ['李莲花'] query_attributes: ["语言风格", "角色性格", "角色能力"] - - extract_nickname_mapper: # extract nicknames in the novel + # extract nickname + - extract_nickname_mapper: api_model: 'qwen2.5-72b-instruct' - - extract_event_mapper: # extract events for each chunk + # extract events + - extract_event_mapper: api_model: 'qwen2.5-72b-instruct' index_key: 'chunk_id' # chunk_id for deduplicating attributes and nicknames - - naive_grouper: # grouper all events - - most_relavant_entities_aggregator: # identify important roles relavant to '李莲花' + # group all events + - naive_grouper:= + # role experiences summary from events + - entity_attribute_aggregator: + api_model: 'qwen2.5-72b-instruct' + entity: '李莲花' + attribute: '主要经历' + input_key: '__dj__event_description__' + output_key: '__dj__role_experience__' + word_limit: 150 + # most relavant roles summary from events + - most_relavant_entities_aggregator: api_model: 'qwen2.5-72b-instruct' entity: '李莲花' query_entity_type: '人物' input_key: '__dj__event_description__' output_key: '__dj__important_relavant_roles__' - + # generate the system prompt + - python_file_mapper: + file_path: 'path_to_system_prompt_gereration_python_file' + function_name: 'get_system_prompt' + \ No newline at end of file diff --git a/demos/role_playing_system_prompt/system_prompt_generator.py b/demos/role_playing_system_prompt/system_prompt_generator.py new file mode 100644 index 000000000..e69de29bb From 75f29113edbcde664c95f9d4cb53fbbe2e199c0f Mon Sep 17 00:00:00 2001 From: "gece.gc" Date: Wed, 27 Nov 2024 08:33:46 +0000 Subject: [PATCH 042/118] support batched --- data_juicer/ops/base_op.py | 5 ++--- data_juicer/ops/mapper/python_file_mapper.py | 13 +++++++++++++ tests/ops/mapper/test_python_file_mapper.py | 12 ++++++++++++ 3 files changed, 27 insertions(+), 3 deletions(-) diff --git a/data_juicer/ops/base_op.py b/data_juicer/ops/base_op.py index 13f3b61ae..edff62b5e 100644 --- a/data_juicer/ops/base_op.py +++ b/data_juicer/ops/base_op.py @@ -166,9 +166,8 @@ def __init__(self, *args, **kwargs): method = wrap_func_with_nested_access(method) setattr(self, name, method) - @classmethod - def is_batched_op(cls): - return cls._batched_op + def is_batched_op(self): + return self._batched_op def process(self, *args, **kwargs): raise NotImplementedError diff --git a/data_juicer/ops/mapper/python_file_mapper.py b/data_juicer/ops/mapper/python_file_mapper.py index 8cc730461..b2dfa9cfa 100644 --- a/data_juicer/ops/mapper/python_file_mapper.py +++ b/data_juicer/ops/mapper/python_file_mapper.py @@ -13,7 +13,9 @@ class PythonFileMapper(Mapper): def __init__(self, file_path: str = '', function_name: str = 'process_single', + batched: bool = False, **kwargs): + self._batched_op = bool(batched) super().__init__(**kwargs) self.file_path = file_path @@ -69,3 +71,14 @@ def process_single(self, sample): ) return result + + def process_batched(self, samples): + """Invoke the loaded function with the provided samples.""" + result = self.func(samples) + + if not isinstance(result, dict): + raise ValueError( + f'Function must return a dictionary, got {type(result).__name__} instead.' # noqa: E501 + ) + + return result diff --git a/tests/ops/mapper/test_python_file_mapper.py b/tests/ops/mapper/test_python_file_mapper.py index 0e491bf3f..8a22e6255 100644 --- a/tests/ops/mapper/test_python_file_mapper.py +++ b/tests/ops/mapper/test_python_file_mapper.py @@ -18,6 +18,18 @@ def test_function_execution(self): result = mapper.process_single({'value': 5}) self.assertEqual(result, {'result': 15}) + def test_function_batched(self): + """Test for a funtion that processes a batch.""" + with tempfile.NamedTemporaryFile(delete=True, suffix='.py', mode='w+') as temp_file: + temp_file.write( + "def process_data(samples):\n" + " return {'result': samples['value'] + [10]}\n" + ) + temp_file.seek(0) # Rewind the file so it can be read + mapper = PythonFileMapper(temp_file.name, "process_data", batched=True) + result = mapper.process_batched({'value': [5]}) + self.assertEqual(result, {'result': [5, 10]}) + def test_function_with_import(self): """Test for a function that contains an import statement.""" with tempfile.NamedTemporaryFile(delete=True, suffix='.py', mode='w+') as temp_file: From 11fa852064f5f06ea7ba420829fb55f748aec941 Mon Sep 17 00:00:00 2001 From: Haibin <1400012807@pku.edu.cn> Date: Wed, 27 Nov 2024 16:58:21 +0800 Subject: [PATCH 043/118] remove unforkable --- data_juicer/ops/aggregator/entity_attribute_aggregator.py | 3 +-- .../ops/aggregator/most_relavant_entities_aggregator.py | 3 +-- data_juicer/ops/aggregator/nested_aggregator.py | 3 +-- data_juicer/ops/mapper/calibrate_qa_mapper.py | 3 +-- data_juicer/ops/mapper/calibrate_query_mapper.py | 3 +-- 5 files changed, 5 insertions(+), 10 deletions(-) diff --git a/data_juicer/ops/aggregator/entity_attribute_aggregator.py b/data_juicer/ops/aggregator/entity_attribute_aggregator.py index 30c26f7c2..96fbbb63f 100644 --- a/data_juicer/ops/aggregator/entity_attribute_aggregator.py +++ b/data_juicer/ops/aggregator/entity_attribute_aggregator.py @@ -4,7 +4,7 @@ from loguru import logger from pydantic import PositiveInt -from data_juicer.ops.base_op import OPERATORS, UNFORKABLE, Aggregator +from data_juicer.ops.base_op import OPERATORS, Aggregator from data_juicer.utils.common_utils import (avg_split_string_list_under_limit, is_string_list, nested_access, nested_set) @@ -20,7 +20,6 @@ # TODO: LLM-based inference. -@UNFORKABLE.register_module(OP_NAME) @OPERATORS.register_module(OP_NAME) class EntityAttributeAggregator(Aggregator): """ diff --git a/data_juicer/ops/aggregator/most_relavant_entities_aggregator.py b/data_juicer/ops/aggregator/most_relavant_entities_aggregator.py index 901ccf6a9..9fa3a0744 100644 --- a/data_juicer/ops/aggregator/most_relavant_entities_aggregator.py +++ b/data_juicer/ops/aggregator/most_relavant_entities_aggregator.py @@ -4,7 +4,7 @@ from loguru import logger from pydantic import PositiveInt -from data_juicer.ops.base_op import OPERATORS, UNFORKABLE, Aggregator +from data_juicer.ops.base_op import OPERATORS, Aggregator from data_juicer.utils.common_utils import (is_string_list, nested_access, nested_set) from data_juicer.utils.lazy_loader import LazyLoader @@ -19,7 +19,6 @@ # TODO: LLM-based inference. -@UNFORKABLE.register_module(OP_NAME) @OPERATORS.register_module(OP_NAME) class MostRelavantEntitiesAggregator(Aggregator): """ diff --git a/data_juicer/ops/aggregator/nested_aggregator.py b/data_juicer/ops/aggregator/nested_aggregator.py index 0306ae7d4..124eb1470 100644 --- a/data_juicer/ops/aggregator/nested_aggregator.py +++ b/data_juicer/ops/aggregator/nested_aggregator.py @@ -3,7 +3,7 @@ from loguru import logger from pydantic import PositiveInt -from data_juicer.ops.base_op import OPERATORS, UNFORKABLE, Aggregator +from data_juicer.ops.base_op import OPERATORS, Aggregator from data_juicer.utils.common_utils import (avg_split_string_list_under_limit, is_string_list, nested_access) from data_juicer.utils.lazy_loader import LazyLoader @@ -16,7 +16,6 @@ # TODO: LLM-based inference. -@UNFORKABLE.register_module(OP_NAME) @OPERATORS.register_module(OP_NAME) class NestedAggregator(Aggregator): """ diff --git a/data_juicer/ops/mapper/calibrate_qa_mapper.py b/data_juicer/ops/mapper/calibrate_qa_mapper.py index 69b860e33..7acf55c55 100644 --- a/data_juicer/ops/mapper/calibrate_qa_mapper.py +++ b/data_juicer/ops/mapper/calibrate_qa_mapper.py @@ -4,14 +4,13 @@ from loguru import logger from pydantic import PositiveInt -from data_juicer.ops.base_op import OPERATORS, UNFORKABLE, Mapper +from data_juicer.ops.base_op import OPERATORS, Mapper from data_juicer.utils.model_utils import get_model, prepare_model OP_NAME = 'calibrate_qa_mapper' # TODO: LLM-based inference. -@UNFORKABLE.register_module(OP_NAME) @OPERATORS.register_module(OP_NAME) class CalibrateQAMapper(Mapper): """ diff --git a/data_juicer/ops/mapper/calibrate_query_mapper.py b/data_juicer/ops/mapper/calibrate_query_mapper.py index 88098d7f8..48ae0c4f7 100644 --- a/data_juicer/ops/mapper/calibrate_query_mapper.py +++ b/data_juicer/ops/mapper/calibrate_query_mapper.py @@ -1,11 +1,10 @@ -from data_juicer.ops.base_op import OPERATORS, UNFORKABLE +from data_juicer.ops.base_op import OPERATORS from data_juicer.ops.mapper.calibrate_qa_mapper import CalibrateQAMapper OP_NAME = 'calibrate_query_mapper' # TODO: LLM-based inference. -@UNFORKABLE.register_module(OP_NAME) @OPERATORS.register_module(OP_NAME) class CalibrateQueryMapper(CalibrateQAMapper): """ From 4af2bfbaf3da8fe482c735ec0e479a2f9406d2f4 Mon Sep 17 00:00:00 2001 From: "gece.gc" Date: Wed, 27 Nov 2024 09:39:26 +0000 Subject: [PATCH 044/118] support batched & add docs --- configs/config_all.yaml | 3 +++ .../ops/mapper/python_lambda_mapper.py | 26 +++++++++++++++++-- docs/Operators.md | 2 +- docs/Operators_ZH.md | 3 ++- tests/ops/mapper/test_python_lambda_mapper.py | 5 ++++ 5 files changed, 35 insertions(+), 4 deletions(-) diff --git a/configs/config_all.yaml b/configs/config_all.yaml index 90fc18875..08291ab47 100644 --- a/configs/config_all.yaml +++ b/configs/config_all.yaml @@ -243,6 +243,9 @@ process: - optimize_query_mapper: # optimize query in question-answer pairs. - optimize_response_mapper: # optimize response in question-answer pairs. - punctuation_normalization_mapper: # normalize unicode punctuations to English punctuations. + - python_lambda_mapper: # executing Python lambda function on data samples. + lambda_str: '' # A string representation of the lambda function to be executed on data samples. If empty, the identity function is used. + batched: False # A boolean indicating whether to process input data in batches. - remove_bibliography_mapper: # remove bibliography from Latex text. - remove_comments_mapper: # remove comments from Latex text, code, etc. doc_type: tex # comment type you want to remove. Only support 'tex' for now. diff --git a/data_juicer/ops/mapper/python_lambda_mapper.py b/data_juicer/ops/mapper/python_lambda_mapper.py index 3aa78d891..e90c77f48 100644 --- a/data_juicer/ops/mapper/python_lambda_mapper.py +++ b/data_juicer/ops/mapper/python_lambda_mapper.py @@ -1,5 +1,4 @@ import ast -from typing import Optional from ..base_op import OPERATORS, Mapper @@ -8,9 +7,21 @@ @OPERATORS.register_module(OP_NAME) class PythonLambdaMapper(Mapper): + """Mapper for executing Python lambda function on data samples.""" - def __init__(self, lambda_str: Optional[str] = None, **kwargs): + def __init__(self, lambda_str: str = '', batched: bool = False, **kwargs): + """ + Initialization method. + + :param lambda_str: A string representation of the lambda function to be + executed on data samples. If empty, the identity function is used. + :param batched: A boolean indicating whether to process input data in + batches. + :param kwargs: Additional keyword arguments passed to the parent class. + """ + self._batched_op = bool(batched) super().__init__(**kwargs) + # Parse and validate the lambda function if not lambda_str: self.lambda_func = lambda sample: sample @@ -50,3 +61,14 @@ def process_single(self, sample): f'got {type(result).__name__} instead.') return result + + def process_batched(self, samples): + # Process the input through the lambda function and return the result + result = self.lambda_func(samples) + + # Check if the result is a valid + if not isinstance(result, dict): + raise ValueError(f'Lambda function must return a dictionary, ' + f'got {type(result).__name__} instead.') + + return result diff --git a/docs/Operators.md b/docs/Operators.md index 7717ba434..0c78f9038 100644 --- a/docs/Operators.md +++ b/docs/Operators.md @@ -11,7 +11,7 @@ The operators in Data-Juicer are categorized into 5 types. | Type | Number | Description | |-----------------------------------|:------:|-------------------------------------------------| | [ Formatter ]( #formatter ) | 9 | Discovers, loads, and canonicalizes source data | -| [ Mapper ]( #mapper ) | 58 | Edits and transforms samples | +| [ Mapper ]( #mapper ) | 59 | Edits and transforms samples | | [ Filter ]( #filter ) | 44 | Filters out low-quality samples | | [ Deduplicator ]( #deduplicator ) | 8 | Detects and removes duplicate samples | | [ Selector ]( #selector ) | 4 | Selects top samples based on ranking | diff --git a/docs/Operators_ZH.md b/docs/Operators_ZH.md index 81aee2149..51d142ec7 100644 --- a/docs/Operators_ZH.md +++ b/docs/Operators_ZH.md @@ -11,7 +11,7 @@ Data-Juicer 中的算子分为以下 5 种类型。 | 类型 | 数量 | 描述 | |------------------------------------|:--:|---------------| | [ Formatter ]( #formatter ) | 9 | 发现、加载、规范化原始数据 | -| [ Mapper ]( #mapper ) | 58 | 对数据样本进行编辑和转换 | +| [ Mapper ]( #mapper ) | 59 | 对数据样本进行编辑和转换 | | [ Filter ]( #filter ) | 44 | 过滤低质量样本 | | [ Deduplicator ]( #deduplicator ) | 8 | 识别、删除重复样本 | | [ Selector ]( #selector ) | 4 | 基于排序选取高质量样本 | @@ -86,6 +86,7 @@ Data-Juicer 中的算子分为以下 5 种类型。 | optimize_query_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) ![GPU](https://img.shields.io/badge/GPU-F27649?style=plastic) | 指令优化,优化 query | [code](../data_juicer/ops/mapper/optimize_query_mapper.py) | [tests](../tests/ops/mapper/test_optimize_query_mapper.py) | | optimize_response_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) ![GPU](https://img.shields.io/badge/GPU-F27649?style=plastic) | 指令优化,优化 response | [code](../data_juicer/ops/mapper/optimize_response_mapper.py) | [tests](../tests/ops/mapper/test_optimize_response_mapper.py) | | punctuation_normalization_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 将各种 Unicode 标点符号标准化为其 ASCII 等效项 | [code](../data_juicer/ops/mapper/punctuation_normalization_mapper.py) | [tests](../tests/ops/mapper/test_punctuation_normalization_mapper.py) | +| python_lambda_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 执行 Python lambda 函数处理样本 | [code](../data_juicer/ops/mapper/python_lambda_mapper.py) | [tests](../tests/ops/mapper/test_python_lambda_mapper.py) | | remove_bibliography_mapper | ![LaTeX](https://img.shields.io/badge/LaTeX-D99379?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 删除 TeX 文档的参考文献 | [code](../data_juicer/ops/mapper/remove_bibliography_mapper.py) | [tests](../tests/ops/mapper/test_remove_bibliography_mapper.py) | | remove_comments_mapper | ![LaTeX](https://img.shields.io/badge/LaTeX-D99379?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 删除 TeX 文档中的注释 | [code](../data_juicer/ops/mapper/remove_comments_mapper.py) | [tests](../tests/ops/mapper/test_remove_comments_mapper.py) | | remove_header_mapper | ![LaTeX](https://img.shields.io/badge/LaTeX-D99379?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 删除 TeX 文档头,例如标题、章节数字/名称等 | [code](../data_juicer/ops/mapper/remove_header_mapper.py) | [tests](../tests/ops/mapper/test_remove_header_mapper.py) | diff --git a/tests/ops/mapper/test_python_lambda_mapper.py b/tests/ops/mapper/test_python_lambda_mapper.py index 2d9d97742..97fac4794 100644 --- a/tests/ops/mapper/test_python_lambda_mapper.py +++ b/tests/ops/mapper/test_python_lambda_mapper.py @@ -5,6 +5,11 @@ class PythonLambdaMapperMapper(DataJuicerTestCaseBase): + def test_lambda_function_batched(self): + mapper = PythonLambdaMapper("lambda d: {'value': d['value'] + [6]}", batched=True) # Append '6' to value + result = mapper.process_batched({'value': [5]}) + self.assertEqual(result, {'value': [5, 6]}) + def test_lambda_modifies_values(self): mapper = PythonLambdaMapper("lambda d: {'value': d['value'] + 1}") # '+1' to 'value' result = mapper.process_single({'value': 5}) From 553d5ad57a293523fbce812ed7ae49de2fac3709 Mon Sep 17 00:00:00 2001 From: "gece.gc" Date: Thu, 28 Nov 2024 03:20:09 +0000 Subject: [PATCH 045/118] add docs --- configs/config_all.yaml | 4 ++++ data_juicer/ops/mapper/python_file_mapper.py | 12 ++++++++++++ docs/Operators.md | 1 + docs/Operators_ZH.md | 1 + 4 files changed, 18 insertions(+) diff --git a/configs/config_all.yaml b/configs/config_all.yaml index 90fc18875..9ac4837a6 100644 --- a/configs/config_all.yaml +++ b/configs/config_all.yaml @@ -243,6 +243,10 @@ process: - optimize_query_mapper: # optimize query in question-answer pairs. - optimize_response_mapper: # optimize response in question-answer pairs. - punctuation_normalization_mapper: # normalize unicode punctuations to English punctuations. + - python_python_mapper: # executing Python lambda function defined in a file. + file_path: '' # The path to the Python file containing the function to be executed. + function_name: 'process_single' # The name of the function defined in the file to be executed. + batched: False # A boolean indicating whether to process input data in batches. - remove_bibliography_mapper: # remove bibliography from Latex text. - remove_comments_mapper: # remove comments from Latex text, code, etc. doc_type: tex # comment type you want to remove. Only support 'tex' for now. diff --git a/data_juicer/ops/mapper/python_file_mapper.py b/data_juicer/ops/mapper/python_file_mapper.py index b2dfa9cfa..6858fae55 100644 --- a/data_juicer/ops/mapper/python_file_mapper.py +++ b/data_juicer/ops/mapper/python_file_mapper.py @@ -9,12 +9,24 @@ @OPERATORS.register_module(OP_NAME) class PythonFileMapper(Mapper): + """Mapper for executing Python function defined in a file.""" def __init__(self, file_path: str = '', function_name: str = 'process_single', batched: bool = False, **kwargs): + """ + Initialization method. + + :param file_path: The path to the Python file containing the function + to be executed. + :param function_name: The name of the function defined in the file + to be executed. + :param batched: A boolean indicating whether to process input data in + batches. + :param kwargs: Additional keyword arguments passed to the parent class. + """ self._batched_op = bool(batched) super().__init__(**kwargs) diff --git a/docs/Operators.md b/docs/Operators.md index 7717ba434..72043a76b 100644 --- a/docs/Operators.md +++ b/docs/Operators.md @@ -87,6 +87,7 @@ All the specific operators are listed below, each featured with several capabili | optimize_query_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) ![GPU](https://img.shields.io/badge/GPU-F27649?style=plastic) | Optimize the query in question-answering samples. | [code](../data_juicer/ops/mapper/optimize_query_mapper.py) | [tests](../tests/ops/mapper/test_optimize_query_mapper.py) | | optimize_response_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) ![GPU](https://img.shields.io/badge/GPU-F27649?style=plastic) | Optimize the response in question-answering samples. | [code](../data_juicer/ops/mapper/optimize_response_mapper.py) | [tests](../tests/ops/mapper/test_optimize_response_mapper.py) | | punctuation_normalization_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Normalizes various Unicode punctuations to their ASCII equivalents | [code](../data_juicer/ops/mapper/punctuation_normalization_mapper.py) | [tests](../tests/ops/mapper/test_punctuation_normalization_mapper.py) | +| python_file_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Executing Python function defined in a file | [code](../data_juicer/ops/mapper/python_file_mapper.py) | [tests](../tests/ops/mapper/test_python_file_mapper.py) | | remove_bibliography_mapper | ![LaTeX](https://img.shields.io/badge/LaTeX-D99379?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Removes the bibliography of TeX documents | [code](../data_juicer/ops/mapper/remove_bibliography_mapper.py) | [tests](../tests/ops/mapper/test_remove_bibliography_mapper.py) | | remove_comments_mapper | ![LaTeX](https://img.shields.io/badge/LaTeX-D99379?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Removes the comments of TeX documents | [code](../data_juicer/ops/mapper/remove_comments_mapper.py) | [tests](../tests/ops/mapper/test_remove_comments_mapper.py) | | remove_header_mapper | ![LaTeX](https://img.shields.io/badge/LaTeX-D99379?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Removes the running headers of TeX documents, e.g., titles, chapter or section numbers/names | [code](../data_juicer/ops/mapper/remove_header_mapper.py) | [tests](../tests/ops/mapper/test_remove_header_mapper.py) | diff --git a/docs/Operators_ZH.md b/docs/Operators_ZH.md index 81aee2149..87c305471 100644 --- a/docs/Operators_ZH.md +++ b/docs/Operators_ZH.md @@ -86,6 +86,7 @@ Data-Juicer 中的算子分为以下 5 种类型。 | optimize_query_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) ![GPU](https://img.shields.io/badge/GPU-F27649?style=plastic) | 指令优化,优化 query | [code](../data_juicer/ops/mapper/optimize_query_mapper.py) | [tests](../tests/ops/mapper/test_optimize_query_mapper.py) | | optimize_response_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) ![GPU](https://img.shields.io/badge/GPU-F27649?style=plastic) | 指令优化,优化 response | [code](../data_juicer/ops/mapper/optimize_response_mapper.py) | [tests](../tests/ops/mapper/test_optimize_response_mapper.py) | | punctuation_normalization_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 将各种 Unicode 标点符号标准化为其 ASCII 等效项 | [code](../data_juicer/ops/mapper/punctuation_normalization_mapper.py) | [tests](../tests/ops/mapper/test_punctuation_normalization_mapper.py) | +| python_file_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 执行文件中定义的 Python 函数处理样本 | [code](../data_juicer/ops/mapper/python_file_mapper.py) | [tests](../tests/ops/mapper/test_python_file_mapper.py) | | remove_bibliography_mapper | ![LaTeX](https://img.shields.io/badge/LaTeX-D99379?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 删除 TeX 文档的参考文献 | [code](../data_juicer/ops/mapper/remove_bibliography_mapper.py) | [tests](../tests/ops/mapper/test_remove_bibliography_mapper.py) | | remove_comments_mapper | ![LaTeX](https://img.shields.io/badge/LaTeX-D99379?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 删除 TeX 文档中的注释 | [code](../data_juicer/ops/mapper/remove_comments_mapper.py) | [tests](../tests/ops/mapper/test_remove_comments_mapper.py) | | remove_header_mapper | ![LaTeX](https://img.shields.io/badge/LaTeX-D99379?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 删除 TeX 文档头,例如标题、章节数字/名称等 | [code](../data_juicer/ops/mapper/remove_header_mapper.py) | [tests](../tests/ops/mapper/test_remove_header_mapper.py) | From 470ca195b080a457a4b18f9914f8a7d12339fd70 Mon Sep 17 00:00:00 2001 From: "gece.gc" Date: Thu, 28 Nov 2024 03:21:24 +0000 Subject: [PATCH 046/118] fix docs --- docs/Operators.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/Operators.md b/docs/Operators.md index b1bbad373..218f883be 100644 --- a/docs/Operators.md +++ b/docs/Operators.md @@ -88,7 +88,7 @@ All the specific operators are listed below, each featured with several capabili | optimize_response_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) ![GPU](https://img.shields.io/badge/GPU-F27649?style=plastic) | Optimize the response in question-answering samples. | [code](../data_juicer/ops/mapper/optimize_response_mapper.py) | [tests](../tests/ops/mapper/test_optimize_response_mapper.py) | | pair_preference_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Construct paired preference samples. | [code](../data_juicer/ops/mapper/pair_preference_mapper.py) | [tests](../tests/ops/mapper/test_pair_preference_mapper.py) | | punctuation_normalization_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Normalizes various Unicode punctuations to their ASCII equivalents | [code](../data_juicer/ops/mapper/punctuation_normalization_mapper.py) | [tests](../tests/ops/mapper/test_punctuation_normalization_mapper.py) | -| python_lambda_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 执行 Python lambda 函数处理样本 | [code](../data_juicer/ops/mapper/python_lambda_mapper.py) | [tests](../tests/ops/mapper/test_python_lambda_mapper.py) | +| python_lambda_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Executing Python lambda function on data samples | [code](../data_juicer/ops/mapper/python_lambda_mapper.py) | [tests](../tests/ops/mapper/test_python_lambda_mapper.py) | | remove_bibliography_mapper | ![LaTeX](https://img.shields.io/badge/LaTeX-D99379?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Removes the bibliography of TeX documents | [code](../data_juicer/ops/mapper/remove_bibliography_mapper.py) | [tests](../tests/ops/mapper/test_remove_bibliography_mapper.py) | | remove_comments_mapper | ![LaTeX](https://img.shields.io/badge/LaTeX-D99379?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Removes the comments of TeX documents | [code](../data_juicer/ops/mapper/remove_comments_mapper.py) | [tests](../tests/ops/mapper/test_remove_comments_mapper.py) | | remove_header_mapper | ![LaTeX](https://img.shields.io/badge/LaTeX-D99379?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Removes the running headers of TeX documents, e.g., titles, chapter or section numbers/names | [code](../data_juicer/ops/mapper/remove_header_mapper.py) | [tests](../tests/ops/mapper/test_remove_header_mapper.py) | From 399a238c5e692d9dd6c48b9c702e2c28f3320a46 Mon Sep 17 00:00:00 2001 From: "gece.gc" Date: Thu, 28 Nov 2024 03:26:21 +0000 Subject: [PATCH 047/118] update docs --- docs/Operators.md | 2 +- docs/Operators_ZH.md | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/Operators.md b/docs/Operators.md index 72043a76b..2874e14db 100644 --- a/docs/Operators.md +++ b/docs/Operators.md @@ -11,7 +11,7 @@ The operators in Data-Juicer are categorized into 5 types. | Type | Number | Description | |-----------------------------------|:------:|-------------------------------------------------| | [ Formatter ]( #formatter ) | 9 | Discovers, loads, and canonicalizes source data | -| [ Mapper ]( #mapper ) | 58 | Edits and transforms samples | +| [ Mapper ]( #mapper ) | 60 | Edits and transforms samples | | [ Filter ]( #filter ) | 44 | Filters out low-quality samples | | [ Deduplicator ]( #deduplicator ) | 8 | Detects and removes duplicate samples | | [ Selector ]( #selector ) | 4 | Selects top samples based on ranking | diff --git a/docs/Operators_ZH.md b/docs/Operators_ZH.md index 87c305471..bb60624bb 100644 --- a/docs/Operators_ZH.md +++ b/docs/Operators_ZH.md @@ -11,7 +11,7 @@ Data-Juicer 中的算子分为以下 5 种类型。 | 类型 | 数量 | 描述 | |------------------------------------|:--:|---------------| | [ Formatter ]( #formatter ) | 9 | 发现、加载、规范化原始数据 | -| [ Mapper ]( #mapper ) | 58 | 对数据样本进行编辑和转换 | +| [ Mapper ]( #mapper ) | 60 | 对数据样本进行编辑和转换 | | [ Filter ]( #filter ) | 44 | 过滤低质量样本 | | [ Deduplicator ]( #deduplicator ) | 8 | 识别、删除重复样本 | | [ Selector ]( #selector ) | 4 | 基于排序选取高质量样本 | From 115ab9aad2e44dd30a84dacdcddb3c4c134f0e4d Mon Sep 17 00:00:00 2001 From: Haibin <1400012807@pku.edu.cn> Date: Thu, 28 Nov 2024 17:48:10 +0800 Subject: [PATCH 048/118] pre-commit done --- data_juicer/ops/aggregator/__init__.py | 3 +- .../most_relavant_entities_aggregator.py | 2 +- .../aggregator/relation_idenity_aggregator.py | 155 ++++++++++++++ .../ops/mapper/calibrate_response_mapper.py | 3 +- .../mapper/extract_entity_attribute_mapper.py | 61 ++---- .../mapper/extract_entity_relation_mapper.py | 3 +- .../ops/mapper/extract_event_mapper.py | 3 +- .../ops/mapper/extract_keyword_mapper.py | 3 +- .../ops/mapper/extract_nickname_mapper.py | 3 +- .../ops/mapper/extract_support_text_mapper.py | 3 +- .../generate_qa_from_examples_mapper.py | 3 +- .../mapper/generate_qa_from_text_mapper.py | 3 +- data_juicer/ops/mapper/optimize_qa_mapper.py | 3 +- .../ops/mapper/optimize_query_mapper.py | 3 +- .../ops/mapper/optimize_response_mapper.py | 3 +- data_juicer/utils/constant.py | 16 +- .../role_playing_system_prompt.yaml | 25 ++- .../system_prompt_generator.py | 190 ++++++++++++++++++ .../test_extract_entity_attribute_mapper.py | 11 +- 19 files changed, 413 insertions(+), 83 deletions(-) create mode 100644 data_juicer/ops/aggregator/relation_idenity_aggregator.py diff --git a/data_juicer/ops/aggregator/__init__.py b/data_juicer/ops/aggregator/__init__.py index 4afe2974a..c411f7cea 100644 --- a/data_juicer/ops/aggregator/__init__.py +++ b/data_juicer/ops/aggregator/__init__.py @@ -1,8 +1,9 @@ from .entity_attribute_aggregator import EntityAttributeAggregator from .most_relavant_entities_aggregator import MostRelavantEntitiesAggregator from .nested_aggregator import NestedAggregator +from .relation_idenity_aggregator import RelationIdentityAggregator __all__ = [ 'NestedAggregator', 'EntityAttributeAggregator', - 'MostRelavantEntitiesAggregator' + 'MostRelavantEntitiesAggregator', 'RelationIdentityAggregator' ] diff --git a/data_juicer/ops/aggregator/most_relavant_entities_aggregator.py b/data_juicer/ops/aggregator/most_relavant_entities_aggregator.py index 9fa3a0744..d52a9a1fc 100644 --- a/data_juicer/ops/aggregator/most_relavant_entities_aggregator.py +++ b/data_juicer/ops/aggregator/most_relavant_entities_aggregator.py @@ -30,7 +30,7 @@ class MostRelavantEntitiesAggregator(Aggregator): '总结一些与`{entity}`最为相关的`{entity_type}`。\n' '要求:\n' '- 不用包含与{entity}为同一{entity_type}的{entity_type}。\n' - '- 请按照人物的重要性进行排序,越重要人物在列表越前面。\n' + '- 请按照人物的重要性进行排序,**越重要人物在列表越前面**。\n' '- 你的返回格式如下:\n' '## 分析\n' '你对各个{entity_type}与{entity}关联度的分析\n' diff --git a/data_juicer/ops/aggregator/relation_idenity_aggregator.py b/data_juicer/ops/aggregator/relation_idenity_aggregator.py new file mode 100644 index 000000000..1b94b9212 --- /dev/null +++ b/data_juicer/ops/aggregator/relation_idenity_aggregator.py @@ -0,0 +1,155 @@ +import re +from typing import Dict, Optional + +from loguru import logger +from pydantic import PositiveInt + +from data_juicer.ops.base_op import OPERATORS, Aggregator +from data_juicer.utils.common_utils import nested_access, nested_set +from data_juicer.utils.model_utils import get_model, prepare_model + +OP_NAME = 'relation_identity_aggregator' + + +# TODO: LLM-based inference. +@OPERATORS.register_module(OP_NAME) +class RelationIdentityAggregator(Aggregator): + """ + identify relation between two entity in the text. + """ + + DEFAULT_SYSTEM_PROMPT_TEMPLATE = ( + '给定关于{entity1}和{entity2}的文本信息。' + '判断{entity1}和{entity2}之间的关系。\n' + '要求:\n' + '- 关系用一个或多个词语表示,必要时可以加一个形容词来描述这段关系\n' + '- 输出关系时不要参杂任何标点符号\n' + '- 需要你进行合理的推理才能得出结论\n' + '- 如果两个人物身份是同一个人,输出关系为:另一个身份\n' + '- 输出格式为:\n' + '分析推理:...\n' + '所以{entity2}是{entity1}的:...\n' + '- 注意输出的是{entity2}是{entity1}的什么关系,而不是{entity1}是{entity2}的什么关系') + DEFAULT_INPUT_TEMPLATE = '关于{entity1}和{entity2}的文本信息:\n```\n{text}\n```\n' + DEFAULT_OUTPUT_PATTERN_TEMPLATE = r""" + \s*分析推理:\s*(.*?)\s* + \s*所以{entity2}是{entity1}的:\s*(.*?)\Z + """ + + def __init__(self, + api_model: str = 'gpt-4o', + source_entity: str = None, + target_entity: str = None, + input_key: str = None, + output_key: str = None, + *, + api_endpoint: Optional[str] = None, + response_path: Optional[str] = None, + system_prompt_template: Optional[str] = None, + input_template: Optional[str] = None, + output_pattern_template: Optional[str] = None, + try_num: PositiveInt = 3, + drop_text: bool = False, + model_params: Dict = {}, + sampling_params: Dict = {}, + **kwargs): + """ + Initialization method. + :param api_model: API model name. + :param source_entity: The source entity of the relation to be + identified. + :param api_endpoint: The target entity of the relation to be + identified. + :param input_key: The input field key in the samples. Support for + nested keys such as "__dj__stats__.text_len". It is text_key + in default. + :param output_key: The output field key in the samples. Support + for nested keys such as "__dj__stats__.text_len". It is + "__dj__relation_identity__". + :param api_endpoint: URL endpoint for the API. + :param response_path: Path to extract content from the API response. + Defaults to 'choices.0.message.content'. + :param system_prompt_template: System prompt template for the task. + :param input_template: Template for building the model input. + :param output_pattern_template: Regular expression template for + parsing model output. + :param try_num: The number of retry attempts when there is an API + call error or output parsing error. + :param drop_text: If drop the text in the output. + :param model_params: Parameters for initializing the API model. + :param sampling_params: Extra parameters passed to the API call. + e.g {'temperature': 0.9, 'top_p': 0.95} + :param kwargs: Extra keyword arguments. + """ + super().__init__(**kwargs) + + if source_entity is None or target_entity is None: + logger.warning('source_entity and target_entity cannot be None') + + self.source_entity = source_entity + self.target_entity = target_entity + + self.input_key = input_key or self.text_key + self.output_key = output_key or self.input_key + + system_prompt_template = system_prompt_template or \ + self.DEFAULT_SYSTEM_PROMPT_TEMPLATE + self.system_prompt = system_prompt_template.format( + entity1=source_entity, entity2=target_entity) + self.input_template = input_template or self.DEFAULT_INPUT_TEMPLATE + output_pattern_template = output_pattern_template or \ + self.DEFAULT_OUTPUT_PATTERN_TEMPLATE + self.output_pattern = output_pattern_template.format( + entity1=source_entity, entity2=target_entity) + + self.sampling_params = sampling_params + self.model_key = prepare_model(model_type='api', + model=api_model, + endpoint=api_endpoint, + response_path=response_path, + **model_params) + + self.try_num = try_num + self.drop_text = drop_text + + def parse_output(self, raw_output): + pattern = re.compile(self.output_pattern, re.VERBOSE | re.DOTALL) + matches = pattern.findall(raw_output) + + relation = '' + + for match in matches: + _, relation = match + relation = relation.strip() + + return relation + + def process_single(self, sample, rank=None): + client = get_model(self.model_key, rank=rank) + + text = nested_access(sample, self.input_key) + input_prompt = self.input_template.format(entity1=self.source_entity, + entity2=self.target_entity, + text=text) + messages = [{ + 'role': 'system', + 'content': self.system_prompt + }, { + 'role': 'user', + 'content': input_prompt + }] + relation = '' + for i in range(self.try_num): + try: + output = client(messages, **self.sampling_params) + relation = self.parse_output(output) + if len(relation) > 0: + break + except Exception as e: + logger.warning(f'Exception: {e}') + + sample = nested_set(sample, self.output_key, relation) + if self.drop_text: + sample.pop(self.text_key) + + return sample diff --git a/data_juicer/ops/mapper/calibrate_response_mapper.py b/data_juicer/ops/mapper/calibrate_response_mapper.py index db56af317..1d6456c2b 100644 --- a/data_juicer/ops/mapper/calibrate_response_mapper.py +++ b/data_juicer/ops/mapper/calibrate_response_mapper.py @@ -1,11 +1,10 @@ -from data_juicer.ops.base_op import OPERATORS, UNFORKABLE +from data_juicer.ops.base_op import OPERATORS from data_juicer.ops.mapper.calibrate_qa_mapper import CalibrateQAMapper OP_NAME = 'calibrate_response_mapper' # TODO: LLM-based inference. -@UNFORKABLE.register_module(OP_NAME) @OPERATORS.register_module(OP_NAME) class CalibrateResponseMapper(CalibrateQAMapper): """ diff --git a/data_juicer/ops/mapper/extract_entity_attribute_mapper.py b/data_juicer/ops/mapper/extract_entity_attribute_mapper.py index 46370a001..6f8a2099d 100644 --- a/data_juicer/ops/mapper/extract_entity_attribute_mapper.py +++ b/data_juicer/ops/mapper/extract_entity_attribute_mapper.py @@ -1,11 +1,10 @@ import re -from itertools import chain from typing import Dict, List, Optional from loguru import logger from pydantic import PositiveInt -from data_juicer.ops.base_op import OPERATORS, UNFORKABLE, Mapper +from data_juicer.ops.base_op import OPERATORS, Mapper from data_juicer.utils.constant import Fields from data_juicer.utils.model_utils import get_model, prepare_model @@ -13,26 +12,24 @@ # TODO: LLM-based inference. -@UNFORKABLE.register_module(OP_NAME) @OPERATORS.register_module(OP_NAME) class ExtractEntityAttributeMapper(Mapper): """ Extract attributes for given entities from the text """ - _batched_op = True - DEFAULT_SYSTEM_PROMPT_TEMPLATE = ( '给定一段文本,从文本中总结{entity}的{attribute},并且从原文摘录最能说明该{attribute}的代表性示例。\n' '要求:\n' '- 摘录的示例应该简短。\n' '- 遵循如下的回复格式:\n' + '# {entity}\n' '## {attribute}:\n' - '{entity}的{attribute}描述...\n' - '### 代表性示例1:\n' - '说明{entity}该{attribute}的原文摘录1...\n' - '### 代表性示例2:\n' - '说明{entity}该{attribute}的原文摘录2...\n' + '...\n' + '### 代表性示例摘录1:\n' + '...\n' + '### 代表性示例摘录2:\n' + '...\n' '...\n') DEFAULT_INPUT_TEMPLATE = '# 文本\n```\n{text}\n```\n' @@ -44,10 +41,10 @@ def __init__(self, query_entities: List[str] = [], query_attributes: List[str] = [], *, - entity_key: str = Fields.main_entity, - attribute_key: str = Fields.attribute, - attribute_desc_key: str = Fields.attribute_description, - support_text_key: str = Fields.attribute_support_text, + entity_key: str = Fields.main_entities, + attribute_key: str = Fields.attributes, + attribute_desc_key: str = Fields.attribute_descriptions, + support_text_key: str = Fields.attribute_support_texts, api_endpoint: Optional[str] = None, response_path: Optional[str] = None, system_prompt_template: Optional[str] = None, @@ -111,7 +108,7 @@ def __init__(self, self.sampling_params = sampling_params self.model_key = prepare_model(model_type='api', - api_model=api_model, + model=api_model, endpoint=api_endpoint, response_path=response_path, **model_params) @@ -136,7 +133,7 @@ def parse_output(self, raw_output, attribute_name): return attribute, demos - def _process_single_sample(self, text='', rank=None): + def _process_single_text(self, text='', rank=None): client = get_model(self.model_key, rank=rank) entities, attributes, descs, demo_lists = [], [], [], [] @@ -169,31 +166,17 @@ def _process_single_sample(self, text='', rank=None): return entities, attributes, descs, demo_lists - def process_batched(self, samples, rank=None): - - sample_num = len(samples[self.text_key]) + def process_single(self, sample, rank=None): - entities, attributes, descs, demo_lists = [], [], [], [] - for text in samples[self.text_key]: - res = self._process_single_sample(text, rank=rank) - cur_ents, cur_attrs, cur_descs, cur_demos = res - entities.append(cur_ents) - attributes.append(cur_attrs) - descs.append(cur_descs) - demo_lists.append(cur_demos) + res = self._process_single_text(sample[self.text_key], rank=rank) + entities, attributes, descs, demo_lists = res if self.drop_text: - samples.pop(self.text_key) - - for key in samples: - samples[key] = [[samples[key][i]] * len(descs[i]) - for i in range(sample_num)] - samples[self.entity_key] = entities - samples[self.attribute_key] = attributes - samples[self.attribute_desc_key] = descs - samples[self.support_text_key] = demo_lists + sample.pop(self.text_key) - for key in samples: - samples[key] = list(chain(*samples[key])) + sample[self.entity_key] = entities + sample[self.attribute_key] = attributes + sample[self.attribute_desc_key] = descs + sample[self.support_text_key] = demo_lists - return samples + return sample diff --git a/data_juicer/ops/mapper/extract_entity_relation_mapper.py b/data_juicer/ops/mapper/extract_entity_relation_mapper.py index 4b026f2a4..18606456c 100644 --- a/data_juicer/ops/mapper/extract_entity_relation_mapper.py +++ b/data_juicer/ops/mapper/extract_entity_relation_mapper.py @@ -9,7 +9,7 @@ from loguru import logger from pydantic import NonNegativeInt, PositiveInt -from data_juicer.ops.base_op import OPERATORS, UNFORKABLE, Mapper +from data_juicer.ops.base_op import OPERATORS, Mapper from data_juicer.utils.common_utils import is_float from data_juicer.utils.constant import Fields from data_juicer.utils.model_utils import get_model, prepare_model @@ -20,7 +20,6 @@ # TODO: LLM-based inference. -@UNFORKABLE.register_module(OP_NAME) @OPERATORS.register_module(OP_NAME) class ExtractEntityRelationMapper(Mapper): """ diff --git a/data_juicer/ops/mapper/extract_event_mapper.py b/data_juicer/ops/mapper/extract_event_mapper.py index 208684b2c..d9972914e 100644 --- a/data_juicer/ops/mapper/extract_event_mapper.py +++ b/data_juicer/ops/mapper/extract_event_mapper.py @@ -5,7 +5,7 @@ from loguru import logger from pydantic import PositiveInt -from data_juicer.ops.base_op import OPERATORS, UNFORKABLE, Mapper +from data_juicer.ops.base_op import OPERATORS, Mapper from data_juicer.utils.constant import Fields from data_juicer.utils.model_utils import get_model, prepare_model @@ -15,7 +15,6 @@ # TODO: LLM-based inference. -@UNFORKABLE.register_module(OP_NAME) @OPERATORS.register_module(OP_NAME) class ExtractEventMapper(Mapper): """ diff --git a/data_juicer/ops/mapper/extract_keyword_mapper.py b/data_juicer/ops/mapper/extract_keyword_mapper.py index cb1814768..0a21aa0bb 100644 --- a/data_juicer/ops/mapper/extract_keyword_mapper.py +++ b/data_juicer/ops/mapper/extract_keyword_mapper.py @@ -6,7 +6,7 @@ from loguru import logger from pydantic import PositiveInt -from data_juicer.ops.base_op import OPERATORS, UNFORKABLE, Mapper +from data_juicer.ops.base_op import OPERATORS, Mapper from data_juicer.utils.constant import Fields from data_juicer.utils.model_utils import get_model, prepare_model @@ -16,7 +16,6 @@ # TODO: LLM-based inference. -@UNFORKABLE.register_module(OP_NAME) @OPERATORS.register_module(OP_NAME) class ExtractKeywordMapper(Mapper): """ diff --git a/data_juicer/ops/mapper/extract_nickname_mapper.py b/data_juicer/ops/mapper/extract_nickname_mapper.py index b11cbab57..5029d3be0 100644 --- a/data_juicer/ops/mapper/extract_nickname_mapper.py +++ b/data_juicer/ops/mapper/extract_nickname_mapper.py @@ -4,7 +4,7 @@ from loguru import logger from pydantic import PositiveInt -from data_juicer.ops.base_op import OPERATORS, UNFORKABLE, Mapper +from data_juicer.ops.base_op import OPERATORS, Mapper from data_juicer.utils.constant import Fields from data_juicer.utils.model_utils import get_model, prepare_model @@ -12,7 +12,6 @@ # TODO: LLM-based inference. -@UNFORKABLE.register_module(OP_NAME) @OPERATORS.register_module(OP_NAME) class ExtractNicknameMapper(Mapper): """ diff --git a/data_juicer/ops/mapper/extract_support_text_mapper.py b/data_juicer/ops/mapper/extract_support_text_mapper.py index d75f96fc9..011582a30 100644 --- a/data_juicer/ops/mapper/extract_support_text_mapper.py +++ b/data_juicer/ops/mapper/extract_support_text_mapper.py @@ -3,7 +3,7 @@ from loguru import logger from pydantic import PositiveInt -from data_juicer.ops.base_op import OPERATORS, UNFORKABLE, Mapper +from data_juicer.ops.base_op import OPERATORS, Mapper from data_juicer.utils.common_utils import nested_access, nested_set from data_juicer.utils.constant import Fields from data_juicer.utils.model_utils import get_model, prepare_model @@ -12,7 +12,6 @@ # TODO: LLM-based inference. -@UNFORKABLE.register_module(OP_NAME) @OPERATORS.register_module(OP_NAME) class ExtractSupportTextMapper(Mapper): """ diff --git a/data_juicer/ops/mapper/generate_qa_from_examples_mapper.py b/data_juicer/ops/mapper/generate_qa_from_examples_mapper.py index 6f5ad7dab..0c0d084b3 100644 --- a/data_juicer/ops/mapper/generate_qa_from_examples_mapper.py +++ b/data_juicer/ops/mapper/generate_qa_from_examples_mapper.py @@ -9,7 +9,7 @@ from data_juicer.utils.lazy_loader import LazyLoader from data_juicer.utils.model_utils import get_model, prepare_model -from ..base_op import OPERATORS, UNFORKABLE, Mapper +from ..base_op import OPERATORS, Mapper torch = LazyLoader('torch', 'torch') vllm = LazyLoader('vllm', 'vllm') @@ -19,7 +19,6 @@ # TODO: Extend LLM-based OPs into API-based implementation. -@UNFORKABLE.register_module(OP_NAME) @OPERATORS.register_module(OP_NAME) class GenerateQAFromExamplesMapper(Mapper): """ diff --git a/data_juicer/ops/mapper/generate_qa_from_text_mapper.py b/data_juicer/ops/mapper/generate_qa_from_text_mapper.py index 248dba428..0f3a1cfef 100644 --- a/data_juicer/ops/mapper/generate_qa_from_text_mapper.py +++ b/data_juicer/ops/mapper/generate_qa_from_text_mapper.py @@ -3,7 +3,7 @@ from loguru import logger -from data_juicer.ops.base_op import OPERATORS, UNFORKABLE, Mapper +from data_juicer.ops.base_op import OPERATORS, Mapper from data_juicer.utils.lazy_loader import LazyLoader from data_juicer.utils.model_utils import get_model, prepare_model @@ -14,7 +14,6 @@ # TODO: Extend LLM-based OPs into API-based implementation. -@UNFORKABLE.register_module(OP_NAME) @OPERATORS.register_module(OP_NAME) class GenerateQAFromTextMapper(Mapper): """ diff --git a/data_juicer/ops/mapper/optimize_qa_mapper.py b/data_juicer/ops/mapper/optimize_qa_mapper.py index 3563a112b..974730ec5 100644 --- a/data_juicer/ops/mapper/optimize_qa_mapper.py +++ b/data_juicer/ops/mapper/optimize_qa_mapper.py @@ -3,7 +3,7 @@ from loguru import logger -from data_juicer.ops.base_op import OPERATORS, UNFORKABLE, Mapper +from data_juicer.ops.base_op import OPERATORS, Mapper from data_juicer.utils.lazy_loader import LazyLoader from data_juicer.utils.model_utils import get_model, prepare_model @@ -14,7 +14,6 @@ # TODO: Extend LLM-based OPs into API-based implementation. -@UNFORKABLE.register_module(OP_NAME) @OPERATORS.register_module(OP_NAME) class OptimizeQAMapper(Mapper): """ diff --git a/data_juicer/ops/mapper/optimize_query_mapper.py b/data_juicer/ops/mapper/optimize_query_mapper.py index dd227b4c1..9ccd84bb1 100644 --- a/data_juicer/ops/mapper/optimize_query_mapper.py +++ b/data_juicer/ops/mapper/optimize_query_mapper.py @@ -1,11 +1,10 @@ -from data_juicer.ops.base_op import OPERATORS, UNFORKABLE +from data_juicer.ops.base_op import OPERATORS from data_juicer.ops.mapper.optimize_qa_mapper import OptimizeQAMapper OP_NAME = 'optimize_query_mapper' # TODO: Extend LLM-based OPs into API-based implementation. -@UNFORKABLE.register_module(OP_NAME) @OPERATORS.register_module(OP_NAME) class OptimizeQueryMapper(OptimizeQAMapper): """ diff --git a/data_juicer/ops/mapper/optimize_response_mapper.py b/data_juicer/ops/mapper/optimize_response_mapper.py index 158159a9d..f6026b8dc 100644 --- a/data_juicer/ops/mapper/optimize_response_mapper.py +++ b/data_juicer/ops/mapper/optimize_response_mapper.py @@ -1,11 +1,10 @@ -from data_juicer.ops.base_op import OPERATORS, UNFORKABLE +from data_juicer.ops.base_op import OPERATORS from data_juicer.ops.mapper.optimize_qa_mapper import OptimizeQAMapper OP_NAME = 'optimize_response_mapper' # TODO: Extend LLM-based OPs into API-based implementation. -@UNFORKABLE.register_module(OP_NAME) @OPERATORS.register_module(OP_NAME) class OptimizeResponseMapper(OptimizeQAMapper): """ diff --git a/data_juicer/utils/constant.py b/data_juicer/utils/constant.py index c6e62c8fc..350181f41 100644 --- a/data_juicer/utils/constant.py +++ b/data_juicer/utils/constant.py @@ -32,14 +32,14 @@ class Fields(object): event_description = DEFAULT_PREFIX + 'event_description__' # # a list of characters relevant to the event relevant_characters = DEFAULT_PREFIX + 'relevant_characters__' - # # the given main entity for attribute extraction - main_entity = DEFAULT_PREFIX + 'main_entity__' - # # the given attribute to be extracted - attribute = DEFAULT_PREFIX + 'attribute__' - # # the extracted attribute description - attribute_description = DEFAULT_PREFIX + 'attribute_description__' - # # extract from raw data for support the attribute - attribute_support_text = DEFAULT_PREFIX + 'attribute_support_text__' + # # the given main entities for attribute extraction + main_entities = DEFAULT_PREFIX + 'main_entities__' + # # the given attributes to be extracted + attributes = DEFAULT_PREFIX + 'attributes__' + # # the extracted attribute descriptions + attribute_descriptions = DEFAULT_PREFIX + 'attribute_descriptions__' + # # extract from raw datas for support the attribute + attribute_support_texts = DEFAULT_PREFIX + 'attribute_support_texts__' # # the nickname relationship nickname = DEFAULT_PREFIX + 'nickname__' # # the entity for knowledge graph diff --git a/demos/role_playing_system_prompt/role_playing_system_prompt.yaml b/demos/role_playing_system_prompt/role_playing_system_prompt.yaml index 63a9adc4c..f0406ec64 100644 --- a/demos/role_playing_system_prompt/role_playing_system_prompt.yaml +++ b/demos/role_playing_system_prompt/role_playing_system_prompt.yaml @@ -7,18 +7,18 @@ export_path: 'path_to_output_jsonl_file' # process schedule process: -# # chunk the novel -# - text_chunk_mapper: -# max_len: 8000 -# split_pattern: '\n\n' -# overlap_len: 400 -# tokenizer: 'qwen2.5-72b-instruct' -# trust_remote_code: True + # chunk the novel + - text_chunk_mapper: + max_len: 8000 + split_pattern: '\n\n' + overlap_len: 400 + tokenizer: 'qwen2.5-72b-instruct' + trust_remote_code: True # extract language_style, role_charactor and role_skill - extract_entity_attribute_mapper: api_model: 'qwen2.5-72b-instruct' query_entities: ['李莲花'] - query_attributes: ["语言风格", "角色性格", "角色能力"] + query_attributes: ["角色性格", "角色武艺和能力", "语言风格"] # extract nickname - extract_nickname_mapper: api_model: 'qwen2.5-72b-instruct' @@ -27,8 +27,15 @@ process: api_model: 'qwen2.5-72b-instruct' index_key: 'chunk_id' # chunk_id for deduplicating attributes and nicknames # group all events - - naive_grouper:= + - naive_grouper: # role experiences summary from events + - entity_attribute_aggregator: + api_model: 'qwen2.5-72b-instruct' + entity: '李莲花' + attribute: '身份背景' + input_key: '__dj__event_description__' + output_key: '__dj__role_background__' + word_limit: 50 - entity_attribute_aggregator: api_model: 'qwen2.5-72b-instruct' entity: '李莲花' diff --git a/demos/role_playing_system_prompt/system_prompt_generator.py b/demos/role_playing_system_prompt/system_prompt_generator.py index e69de29bb..26712e747 100644 --- a/demos/role_playing_system_prompt/system_prompt_generator.py +++ b/demos/role_playing_system_prompt/system_prompt_generator.py @@ -0,0 +1,190 @@ +import random + +from itertools import chain +from loguru import logger +from collections import Counter + +from data_juicer.ops.aggregator import NestedAggregator +from data_juicer.ops.aggregator import EntityAttributeAggregator +from data_juicer.ops.aggregator import RelationIdentityAggregator +from data_juicer.utils.constant import Fields + +api_model = 'qwen2.5-72b-instruct' + +main_entity = "李莲花" +query_attributes = ["语言风格", "角色性格", "角色武艺和能力"] +system_prompt_key = '__dj__system_prompt__' +example_num_limit = 5 + +role_info_template = "# {entity}\n## 身份背景\n{identity}\n## 人物经历\n{experience}" +relation_identity_text_template = """ +{source_entity}的信息: +{source_entity_info} +{target_entity}的信息: +{target_entity_info} +{source_entity}对{target_entity}的称呼:{nicknames} +""" + +nested_sum = NestedAggregator( + model=api_model, + try_num=3) + +def dedup_sort_val_by_chunk_id(sample, id_key, val_key): + chunk_ids = sample[id_key] + vals = sample[val_key] + id_to_val = {} + for id, val in zip(chunk_ids, vals): + id_to_val[id] = val + sorted_ids = list(id_to_val.keys()) + sorted_ids.sort() + sorted_vals = [id_to_val[id] for id in sorted_ids] + return list(chain(*sorted_vals)) + +def get_attributes(sample): + main_entities = dedup_sort_val_by_chunk_id(sample, 'chunk_id', Fields.main_entities) + attribute_names = dedup_sort_val_by_chunk_id(sample, 'chunk_id', Fields.attributes) + attribute_descs = dedup_sort_val_by_chunk_id(sample, 'chunk_id', Fields.attribute_descriptions) + attribute_support_texts = dedup_sort_val_by_chunk_id(sample, 'chunk_id', Fields.attribute_support_texts) + attributes = {} + support_texts = {} + for attr in query_attributes: + attributes[attr] = [] + support_texts[attr] = [] + for entity, attr_name, attr_desc, sub_support_texts in \ + zip(main_entities, attribute_names, attribute_descs, attribute_support_texts): + if entity == main_entity and attr_name in query_attributes: + attributes[attr_name].append(attr_desc) + support_texts[attr_name].append(sub_support_texts) + return attributes, support_texts + +def get_nicknames(sample): + nicknames = dedup_sort_val_by_chunk_id(sample, 'chunk_id', Fields.nickname) + nickname_map = {} + for nr in nicknames: + if nr[Fields.source_entity] == main_entity: + role_name = nr[Fields.target_entity] + if role_name not in nickname_map: + nickname_map[role_name] = [] + nickname_map[role_name].append(nr[Fields.relation_description]) + + max_nums = 3 + for role_name, nickname_list in nickname_map.items(): + th = (len(nickname_list)+1) // 2 + count = Counter(nickname_list) + sorted_items = sorted(count.items(), key=lambda x: x[1], reverse=True) + most_common_nicknames = [] + idx = 0 + while th > 0 and idx < min(len(sorted_items), max_nums): + most_common_nicknames.append(sorted_items[idx][0]) + th -= sorted_items[idx][1] + idx += 1 + nickname_map[role_name] = most_common_nicknames + return nickname_map + + +def get_system_prompt(sample): + + main_role_identity = sample['__dj__role_background__'] + main_role_experience = sample['__dj__role_experience__'] + attributes, support_texts = get_attributes(sample) + main_role_character = nested_sum.recursive_summary(attributes['角色性格']) + main_role_skill = nested_sum.recursive_summary(attributes['角色武艺和能力']) + main_role_lang_style = nested_sum.recursive_summary(attributes['语言风格']) + lang_style_examples = list(chain(*support_texts['语言风格'])) + lang_style_example_num = min(example_num_limit, len(lang_style_examples)) + lang_style_examples = random.sample(lang_style_examples, lang_style_example_num) + + main_role_info = role_info_template.format( + entity=main_entity, + identity=main_role_identity, + experience=main_role_experience + ) + + nicknames = get_nicknames(sample) + + relation_detail = "" + relavant_roles = sample['__dj__important_relavant_roles__'] + for role_name in relavant_roles: + if role_name == main_entity: + continue + + # get sub role identity + op = EntityAttributeAggregator( + api_model=api_model, + entity=role_name, + attribute='身份背景', + input_key='__dj__event_description__', + output_key='__dj__role_background__', + word_limit=30 + ) + sample = op.process_single(sample) + role_identity = sample['__dj__role_background__'] + + # get sub role experience + op = EntityAttributeAggregator( + api_model=api_model, + entity=role_name, + attribute='主要经历', + input_key='__dj__event_description__', + output_key='__dj__role_experience__', + word_limit=100 + ) + sample = op.process_single(sample) + role_experience = sample['__dj__role_experience__'] + + # get relation identity with main role + role_info = role_info_template.format( + entity=role_name, + identity=role_identity, + experience=role_experience + ) + op = RelationIdentityAggregator( + api_model=api_model, + source_entity=main_entity, + target_entity=role_name, + output_key='__dj__relation_identity__' + ) + if role_name in nicknames: + cur_nicknames = '、'.join(nicknames[role_name]) + else: + cur_nicknames = role_name + text = relation_identity_text_template.format( + source_entity=main_entity, + source_entity_info=main_role_info, + target_entity=role_name, + target_entity_info=role_info, + nicknames = cur_nicknames + ) + tmp_sample = {'text': text} + tmp_sample = op.process_single(tmp_sample) + relation = tmp_sample['__dj__relation_identity__'] + + relation_detail += f"\n{role_name} (称呼:{cur_nicknames})" + relation_detail += f"{main_entity}的{relation}。" + relation_detail += f"{role_identity}{role_experience}".replace('\n', '') + + full_system_prompt = f"""扮演{main_entity}与用户进行对话。\n""" + full_system_prompt += """# 角色身份\n""" + full_system_prompt += main_role_identity + full_system_prompt += """\n# 角色经历\n""" + full_system_prompt += main_role_experience + full_system_prompt += """\n# 角色性格\n""" + full_system_prompt += main_role_character + full_system_prompt += """\n# 角色能力\n""" + full_system_prompt += main_role_skill + + full_system_prompt += """\n# 人际关系""" + full_system_prompt += relation_detail + + full_system_prompt += """\n# 语言风格\n""" + full_system_prompt += main_role_lang_style + full_system_prompt += f"""\n供参考语言风格的部分{main_entity}台词:\n""" + full_system_prompt += "\n````\n" + full_system_prompt += '\n'.join(lang_style_examples) + full_system_prompt += "\n````\n" + + logger.info(full_system_prompt) + + sample[system_prompt_key] = full_system_prompt + + return sample \ No newline at end of file diff --git a/tests/ops/mapper/test_extract_entity_attribute_mapper.py b/tests/ops/mapper/test_extract_entity_attribute_mapper.py index 177880358..f15b4ca3f 100644 --- a/tests/ops/mapper/test_extract_entity_attribute_mapper.py +++ b/tests/ops/mapper/test_extract_entity_attribute_mapper.py @@ -49,9 +49,14 @@ def _run_op(self, api_model, response_path=None): dataset = Dataset.from_list(samples) dataset = dataset.map(op.process, batch_size=1) for sample in dataset: - logger.info(f'{sample[Fields.main_entity]} {sample[Fields.attribute]}: {sample[Fields.attribute_description]}') - self.assertNotEqual(sample[Fields.attribute_description], '') - self.assertNotEqual(len(sample[Fields.attribute_support_text]), 0) + ents = sample[Fields.main_entities] + attrs = sample[Fields.attributes] + descs = sample[Fields.attribute_descriptions] + sups = sample[Fields.attribute_support_texts] + for ent, attr, desc, sup in zip(ents, attrs, descs, sups): + logger.info(f'{ent} {attr}: {desc}') + self.assertNotEqual(desc, '') + self.assertNotEqual(len(sup), 0) def test(self): # before runing this test, set below environment variables: From ecb86359e29d816c64bb6ab93afd9ac4f77513fd Mon Sep 17 00:00:00 2001 From: Haibin <1400012807@pku.edu.cn> Date: Mon, 2 Dec 2024 14:45:32 +0800 Subject: [PATCH 049/118] fix batch bug --- data_juicer/core/data.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/data_juicer/core/data.py b/data_juicer/core/data.py index 9cef1fe89..9152dd862 100644 --- a/data_juicer/core/data.py +++ b/data_juicer/core/data.py @@ -253,10 +253,12 @@ def map(self, *args, **kargs): # batched is required for fault-tolerant or batched OP if callable(getattr( called_func.__self__, - 'is_batched_op')) and called_func.__self__.is_batched_op( - ) or not getattr(called_func.__self__, 'turbo', False): + 'is_batched_op')) and called_func.__self__.is_batched_op(): kargs['batched'] = True kargs['batch_size'] = kargs.pop('batch_size', 1) + elif not getattr(called_func.__self__, 'turbo', False): + kargs['batched'] = True + kargs['batch_size'] = 1 else: kargs['batched'] = False From 03e3469c415a2e2c2714e673e379a9322a389c54 Mon Sep 17 00:00:00 2001 From: Haibin <1400012807@pku.edu.cn> Date: Mon, 2 Dec 2024 15:00:38 +0800 Subject: [PATCH 050/118] fix batch bug --- data_juicer/core/data.py | 51 +++++++++++----------------------------- 1 file changed, 14 insertions(+), 37 deletions(-) diff --git a/data_juicer/core/data.py b/data_juicer/core/data.py index 9cef1fe89..0fdcbe2c0 100644 --- a/data_juicer/core/data.py +++ b/data_juicer/core/data.py @@ -225,9 +225,7 @@ def process(self, monitor_dir) return dataset - def map(self, *args, **kargs): - """Override the map func, which is called by most common operations, - such that the processed samples can be accessed by nested manner.""" + def update_common_args(self, args, kargs): if args: args = list(args) # the first positional para is function @@ -253,10 +251,12 @@ def map(self, *args, **kargs): # batched is required for fault-tolerant or batched OP if callable(getattr( called_func.__self__, - 'is_batched_op')) and called_func.__self__.is_batched_op( - ) or not getattr(called_func.__self__, 'turbo', False): + 'is_batched_op')) and called_func.__self__.is_batched_op(): kargs['batched'] = True kargs['batch_size'] = kargs.pop('batch_size', 1) + elif not getattr(called_func.__self__, 'turbo', False): + kargs['batched'] = True + kargs['batch_size'] = 1 else: kargs['batched'] = False @@ -270,6 +270,14 @@ def map(self, *args, **kargs): new_fingerprint = generate_fingerprint(self, *args, **kargs) kargs['new_fingerprint'] = new_fingerprint + return args, kargs + + def map(self, *args, **kargs): + """Override the map func, which is called by most common operations, + such that the processed samples can be accessed by nested manner.""" + + args, kargs = self.update_common_args(args, kargs) + if cache_utils.CACHE_COMPRESS: decompress(self, kargs['new_fingerprint'], kargs['num_proc'] if 'num_proc' in kargs else 1) @@ -288,38 +296,7 @@ def map(self, *args, **kargs): def filter(self, *args, **kargs): """Override the filter func, which is called by most common operations, such that the processed samples can be accessed by nested manner.""" - if args: - args = list(args) - # the first positional para is function - if args[0] is None: - args[0] = lambda x: nested_obj_factory(x) - else: - args[0] = wrap_func_with_nested_access(args[0]) - called_func = args[0] - else: - if 'function' not in kargs or kargs['function'] is None: - kargs['function'] = lambda x: nested_obj_factory(x) - else: - kargs['function'] = wrap_func_with_nested_access( - kargs['function']) - called_func = kargs['function'] - - # For wrapped function, try to get its unwrapped (bound) method - while not inspect.ismethod(called_func) and hasattr( - called_func, '__wrapped__'): - called_func = called_func.__wrapped__ - - # Batched is always required for fault tolerance - if inspect.ismethod(called_func): - if callable(getattr( - called_func.__self__, - 'is_batched_op')) and called_func.__self__.is_batched_op(): - kargs['batched'] = True - kargs['batch_size'] = kargs.pop('batch_size', 1) - - if 'new_fingerprint' not in kargs or kargs['new_fingerprint'] is None: - new_fingerprint = generate_fingerprint(self, *args, **kargs) - kargs['new_fingerprint'] = new_fingerprint + args, kargs = self.update_common_args(args, kargs) # For filter, it involves a map and a filter operations, so the final # cache files includes two sets with different fingerprint (before and From 00ff62490e49dfb60060f015a4184bfe2c566b28 Mon Sep 17 00:00:00 2001 From: Haibin <1400012807@pku.edu.cn> Date: Tue, 3 Dec 2024 17:23:17 +0800 Subject: [PATCH 051/118] fix filter batch --- data_juicer/ops/base_op.py | 12 ++++++++---- .../ops/mapper/extract_entity_attribute_mapper.py | 6 +++++- .../system_prompt_generator.py | 2 +- 3 files changed, 14 insertions(+), 6 deletions(-) diff --git a/data_juicer/ops/base_op.py b/data_juicer/ops/base_op.py index 02651cbc7..9f70df7a6 100644 --- a/data_juicer/ops/base_op.py +++ b/data_juicer/ops/base_op.py @@ -70,7 +70,7 @@ def wrapper(samples, *args, **kwargs): return wrapper -def catch_map_single_exception(method): +def catch_map_single_exception(method, return_sample=True): """ For single-map sample-level fault tolerance. The input sample is expected batch_size = 1. @@ -92,8 +92,11 @@ def wrapper(sample, *args, **kwargs): if is_batched(sample): try: sample = convert_dict_list_to_list_dict(sample)[0] - res_sample = method(sample, *args, **kwargs) - return convert_list_dict_to_dict_list([res_sample]) + res = method(sample, *args, **kwargs) + if return_sample: + return convert_list_dict_to_dict_list([res]) + else: + return [res] except Exception as e: from loguru import logger logger.error( @@ -338,7 +341,8 @@ def __init__(self, *args, **kwargs): else: self.compute_stats = catch_map_single_exception( self.compute_stats_single) - self.process = catch_map_single_exception(self.process_single) + self.process = catch_map_single_exception(self.process_single, + return_sample=False) # set the process method is not allowed to be overridden def __init_subclass__(cls, **kwargs): diff --git a/data_juicer/ops/mapper/extract_entity_attribute_mapper.py b/data_juicer/ops/mapper/extract_entity_attribute_mapper.py index bc99cdc5f..0fc76b11f 100644 --- a/data_juicer/ops/mapper/extract_entity_attribute_mapper.py +++ b/data_juicer/ops/mapper/extract_entity_attribute_mapper.py @@ -27,14 +27,18 @@ class ExtractEntityAttributeMapper(Mapper): '## {attribute}:\n' '...\n' '### 代表性示例摘录1:\n' + '```\n' '...\n' + '```\n' '### 代表性示例摘录2:\n' + '```\n' '...\n' + '```\n' '...\n') DEFAULT_INPUT_TEMPLATE = '# 文本\n```\n{text}\n```\n' DEFAULT_ATTR_PATTERN_TEMPLATE = r'\#\#\s*{attribute}:\s*(.*?)(?=\#\#\#|\Z)' - DEFAULT_DEMON_PATTERN = r'\#\#\#\s*代表性示例(\d+):\s*(.*?)(?=\#\#\#|\Z)' + DEFAULT_DEMON_PATTERN = r'\#\#\#\s*代表性示例摘录(\d+):\s*```\s*(.*?)```\s*(?=\#\#\#|\Z)' # noqa: E501 def __init__(self, api_model: str = 'gpt-4o', diff --git a/demos/role_playing_system_prompt/system_prompt_generator.py b/demos/role_playing_system_prompt/system_prompt_generator.py index 26712e747..967149cb8 100644 --- a/demos/role_playing_system_prompt/system_prompt_generator.py +++ b/demos/role_playing_system_prompt/system_prompt_generator.py @@ -26,7 +26,7 @@ """ nested_sum = NestedAggregator( - model=api_model, + api_model=api_model, try_num=3) def dedup_sort_val_by_chunk_id(sample, id_key, val_key): From 8601519ad8e9ff243bdab77eab131166faf944a5 Mon Sep 17 00:00:00 2001 From: Haibin <1400012807@pku.edu.cn> Date: Tue, 3 Dec 2024 17:29:57 +0800 Subject: [PATCH 052/118] fix filter batch --- data_juicer/ops/base_op.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/data_juicer/ops/base_op.py b/data_juicer/ops/base_op.py index 13f3b61ae..831d94c12 100644 --- a/data_juicer/ops/base_op.py +++ b/data_juicer/ops/base_op.py @@ -70,7 +70,7 @@ def wrapper(samples, *args, **kwargs): return wrapper -def catch_map_single_exception(method): +def catch_map_single_exception(method, return_sample=True): """ For single-map sample-level fault tolerance. The input sample is expected batch_size = 1. @@ -92,8 +92,11 @@ def wrapper(sample, *args, **kwargs): if is_batched(sample): try: sample = convert_dict_list_to_list_dict(sample)[0] - res_sample = method(sample, *args, **kwargs) - return convert_list_dict_to_dict_list([res_sample]) + res = method(sample, *args, **kwargs) + if return_sample: + return convert_list_dict_to_dict_list([res]) + else: + return [res] except Exception as e: from loguru import logger logger.error( @@ -315,7 +318,8 @@ def __init__(self, *args, **kwargs): else: self.compute_stats = catch_map_single_exception( self.compute_stats_single) - self.process = catch_map_single_exception(self.process_single) + self.process = catch_map_single_exception(self.process_single, + return_sample=False) # set the process method is not allowed to be overridden def __init_subclass__(cls, **kwargs): From eeefcabc1ac4026e7c9ef1eb4b82a6308003bc18 Mon Sep 17 00:00:00 2001 From: Haibin <1400012807@pku.edu.cn> Date: Tue, 3 Dec 2024 17:47:33 +0800 Subject: [PATCH 053/118] system prompt recipe done --- .../system_prompt_generator.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/demos/role_playing_system_prompt/system_prompt_generator.py b/demos/role_playing_system_prompt/system_prompt_generator.py index 967149cb8..97f93755a 100644 --- a/demos/role_playing_system_prompt/system_prompt_generator.py +++ b/demos/role_playing_system_prompt/system_prompt_generator.py @@ -15,6 +15,7 @@ query_attributes = ["语言风格", "角色性格", "角色武艺和能力"] system_prompt_key = '__dj__system_prompt__' example_num_limit = 5 +max_relavant_roles_num = 5 role_info_template = "# {entity}\n## 身份背景\n{identity}\n## 人物经历\n{experience}" relation_identity_text_template = """ @@ -104,7 +105,7 @@ def get_system_prompt(sample): relation_detail = "" relavant_roles = sample['__dj__important_relavant_roles__'] - for role_name in relavant_roles: + for role_name in relavant_roles[:max_relavant_roles_num]: if role_name == main_entity: continue @@ -160,24 +161,25 @@ def get_system_prompt(sample): relation = tmp_sample['__dj__relation_identity__'] relation_detail += f"\n{role_name} (称呼:{cur_nicknames})" - relation_detail += f"{main_entity}的{relation}。" + if relation: + relation_detail += f"{main_entity}的{relation}。" relation_detail += f"{role_identity}{role_experience}".replace('\n', '') full_system_prompt = f"""扮演{main_entity}与用户进行对话。\n""" full_system_prompt += """# 角色身份\n""" - full_system_prompt += main_role_identity + full_system_prompt += main_role_identity.replace('\n', '') full_system_prompt += """\n# 角色经历\n""" - full_system_prompt += main_role_experience + full_system_prompt += main_role_experience.replace('\n', '') full_system_prompt += """\n# 角色性格\n""" - full_system_prompt += main_role_character + full_system_prompt += main_role_character.replace('\n', '') full_system_prompt += """\n# 角色能力\n""" - full_system_prompt += main_role_skill + full_system_prompt += main_role_skill.replace('\n', '') full_system_prompt += """\n# 人际关系""" - full_system_prompt += relation_detail + full_system_prompt += relation_detail.replace('\n', '') full_system_prompt += """\n# 语言风格\n""" - full_system_prompt += main_role_lang_style + full_system_prompt += main_role_lang_style.replace('\n', '') full_system_prompt += f"""\n供参考语言风格的部分{main_entity}台词:\n""" full_system_prompt += "\n````\n" full_system_prompt += '\n'.join(lang_style_examples) From 1575717f111f813d5882f7e574caa5d6b6b74b33 Mon Sep 17 00:00:00 2001 From: Haibin <1400012807@pku.edu.cn> Date: Thu, 5 Dec 2024 11:14:46 +0800 Subject: [PATCH 054/118] not rank for filter --- data_juicer/core/data.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/data_juicer/core/data.py b/data_juicer/core/data.py index 0fdcbe2c0..d1a77b581 100644 --- a/data_juicer/core/data.py +++ b/data_juicer/core/data.py @@ -225,7 +225,7 @@ def process(self, monitor_dir) return dataset - def update_common_args(self, args, kargs): + def update_args(self, args, kargs, is_filter=False): if args: args = list(args) # the first positional para is function @@ -260,8 +260,8 @@ def update_common_args(self, args, kargs): else: kargs['batched'] = False - # rank is required for cuda model loading - if callable( + # rank is required for cuda model loading for map + if not is_filter and callable( getattr(called_func.__self__, 'use_cuda')) and called_func.__self__.use_cuda(): kargs['with_rank'] = True @@ -276,7 +276,7 @@ def map(self, *args, **kargs): """Override the map func, which is called by most common operations, such that the processed samples can be accessed by nested manner.""" - args, kargs = self.update_common_args(args, kargs) + args, kargs = self.update_args(args, kargs) if cache_utils.CACHE_COMPRESS: decompress(self, kargs['new_fingerprint'], @@ -296,7 +296,7 @@ def map(self, *args, **kargs): def filter(self, *args, **kargs): """Override the filter func, which is called by most common operations, such that the processed samples can be accessed by nested manner.""" - args, kargs = self.update_common_args(args, kargs) + args, kargs = self.update_args(args, kargs, is_filter=True) # For filter, it involves a map and a filter operations, so the final # cache files includes two sets with different fingerprint (before and From 2c5c4a108afb5916a816d3c603c470e966cabb1d Mon Sep 17 00:00:00 2001 From: Haibin <1400012807@pku.edu.cn> Date: Thu, 5 Dec 2024 11:27:03 +0800 Subject: [PATCH 055/118] limit pyav version --- environments/minimal_requires.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/environments/minimal_requires.txt b/environments/minimal_requires.txt index 7d37959fe..df76b1358 100644 --- a/environments/minimal_requires.txt +++ b/environments/minimal_requires.txt @@ -2,7 +2,7 @@ datasets>=2.19.0 fsspec==2023.5.0 pandas numpy -av +av==13.1.0 soundfile librosa>=0.10 loguru From 49be467e9872682a811e3af49cb55a006410e7dc Mon Sep 17 00:00:00 2001 From: Haibin <1400012807@pku.edu.cn> Date: Thu, 5 Dec 2024 14:35:23 +0800 Subject: [PATCH 056/118] add test for op --- data_juicer/ops/aggregator/__init__.py | 3 +- data_juicer/ops/mapper/__init__.py | 5 +- .../relation_identity_mapper.py} | 6 +- .../system_prompt_generator.py | 10 +- .../test_entity_attribute_aggregator.py | 12 +-- .../test_most_relavant_entities_aggregator.py | 92 +++++++++++++++++++ .../mapper/test_relation_identity_mapper.py | 58 ++++++++++++ 7 files changed, 168 insertions(+), 18 deletions(-) rename data_juicer/ops/{aggregator/relation_idenity_aggregator.py => mapper/relation_identity_mapper.py} (97%) create mode 100644 tests/ops/Aggregator/test_most_relavant_entities_aggregator.py create mode 100644 tests/ops/mapper/test_relation_identity_mapper.py diff --git a/data_juicer/ops/aggregator/__init__.py b/data_juicer/ops/aggregator/__init__.py index c411f7cea..4afe2974a 100644 --- a/data_juicer/ops/aggregator/__init__.py +++ b/data_juicer/ops/aggregator/__init__.py @@ -1,9 +1,8 @@ from .entity_attribute_aggregator import EntityAttributeAggregator from .most_relavant_entities_aggregator import MostRelavantEntitiesAggregator from .nested_aggregator import NestedAggregator -from .relation_idenity_aggregator import RelationIdentityAggregator __all__ = [ 'NestedAggregator', 'EntityAttributeAggregator', - 'MostRelavantEntitiesAggregator', 'RelationIdentityAggregator' + 'MostRelavantEntitiesAggregator' ] diff --git a/data_juicer/ops/mapper/__init__.py b/data_juicer/ops/mapper/__init__.py index 34d5e11bd..3d94a3f00 100644 --- a/data_juicer/ops/mapper/__init__.py +++ b/data_juicer/ops/mapper/__init__.py @@ -32,6 +32,7 @@ from .pair_preference_mapper import PairPreferenceMapper from .punctuation_normalization_mapper import PunctuationNormalizationMapper from .python_file_mapper import PythonFileMapper +from .relation_identity_mapper import RelationIdentityMapper from .remove_bibliography_mapper import RemoveBibliographyMapper from .remove_comments_mapper import RemoveCommentsMapper from .remove_header_mapper import RemoveHeaderMapper @@ -78,8 +79,8 @@ 'ImageTaggingMapper', 'NlpaugEnMapper', 'NlpcdaZhMapper', 'OptimizeQAMapper', 'OptimizeQueryMapper', 'OptimizeResponseMapper', 'PairPreferenceMapper', 'PunctuationNormalizationMapper', - 'PythonFileMapper', 'RemoveBibliographyMapper', 'RemoveCommentsMapper', - 'RemoveHeaderMapper', 'RemoveLongWordsMapper', + 'PythonFileMapper', 'RelationIdentityMapper', 'RemoveBibliographyMapper', + 'RemoveCommentsMapper', 'RemoveHeaderMapper', 'RemoveLongWordsMapper', 'RemoveNonChineseCharacterlMapper', 'RemoveRepeatSentencesMapper', 'RemoveSpecificCharsMapper', 'RemoveTableTextMapper', 'RemoveWordsWithIncorrectSubstringsMapper', 'ReplaceContentMapper', diff --git a/data_juicer/ops/aggregator/relation_idenity_aggregator.py b/data_juicer/ops/mapper/relation_identity_mapper.py similarity index 97% rename from data_juicer/ops/aggregator/relation_idenity_aggregator.py rename to data_juicer/ops/mapper/relation_identity_mapper.py index 1b94b9212..c3eda2eec 100644 --- a/data_juicer/ops/aggregator/relation_idenity_aggregator.py +++ b/data_juicer/ops/mapper/relation_identity_mapper.py @@ -4,16 +4,16 @@ from loguru import logger from pydantic import PositiveInt -from data_juicer.ops.base_op import OPERATORS, Aggregator +from data_juicer.ops.base_op import OPERATORS, Mapper from data_juicer.utils.common_utils import nested_access, nested_set from data_juicer.utils.model_utils import get_model, prepare_model -OP_NAME = 'relation_identity_aggregator' +OP_NAME = 'relation_identity_mapper' # TODO: LLM-based inference. @OPERATORS.register_module(OP_NAME) -class RelationIdentityAggregator(Aggregator): +class RelationIdentityMapper(Mapper): """ identify relation between two entity in the text. """ diff --git a/demos/role_playing_system_prompt/system_prompt_generator.py b/demos/role_playing_system_prompt/system_prompt_generator.py index 97f93755a..dc2738900 100644 --- a/demos/role_playing_system_prompt/system_prompt_generator.py +++ b/demos/role_playing_system_prompt/system_prompt_generator.py @@ -6,7 +6,7 @@ from data_juicer.ops.aggregator import NestedAggregator from data_juicer.ops.aggregator import EntityAttributeAggregator -from data_juicer.ops.aggregator import RelationIdentityAggregator +from data_juicer.ops.mapper import RelationIdentityMapper from data_juicer.utils.constant import Fields api_model = 'qwen2.5-72b-instruct' @@ -119,7 +119,7 @@ def get_system_prompt(sample): word_limit=30 ) sample = op.process_single(sample) - role_identity = sample['__dj__role_background__'] + role_identity = sample['__dj__role_background__'].replace('\n', '') # get sub role experience op = EntityAttributeAggregator( @@ -131,7 +131,7 @@ def get_system_prompt(sample): word_limit=100 ) sample = op.process_single(sample) - role_experience = sample['__dj__role_experience__'] + role_experience = sample['__dj__role_experience__'].replace('\n', '') # get relation identity with main role role_info = role_info_template.format( @@ -139,7 +139,7 @@ def get_system_prompt(sample): identity=role_identity, experience=role_experience ) - op = RelationIdentityAggregator( + op = RelationIdentityMapper( api_model=api_model, source_entity=main_entity, target_entity=role_name, @@ -176,7 +176,7 @@ def get_system_prompt(sample): full_system_prompt += main_role_skill.replace('\n', '') full_system_prompt += """\n# 人际关系""" - full_system_prompt += relation_detail.replace('\n', '') + full_system_prompt += relation_detail full_system_prompt += """\n# 语言风格\n""" full_system_prompt += main_role_lang_style.replace('\n', '') diff --git a/tests/ops/Aggregator/test_entity_attribute_aggregator.py b/tests/ops/Aggregator/test_entity_attribute_aggregator.py index 647d0486a..bb1fb4ee7 100644 --- a/tests/ops/Aggregator/test_entity_attribute_aggregator.py +++ b/tests/ops/Aggregator/test_entity_attribute_aggregator.py @@ -118,12 +118,12 @@ def test_example_prompt(self): }, ] example_prompt=( - '- 例如,根据相关文档总结`孙悟空`的`另外身份`,样例如下:\n' - '`孙悟空`的`另外身份`总结:\n' - '# 孙悟空\n' - '## 另外身份\n' - '孙行者、齐天大圣、美猴王\n' - ) + '- 例如,根据相关文档总结`孙悟空`的`另外身份`,样例如下:\n' + '`孙悟空`的`另外身份`总结:\n' + '# 孙悟空\n' + '## 另外身份\n' + '孙行者、齐天大圣、美猴王\n' + ) op = EntityAttributeAggregator( api_model='qwen2.5-72b-instruct', entity='李莲花', diff --git a/tests/ops/Aggregator/test_most_relavant_entities_aggregator.py b/tests/ops/Aggregator/test_most_relavant_entities_aggregator.py new file mode 100644 index 000000000..dccee6704 --- /dev/null +++ b/tests/ops/Aggregator/test_most_relavant_entities_aggregator.py @@ -0,0 +1,92 @@ +import unittest + +from loguru import logger + +from data_juicer.core.data import NestedDataset as Dataset +from data_juicer.ops.aggregator import MostRelavantEntitiesAggregator +from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase + + +class MostRelavantEntitiesAggregatorTest(DataJuicerTestCaseBase): + + def _run_helper(self, op, samples): + + # before runing this test, set below environment variables: + # export OPENAI_BASE_URL=https://dashscope.aliyuncs.com/compatible-mode/v1/ + # export OPENAI_API_KEY=your_dashscope_key + + dataset = Dataset.from_list(samples) + new_dataset = op.run(dataset) + + for data in new_dataset: + for k in data: + logger.info(f"{k}: {data[k]}") + + self.assertEqual(len(new_dataset), len(samples)) + + def test_default_aggregator(self): + samples = [ + { + 'text': [ + "十年前,李相夷十五岁战胜西域天魔成为天下第一高手,十七岁建立四顾门,二十岁问鼎武林盟主,成为传奇人物。", + "有人视李相夷为中原武林的希望,但也有人以战胜他为目标,包括魔教金鸳盟盟主笛飞声。笛飞声设计加害李相夷的师兄单孤刀,引得李相夷与之一战。", + '在东海的一艘船上,李相夷独自一人对抗金鸳盟的高手,最终击败了大部分敌人。笛飞声突然出现,两人激战,李相夷在战斗中中毒,最终被笛飞声重伤,船只爆炸,李相夷沉入大海。', + '十年后,李莲花在一个寒酸的莲花楼内醒来,表现出与李相夷截然不同的性格。他以神医的身份在小镇上行医,但生活贫困。', + '小镇上的皮影戏摊讲述李相夷和笛飞声的故事,孩子们争论谁赢了。风火堂管事带着人来找李莲花,要求他救治一个“死人”。' + ] + }, + ] + + op = MostRelavantEntitiesAggregator( + api_model='qwen2.5-72b-instruct', + entity='李莲花', + query_entity_type='人物' + ) + self._run_helper(op, samples) + + def test_input_output(self): + samples = [ + { + 'dj_result':{ + 'events': [ + "十年前,李相夷十五岁战胜西域天魔成为天下第一高手,十七岁建立四顾门,二十岁问鼎武林盟主,成为传奇人物。", + "有人视李相夷为中原武林的希望,但也有人以战胜他为目标,包括魔教金鸳盟盟主笛飞声。笛飞声设计加害李相夷的师兄单孤刀,引得李相夷与之一战。", + '在东海的一艘船上,李相夷独自一人对抗金鸳盟的高手,最终击败了大部分敌人。笛飞声突然出现,两人激战,李相夷在战斗中中毒,最终被笛飞声重伤,船只爆炸,李相夷沉入大海。', + '十年后,李莲花在一个寒酸的莲花楼内醒来,表现出与李相夷截然不同的性格。他以神医的身份在小镇上行医,但生活贫困。', + '小镇上的皮影戏摊讲述李相夷和笛飞声的故事,孩子们争论谁赢了。风火堂管事带着人来找李莲花,要求他救治一个“死人”。' + ] + } + }, + ] + + op = MostRelavantEntitiesAggregator( + api_model='qwen2.5-72b-instruct', + entity='李莲花', + query_entity_type='人物', + input_key='dj_result.events', + output_key='dj_result.relavant_roles' + ) + self._run_helper(op, samples) + + def test_max_token_num(self): + samples = [ + { + 'text': [ + "十年前,李相夷十五岁战胜西域天魔成为天下第一高手,十七岁建立四顾门,二十岁问鼎武林盟主,成为传奇人物。", + "有人视李相夷为中原武林的希望,但也有人以战胜他为目标,包括魔教金鸳盟盟主笛飞声。笛飞声设计加害李相夷的师兄单孤刀,引得李相夷与之一战。", + '在东海的一艘船上,李相夷独自一人对抗金鸳盟的高手,最终击败了大部分敌人。笛飞声突然出现,两人激战,李相夷在战斗中中毒,最终被笛飞声重伤,船只爆炸,李相夷沉入大海。', + '十年后,李莲花在一个寒酸的莲花楼内醒来,表现出与李相夷截然不同的性格。他以神医的身份在小镇上行医,但生活贫困。', + '小镇上的皮影戏摊讲述李相夷和笛飞声的故事,孩子们争论谁赢了。风火堂管事带着人来找李莲花,要求他救治一个“死人”。' + ] + }, + ] + op = MostRelavantEntitiesAggregator( + api_model='qwen2.5-72b-instruct', + entity='李莲花', + query_entity_type='人物', + max_token_num=40 + ) + self._run_helper(op, samples) + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/tests/ops/mapper/test_relation_identity_mapper.py b/tests/ops/mapper/test_relation_identity_mapper.py new file mode 100644 index 000000000..d730cb79f --- /dev/null +++ b/tests/ops/mapper/test_relation_identity_mapper.py @@ -0,0 +1,58 @@ +import unittest +import json + +from loguru import logger + +from data_juicer.core.data import NestedDataset as Dataset +from data_juicer.ops.mapper.relation_identity_mapper import RelationIdentityMapper +from data_juicer.utils.unittest_utils import (SKIPPED_TESTS, + DataJuicerTestCaseBase) +from data_juicer.utils.constant import Fields + +# Skip tests for this OP in the GitHub actions due to unknown DistNetworkError. +# These tests have been tested locally. +@SKIPPED_TESTS.register_module() +class RelationIdentityMapperTest(DataJuicerTestCaseBase): + + + def _run_op(self, api_model, response_path=None): + + op = RelationIdentityMapper(api_model=api_model, + source_entity="李莲花", + target_entity="方多病", + response_path=response_path) + + raw_text = """李莲花原名李相夷,十五岁战胜西域天魔,十七岁建立四顾门,二十岁问鼎武林盟主,成为传奇人物。 +在与金鸳盟盟主笛飞声的对决中,李相夷中毒重伤,沉入大海,十年后在莲花楼醒来,过起了市井生活。他帮助肉铺掌柜解决家庭矛盾,表现出敏锐的洞察力。 +李莲花与方多病合作,解决了灵山派掌门王青山的假死案,揭露了朴管家的罪行。 +随后,他与方多病和笛飞声一起调查了玉秋霜的死亡案,最终揭露了玉红烛的阴谋。在朴锄山,李莲花和方多病调查了七具无头尸事件,发现男童的真实身份是笛飞声。 +李莲花利用飞猿爪偷走男童手中的观音垂泪,导致笛飞声恢复内力,但李莲花巧妙逃脱。李莲花与方多病继续合作,调查了少师剑被盗案,揭露了静仁和尚的阴谋。 +在采莲庄,他解决了新娘溺水案,找到了狮魂的线索,并在南门园圃挖出单孤刀的药棺。在玉楼春的案件中,李莲花和方多病揭露了玉楼春的阴谋,救出了被拐的清儿。 +在石寿村,他们发现了柔肠玉酿的秘密,并救出了被控制的武林高手。李莲花与方多病在白水园设下机关,救出方多病的母亲何晓惠,并最终在云隐山找到了治疗碧茶之毒的方法。 +在天机山庄,他揭露了单孤刀的野心,救出了被控制的大臣。在皇宫,李莲花与方多病揭露了魔僧和单孤刀的阴谋,成功解救了皇帝。 +最终,李莲花在东海之滨与笛飞声的决斗中未出现,留下一封信,表示自己已无法赴约。 +一年后,方多病在东海畔的柯厝村找到了李莲花,此时的李莲花双目失明,右手残废,但心态平和,过着简单的生活。 +方多病 (称呼:方小宝、方大少爷)百川院刑探,单孤刀之子,李相夷的徒弟。方多病通过百川院的考核,成为刑探,并在百川院内展示了自己是李相夷的弟子,获得暂时的录用。 +他接到任务前往嘉州调查金鸳盟的余孽,期间与李莲花相识并合作破案。方多病在调查过程中逐渐了解到自己的身世,发现自己的生父是单孤刀。 +他与李莲花、笛飞声等人多次合作,共同对抗金鸳盟和单孤刀的阴谋。方多病在一系列案件中展现了出色的推理能力和武艺,逐渐成长为一名优秀的刑探。 +最终,方多病在天机山庄和皇宫的斗争中发挥了关键作用,帮助李莲花等人挫败了单孤刀的野心。在李莲花中毒后,方多病决心为他寻找解毒之法,展现了深厚的友情。 +""" + samples = [{ + 'text': raw_text, + }] + + dataset = Dataset.from_list(samples) + dataset = dataset.map(op.process, batch_size=2) + for data in dataset: + for k in data: + logger.info(f"{k}: {data[k]}") + + def test(self): + # before runing this test, set below environment variables: + # export OPENAI_API_URL=https://dashscope.aliyuncs.com/compatible-mode/v1 + # export OPENAI_API_KEY=your_key + self._run_op('qwen2.5-72b-instruct') + + +if __name__ == '__main__': + unittest.main() From 9ab02fec99a46dc8c4f58490f43bba1e8764140d Mon Sep 17 00:00:00 2001 From: Haibin <1400012807@pku.edu.cn> Date: Thu, 5 Dec 2024 14:47:01 +0800 Subject: [PATCH 057/118] tmp --- configs/config_all.yaml | 12 ++++++++++++ .../ops/mapper/extract_support_text_mapper.py | 2 +- 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/configs/config_all.yaml b/configs/config_all.yaml index ea10be519..a74bf36d1 100644 --- a/configs/config_all.yaml +++ b/configs/config_all.yaml @@ -152,6 +152,18 @@ process: drop_text: false # If drop the text in the output. model_params: {} # Parameters for initializing the API model. sampling_params: {} # Extra parameters passed to the API call. e.g {'temperature': 0.9, 'top_p': 0.95} + - extract_support_text_mapper: # extract support sub text for a summary. + api_model: 'gpt-4o' # API model name. + summary_key: '__dj__event_description__' # The field name to store the input summary. Support for nested keys such as "__dj__stats__.text_len". + support_text_key: '__dj__support_text__' # The field name to store the output support text for the summary. + api_endpoint: null # URL endpoint for the API. + response_path: null # Path to extract content from the API response. Defaults to 'choices.0.message.content'. + system_prompt: null # System prompt for the task. + input_template: null # Template for building the model input. + try_num: 3 # The number of retry attempts when there is an API call error or output parsing error. + drop_text: false # If drop the text in the output. + model_params: {} # Parameters for initializing the API model. + sampling_params: {} # Extra parameters passed to the API call. e.g {'temperature': 0.9, 'top_p': 0.95} - fix_unicode_mapper: # fix unicode errors in text. - generate_qa_from_examples_mapper: # mapper to generate question and answer pairs from examples. hf_model: 'Qwen/Qwen2.5-7B-Instruct' # Model name on huggingface to generate question and answer pairs. diff --git a/data_juicer/ops/mapper/extract_support_text_mapper.py b/data_juicer/ops/mapper/extract_support_text_mapper.py index 011582a30..34bdbe653 100644 --- a/data_juicer/ops/mapper/extract_support_text_mapper.py +++ b/data_juicer/ops/mapper/extract_support_text_mapper.py @@ -63,7 +63,7 @@ def __init__(self, :param summary_key: The field name to store the input summary. Support for nested keys such as "__dj__stats__.text_len". It's "__dj__event_description__" in default. - :param relevant_char_key: The field name to store the output + :param support_text_key: The field name to store the output support text for the summary. It's "__dj__support_text__" in default. :param api_endpoint: URL endpoint for the API. From f71213124a469128d897eb8dc3d24914de0e6d6a Mon Sep 17 00:00:00 2001 From: Haibin <1400012807@pku.edu.cn> Date: Thu, 5 Dec 2024 17:19:53 +0800 Subject: [PATCH 058/118] doc done --- configs/config_all.yaml | 71 ++++++++++++++++++- .../most_relavant_entities_aggregator.py | 3 +- .../ops/mapper/relation_identity_mapper.py | 4 +- demos/role_playing_system_prompt/README_ZH.md | 49 +++++++++++++ .../role_playing_system_prompt.yaml | 14 ++-- docs/Operators.md | 21 +++++- docs/Operators_ZH.md | 21 +++++- 7 files changed, 169 insertions(+), 14 deletions(-) create mode 100644 demos/role_playing_system_prompt/README_ZH.md diff --git a/configs/config_all.yaml b/configs/config_all.yaml index c6faefd60..d4f6c1438 100644 --- a/configs/config_all.yaml +++ b/configs/config_all.yaml @@ -78,9 +78,9 @@ process: - clean_copyright_mapper: # remove copyright comments. - expand_macro_mapper: # expand macro definitions in Latex text. - extract_entity_attribute_mapper: # Extract attributes for given entities from the text. + api_model: 'gpt-4o' # API model name. query_entities: ["孙悟空", "猪八戒"] # Entity list to be queried. query_attributes: ["人物性格"] # Attribute list to be queried. - api_model: 'gpt-4o' # API model name. entity_key: '__dj__entity__' # The field name to store the given main entity for attribute extraction. entity_attribute_key: '__dj__attribute__' # The field name to store the given attribute to be extracted. attribute_desc_key: '__dj__attribute_description__' # The field name to store the extracted attribute description. @@ -269,10 +269,25 @@ process: model_params: {} # Parameters for initializing the API model. sampling_params: {} # Extra parameters passed to the API call. - punctuation_normalization_mapper: # normalize unicode punctuations to English punctuations. - - python_python_mapper: # executing Python lambda function defined in a file. + - python_file_mapper: # executing Python lambda function defined in a file. file_path: '' # The path to the Python file containing the function to be executed. function_name: 'process_single' # The name of the function defined in the file to be executed. batched: False # A boolean indicating whether to process input data in batches. + - relation_identity_mapper: # identify relation between two entity in the text. + api_model: 'gpt-4o' # API model name. + source_entity: '孙悟空' # The source entity of the relation to be dentified. + target_entity: '猪八戒' # The target entity of the relation to be identified. + input_key: null # The input field key in the samples. Support for nested keys such as "__dj__stats__.text_len". It is text_key in default. + output_key: null # The output field key in the samples. Support for nested keys such as "__dj__stats__.text_len". It is input_key in default. + api_endpoint: null # URL endpoint for the API. + response_path: null # Path to extract content from the API response. Defaults to 'choices.0.message.content'. + system_prompt_template: null # System prompt template for the task. Need to specify by entity1 and entity2. + input_template: null # Template for building the model input. + output_pattern_template: null # Regular expression template for parsing model output. + try_num: 3 # The number of retry attempts when there is an API call error or output parsing error. + drop_text: false # If drop the text in the output. + model_params: {} # Parameters for initializing the API model. + sampling_params: {} # Extra parameters passed to the API call. e.g {'temperature': 0.9, 'top_p': 0.95} - remove_bibliography_mapper: # remove bibliography from Latex text. - remove_comments_mapper: # remove comments from Latex text, code, etc. doc_type: tex # comment type you want to remove. Only support 'tex' for now. @@ -694,3 +709,55 @@ process: top_ratio: # ratio of selected top samples topk: # number of selected top sample reverse: True # determine the sorting rule, if reverse=True, then sort in descending order + +# Grouper ops. + - naive_grouper: # Group all samples to one batched sample. + - key_value_grouper: # Group samples to batched samples according values in given keys. + group_by_keys: null # Group samples according values in the keys. Support for nested keys such as "__dj__stats__.text_len". It is [self.text_key] in default. + +# Aggregator ops. + - entity_attribute_aggregator: # Return conclusion of the given entity's attribute from some docs. + api_model: 'gpt-4o' # API model name. + entity: '孙悟空' # The given entity. + attribute: '人物经历' # The given attribute. + input_key: null # The input field key in the samples. Support for nested keys such as "__dj__stats__.text_len". It is text_key in default. + output_key: null # The output field key in the samples. Support for nested keys such as "__dj__stats__.text_len". It is same as the input_key in default. + word_limit: 100 # Prompt the output length. + max_token_num: null # The max token num of the total tokens of the sub documents. Without limitation if it is None. + api_endpoint: null # URL endpoint for the API. + response_path: null # Path to extract content from the API response. Defaults to 'choices.0.message.content'. + system_prompt_template: null # System prompt template for the task. Need to be specified by given entity and attribute. + example_prompt: null # The example part in the system prompt. + input_template: null # The input template. + output_pattern_template: null # The output template. + try_num: 3 # The number of retry attempts when there is an API call error or output parsing error. + model_params: {} # Parameters for initializing the API model. + sampling_params: {} # Extra parameters passed to the API call. e.g {'temperature': 0.9, 'top_p': 0.95} + - most_relavant_entities_aggregator: # Extract entities closely related to a given entity from some texts, and sort them in descending order of importance. + api_model: 'gpt-4o' # API model name. + entity: '孙悟空' # The given entity. + query_entity_type: '人物' # The type of queried relavant entities. + input_key: null # The input field key in the samples. Support for nested keys such as "__dj__stats__.text_len". It is text_key in default. + output_key: null # The output field key in the samples. Support for nested keys such as "__dj__stats__.text_len". It is same as the input_key in default. + max_token_num: null # The max token num of the total tokens of the sub documents. Without limitation if it is None. + api_endpoint: null # URL endpoint for the API. + response_path: null # Path to extract content from the API response. Defaults to 'choices.0.message.content'. + system_prompt_template: null # System prompt template for the task. Need to be specified by given entity and entity_type. + input_template: null # The input template. + output_pattern: null # The output pattern. + try_num: 3 # The number of retry attempts when there is an API call error or output parsing error. + model_params: {} # Parameters for initializing the API model. + sampling_params: {} # Extra parameters passed to the API call. e.g {'temperature': 0.9, 'top_p': 0.95} + - nested_aggregator: # Considering the limitation of input length, nested aggregate contents for each given number of samples. + api_model: 'gpt-4o' # API model name. + input_key: null # The input field key in the samples. Support for nested keys such as "__dj__stats__.text_len". It is text_key in default. + output_key: null # The output field key in the samples. Support for nested keys such as "__dj__stats__.text_len". It is same as the input_key in default. + max_token_num: null # The max token num of the total tokens of the sub documents. Without limitation if it is None. + api_endpoint: null # URL endpoint for the API. + response_path: null # Path to extract content from the API response. Defaults to 'choices.0.message.content'. + system_prompt: null # The system prompt. + sub_doc_template: null # The template for input text in each sample. + input_template: null # The input template. + try_num: 3 # The number of retry attempts when there is an API call error or output parsing error. + model_params: {} # Parameters for initializing the API model. + sampling_params: {} # Extra parameters passed to the API call. e.g {'temperature': 0.9, 'top_p': 0.95} diff --git a/data_juicer/ops/aggregator/most_relavant_entities_aggregator.py b/data_juicer/ops/aggregator/most_relavant_entities_aggregator.py index d52a9a1fc..69e1a209c 100644 --- a/data_juicer/ops/aggregator/most_relavant_entities_aggregator.py +++ b/data_juicer/ops/aggregator/most_relavant_entities_aggregator.py @@ -22,7 +22,8 @@ @OPERATORS.register_module(OP_NAME) class MostRelavantEntitiesAggregator(Aggregator): """ - Return most relavant entities with the given entity from some docs. + Extract entities closely related to a given entity from some texts, + and sort them in descending order of importance. """ DEFAULT_SYSTEM_TEMPLATE = ( diff --git a/data_juicer/ops/mapper/relation_identity_mapper.py b/data_juicer/ops/mapper/relation_identity_mapper.py index c3eda2eec..29994d744 100644 --- a/data_juicer/ops/mapper/relation_identity_mapper.py +++ b/data_juicer/ops/mapper/relation_identity_mapper.py @@ -58,14 +58,14 @@ def __init__(self, :param api_model: API model name. :param source_entity: The source entity of the relation to be identified. - :param api_endpoint: The target entity of the relation to be + :param target_entity: The target entity of the relation to be identified. :param input_key: The input field key in the samples. Support for nested keys such as "__dj__stats__.text_len". It is text_key in default. :param output_key: The output field key in the samples. Support for nested keys such as "__dj__stats__.text_len". It is - "__dj__relation_identity__". + input_key in default. :param api_endpoint: URL endpoint for the API. :param response_path: Path to extract content from the API response. Defaults to 'choices.0.message.content'. diff --git a/demos/role_playing_system_prompt/README_ZH.md b/demos/role_playing_system_prompt/README_ZH.md new file mode 100644 index 000000000..956c335bb --- /dev/null +++ b/demos/role_playing_system_prompt/README_ZH.md @@ -0,0 +1,49 @@ +# 为LLM构造角色扮演的system prompt + +在该Demo中,我们展示了如何通过Data-Juicer的菜谱,生成让LLM扮演剧本中给定角色的system prompt。我们这里以《莲花楼》为例。 + +## 数据准备 +将《莲花楼》按章节划分,按顺序每个章节对应Data-Juicer的一个sample,放到“text”关键字下。如下json格式: +```json +[ + {'text': '第一章内容'}, + {'text': '第二章内容'}, + {'text': '第三章内容'}, + ... +] +``` + +## 执行 +```shell +python tools/process_data.py --config ./demos/role_playing_system_prompt/role_playing_system_prompt_test.yaml +``` + +## 生成样例 + +```text +扮演李莲花与用户进行对话。 +# 角色身份 +原名李相夷,曾是武林盟主,创立四顾门。十年前因中碧茶之毒,隐姓埋名,成为莲花楼的老板,过着市井生活。 +# 角色经历 +李莲花原名李相夷,十五岁战胜西域天魔,十七岁建立四顾门,二十岁问鼎武林盟主,成为传奇人物。在与金鸳盟盟主笛飞声的对决中,李相夷中毒重伤,沉入大海,十年后在莲花楼醒来,过起了市井生活。他帮助肉铺掌柜解决家庭矛盾,表现出敏锐的洞察力。李莲花与方多病合作,解决了灵山派掌门王青山的假死案,揭露了朴管家的罪行。随后,他与方多病和笛飞声一起调查了玉秋霜的死亡案,最终揭露了玉红烛的阴谋。在朴锄山,李莲花和方多病调查了七具无头尸事件,发现男童的真实身份是笛飞声。李莲花利用飞猿爪偷走男童手中的观音垂泪,导致笛飞声恢复内力,但李莲花巧妙逃脱。李莲花与方多病继续合作,调查了少师剑被盗案,揭露了静仁和尚的阴谋。在采莲庄,他解决了新娘溺水案,找到了狮魂的线索,并在南门园圃挖出单孤刀的药棺。在玉楼春的案件中,李莲花和方多病揭露了玉楼春的阴谋,救出了被拐的清儿。在石寿村,他们发现了柔肠玉酿的秘密,并救出了被控制的武林高手。李莲花与方多病在白水园设下机关,救出方多病的母亲何晓惠,并最终在云隐山找到了治疗碧茶之毒的方法。在天机山庄,他揭露了单孤刀的野心,救出了被控制的大臣。在皇宫,李莲花与方多病揭露了魔僧和单孤刀的阴谋,成功解救了皇帝。最终,李莲花在东海之滨与笛飞声的决斗中未出现,留下一封信,表示自己已无法赴约。一年后,方多病在东海畔的柯厝村找到了李莲花,此时的李莲花双目失明,右手残废,但心态平和,过着简单的生活。 +# 角色性格 +李莲花是一个机智、幽默、善于观察和推理的人物。他表面上看似随和、悠闲,甚至有些懒散,但实际上心思缜密,洞察力极强。他不仅具备敏锐的观察力和独特的思维方式,还拥有深厚的内功和高超的医术。他对朋友忠诚,愿意为了保护他们不惜一切代价,同时在面对敌人时毫不手软。尽管内心充满正义感和责任感,但他选择远离江湖纷争,追求宁静自在的生活。他对过去的自己(李相夷)有着深刻的反思,对乔婉娩的感情复杂,既有愧疚也有关怀。李莲花能够在复杂的环境中保持冷静,巧妙地利用智慧和技能解决问题,展现出非凡的勇气和决心。 +# 角色能力 +李莲花是一位智慧与武艺兼备的高手,拥有深厚的内力、高超的医术和敏锐的洞察力。他擅长使用轻功、剑术和特殊武器,如婆娑步和少师剑,能够在关键时刻化解危机。尽管身体状况不佳,他仍能通过内功恢复体力,运用智谋和技巧应对各种挑战。他在江湖中身份多变,既能以游医身份逍遥自在,也能以李相夷的身份化解武林危机。 +# 人际关系 +方多病 (称呼:方小宝、方大少爷)李莲花的徒弟。百川院刑探,单孤刀之子,李相夷的徒弟。方多病通过百川院的考核,成为刑探,并在百川院内展示了自己是李相夷的弟子,获得暂时的录用。他接到任务前往嘉州调查金鸳盟的余孽,期间与李莲花相识并合作破案。方多病在调查过程中逐渐了解到自己的身世,发现自己的生父是单孤刀。他与李莲花、笛飞声等人多次合作,共同对抗金鸳盟和单孤刀的阴谋。方多病在一系列案件中展现了出色的推理能力和武艺,逐渐成长为一名优秀的刑探。最终,方多病在天机山庄和皇宫的斗争中发挥了关键作用,帮助李莲花等人挫败了单孤刀的野心。在李莲花中毒后,方多病决心为他寻找解毒之法,展现了深厚的友情。 +笛飞声 (称呼:阿飞、笛大盟主)金鸳盟盟主,曾与李相夷激战并重伤李相夷,后因中毒失去内力,与李莲花有复杂恩怨。笛飞声是金鸳盟盟主,十年前因与李相夷一战成名。他利用单孤刀的弟子朴锄山引诱李相夷,最终重伤李相夷,但自己也被李相夷钉在桅杆上。十年后,笛飞声恢复内力,重新执掌金鸳盟,与角丽谯合作,试图利用罗摩天冰和业火痋控制武林。在与李莲花和方多病的多次交手中,笛飞声多次展现强大实力,但也多次被李莲花等人挫败。最终,笛飞声在与李莲花的对决中被制住,但并未被杀死。笛飞声与李莲花约定在东海再战,但李莲花因中毒未赴约。笛飞声在东海之战中并未出现,留下了许多未解之谜。 +乔婉娩 (称呼:乔姑娘)李莲花的前女友。四顾门前任门主李相夷的爱人,现任门主肖紫衿的妻子,江湖中知名侠女。乔婉娩是四顾门的重要人物,与李相夷有着复杂的情感纠葛。在李相夷失踪后,乔婉娩嫁给了肖紫衿,但内心始终未能忘记李相夷。在李莲花(即李相夷)重新出现后,乔婉娩通过种种线索确认了他的身份,但最终选择支持肖紫衿,维护四顾门的稳定。乔婉娩在四顾门的复兴过程中发挥了重要作用,尤其是在调查金鸳盟和南胤阴谋的过程中,她提供了关键的情报和支持。尽管内心充满矛盾,乔婉娩最终决定与肖紫衿共同面对江湖的挑战,展现了她的坚强和智慧。 +肖紫衿 (称呼:紫衿)李莲花的门主兼旧识。四顾门现任门主,曾与李相夷有深厚恩怨,后与乔婉娩成婚。肖紫衿是四顾门的重要人物,与李相夷和乔婉娩关系密切。他曾在李相夷的衣冠冢前与李莲花对峙,质问他为何归来,并坚持要与李莲花决斗。尽管李莲花展示了武功,但肖紫衿最终选择不与他继续争斗。肖紫衿在乔婉娩与李相夷的误会中扮演了关键角色,一度因嫉妒取消了与乔婉娩的婚事。后来,肖紫衿在乔婉娩的支持下担任四顾门的新门主,致力于复兴四顾门。在与单孤刀的对抗中,肖紫衿展现了坚定的决心和领导能力,最终带领四顾门取得了胜利。 +单孤刀 (称呼:师兄)李莲花的师兄兼敌人。单孤刀,李莲花的师兄,四顾门创始人之一,因不满李相夷与金鸳盟签订协定而独自行动,最终被金鸳盟杀害。单孤刀是李莲花的师兄,与李相夷一同创立四顾门。单孤刀性格争强好胜,难以容人,最终因不满李相夷与金鸳盟签订协定,决定独自行动。单孤刀被金鸳盟杀害,李相夷得知后悲愤交加,誓言与金鸳盟不死不休。单孤刀的死成为李相夷心中的一大阴影,多年后李莲花在调查中发现单孤刀并非真正死亡,而是诈死以实现自己的野心。最终,单孤刀在与李莲花和方多病的对决中失败,被轩辕箫的侍卫杀死。 +# 语言风格 +李莲花的语言风格幽默诙谐,充满智慧和机智,善于用轻松的语气化解紧张的气氛。他常用比喻、反讽和夸张来表达复杂的观点,同时在关键时刻能简洁明了地揭示真相。他的言语中带有调侃和自嘲,但又不失真诚和温情,展现出一种从容不迫的态度。无论是面对朋友还是敌人,李莲花都能以幽默和智慧赢得尊重。 +供参考语言风格的部分李莲花台词: +李莲花:你问我干吗?该启程了啊。 +李莲花:说起师门,你怎么也算云隐山一份子啊?不如趁今日叩拜了你师祖婆婆,再正儿八经给我这个师父磕头敬了茶,往后我守山中、你也尽心在跟前罢? +李莲花:恭贺肖大侠和乔姑娘,喜结连理。 +李莲花淡淡一笑:放心吧,该看到的,都看到了。 +李莲花:如果现在去百川院,你家旺福就白死了。 +``` + + diff --git a/demos/role_playing_system_prompt/role_playing_system_prompt.yaml b/demos/role_playing_system_prompt/role_playing_system_prompt.yaml index f0406ec64..eadac45da 100644 --- a/demos/role_playing_system_prompt/role_playing_system_prompt.yaml +++ b/demos/role_playing_system_prompt/role_playing_system_prompt.yaml @@ -7,13 +7,13 @@ export_path: 'path_to_output_jsonl_file' # process schedule process: - # chunk the novel - - text_chunk_mapper: - max_len: 8000 - split_pattern: '\n\n' - overlap_len: 400 - tokenizer: 'qwen2.5-72b-instruct' - trust_remote_code: True +# # chunk the novel if necessary +# - text_chunk_mapper: +# max_len: 8000 +# split_pattern: '\n\n' +# overlap_len: 400 +# tokenizer: 'qwen2.5-72b-instruct' +# trust_remote_code: True # extract language_style, role_charactor and role_skill - extract_entity_attribute_mapper: api_model: 'qwen2.5-72b-instruct' diff --git a/docs/Operators.md b/docs/Operators.md index c2d81c97d..1d963e07c 100644 --- a/docs/Operators.md +++ b/docs/Operators.md @@ -11,10 +11,12 @@ The operators in Data-Juicer are categorized into 5 types. | Type | Number | Description | |-----------------------------------|:------:|-------------------------------------------------| | [ Formatter ]( #formatter ) | 9 | Discovers, loads, and canonicalizes source data | -| [ Mapper ]( #mapper ) | 60 | Edits and transforms samples | +| [ Mapper ]( #mapper ) | 62 | Edits and transforms samples | | [ Filter ]( #filter ) | 44 | Filters out low-quality samples | | [ Deduplicator ]( #deduplicator ) | 8 | Detects and removes duplicate samples | | [ Selector ]( #selector ) | 4 | Selects top samples based on ranking | +| [ Grouper ]( #grouper ) | 2 | Group samples to batched samples | +| [ Aggregator ]( #aggregator ) | 3 | Aggregate for batched samples, such as summary or conclusion | All the specific operators are listed below, each featured with several capability tags. @@ -72,6 +74,7 @@ All the specific operators are listed below, each featured with several capabili | extract_event_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Extract events and relevant characters in the text. | [code](../data_juicer/ops/mapper/extract_event_mapper.py) | [tests](../tests/ops/mapper/test_extract_event_mapper.py) | | extract_keyword_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Generate keywords for the text. | [code](../data_juicer/ops/mapper/extract_keyword_mapper.py) | [tests](../tests/ops/mapper/test_extract_keyword_mapper.py) | | extract_nickname_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Extract nickname relationship in the text. | [code](../data_juicer/ops/mapper/extract_nickname_mapper.py) | [tests](../tests/ops/mapper/test_extract_nickname_mapper.py) | +| extract_support_text_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Extract support sub text for a summary. | [code](../data_juicer/ops/mapper/extract_support_text_mapper.py) | [tests](../tests/ops/mapper/test_extract_support_text_mapper.py) | | fix_unicode_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Fixes broken Unicodes (by [ftfy](https://ftfy.readthedocs.io/)) | [code](../data_juicer/ops/mapper/fix_unicode_mapper.py) | [tests](../tests/ops/mapper/test_fix_unicode_mapper.py) | | generate_qa_from_examples_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) ![GPU](https://img.shields.io/badge/GPU-F27649?style=plastic) | Generate question and answer pairs based on examples. | [code](../data_juicer/ops/mapper/generate_qa_from_examples_mapper.py) | [tests](../tests/ops/mapper/test_generate_qa_from_examples_mapper.py) | | generate_qa_from_text_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) ![GPU](https://img.shields.io/badge/GPU-F27649?style=plastic) | Generate question and answer pairs from text. | [code](../data_juicer/ops/mapper/generate_qa_from_text_mapper.py) | [tests](../tests/ops/mapper/test_generate_qa_from_text_mapper.py) | @@ -89,6 +92,7 @@ All the specific operators are listed below, each featured with several capabili | pair_preference_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Construct paired preference samples. | [code](../data_juicer/ops/mapper/pair_preference_mapper.py) | [tests](../tests/ops/mapper/test_pair_preference_mapper.py) | | punctuation_normalization_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Normalizes various Unicode punctuations to their ASCII equivalents | [code](../data_juicer/ops/mapper/punctuation_normalization_mapper.py) | [tests](../tests/ops/mapper/test_punctuation_normalization_mapper.py) | | python_file_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Executing Python function defined in a file | [code](../data_juicer/ops/mapper/python_file_mapper.py) | [tests](../tests/ops/mapper/test_python_file_mapper.py) | +| relation_identity_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Identify relation between two entity in the text. | [code](../data_juicer/ops/mapper/relation_identity_mapper.py) | [tests](../tests/ops/mapper/test_relation_identity_mapper.py) | | remove_bibliography_mapper | ![LaTeX](https://img.shields.io/badge/LaTeX-D99379?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Removes the bibliography of TeX documents | [code](../data_juicer/ops/mapper/remove_bibliography_mapper.py) | [tests](../tests/ops/mapper/test_remove_bibliography_mapper.py) | | remove_comments_mapper | ![LaTeX](https://img.shields.io/badge/LaTeX-D99379?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Removes the comments of TeX documents | [code](../data_juicer/ops/mapper/remove_comments_mapper.py) | [tests](../tests/ops/mapper/test_remove_comments_mapper.py) | | remove_header_mapper | ![LaTeX](https://img.shields.io/badge/LaTeX-D99379?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Removes the running headers of TeX documents, e.g., titles, chapter or section numbers/names | [code](../data_juicer/ops/mapper/remove_header_mapper.py) | [tests](../tests/ops/mapper/test_remove_header_mapper.py) | @@ -188,6 +192,21 @@ All the specific operators are listed below, each featured with several capabili | range_specified_field_selector | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Selects samples within a specified range by comparing the values of the specified field | [code](../data_juicer/ops/selector/range_specified_field_selector.py) | [tests](../tests/ops/selector/test_range_specified_field_selector.py) | | topk_specified_field_selector | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Selects top samples by comparing the values of the specified field | [code](../data_juicer/ops/selector/topk_specified_field_selector.py) | [tests](../tests/ops/selector/test_topk_specified_field_selector.py) | +## Grouper + +| Operator | Tags | Description | Source code | Unit tests | +|------------------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|-----------------------------------------------------------------------------------------|-------------------------------------------------------------------------------|---------------------------------------------------------------------------| +| key_value_grouper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Group samples to batched samples according values in given keys. | [code](../data_juicer/ops/grouper/key_value_grouper.py) | [tests](../tests/ops/grouper/test_key_value_grouper.py) | +| naive_grouper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Group all samples to one batched sample. | [code](../data_juicer/ops/grouper/naive_grouper.py) | [tests](../tests/ops/grouper/test_naive_grouper.py) | + +## Aggregator + +| Operator | Tags | Description | Source code | Unit tests | +|------------------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|-----------------------------------------------------------------------------------------|-------------------------------------------------------------------------------|---------------------------------------------------------------------------| +| entity_attribute_aggregator | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Return conclusion of the given entity's attribute from some docs. | [code](../data_juicer/ops/aggregator/entity_attribute_aggregator.py) | [tests](../tests/ops/aggregator/test_entity_attribute_aggregator.py) | +| most_relavant_entities_aggregator | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Extract entities closely related to a given entity from some texts, and sort them in descending order of importance. | [code](../data_juicer/ops/aggregator/most_relavant_entities_aggregator.py) | [tests](../tests/ops/aggregator/test_most_relavant_entities_aggregator.py) | +| nested_aggregator | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Considering the limitation of input length, nested aggregate contents for each given number of samples. | [code](../data_juicer/ops/aggregator/nested_aggregator.py) | [tests](../tests/ops/aggregator/test_nested_aggregator.py) | + ## Contributing We welcome contributions of adding new operators. Please refer to [How-to Guide for Developers](DeveloperGuide.md). diff --git a/docs/Operators_ZH.md b/docs/Operators_ZH.md index e12d9caec..f93b4b677 100644 --- a/docs/Operators_ZH.md +++ b/docs/Operators_ZH.md @@ -11,10 +11,12 @@ Data-Juicer 中的算子分为以下 5 种类型。 | 类型 | 数量 | 描述 | |------------------------------------|:--:|---------------| | [ Formatter ]( #formatter ) | 9 | 发现、加载、规范化原始数据 | -| [ Mapper ]( #mapper ) | 60 | 对数据样本进行编辑和转换 | +| [ Mapper ]( #mapper ) | 62 | 对数据样本进行编辑和转换 | | [ Filter ]( #filter ) | 44 | 过滤低质量样本 | | [ Deduplicator ]( #deduplicator ) | 8 | 识别、删除重复样本 | | [ Selector ]( #selector ) | 4 | 基于排序选取高质量样本 | +| [ Grouper ]( #grouper ) | 2 | 将样本分组,每一组组成一个批量样本 | +| [ Aggregator ]( #aggregator ) | 3 | 对批量样本进行汇总,如得出总结或结论 | 下面列出所有具体算子,每种算子都通过多个标签来注明其主要功能。 @@ -71,6 +73,7 @@ Data-Juicer 中的算子分为以下 5 种类型。 | extract_event_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 从文本中抽取出事件和事件相关人物 | [code](../data_juicer/ops/mapper/extract_event_mapper.py) | [tests](../tests/ops/mapper/test_extract_event_mapper.py) | | extract_keyword_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 构造文本的关键词 | [code](../data_juicer/ops/mapper/extract_keyword_mapper.py) | [tests](../tests/ops/mapper/test_extract_keyword_mapper.py) | | extract_nickname_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 抽取昵称称呼关系 | [code](../data_juicer/ops/mapper/extract_nickname_mapper.py) | [tests](../tests/ops/mapper/test_extract_nickname_mapper.py) | +| extract_support_text_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 为一段总结抽取对应原文 | [code](../data_juicer/ops/mapper/extract_support_text_mapper.py) | [tests](../tests/ops/mapper/test_extract_support_text_mapper.py) | | fix_unicode_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 修复损坏的 Unicode(借助 [ftfy](https://ftfy.readthedocs.io/)) | [code](../data_juicer/ops/mapper/fix_unicode_mapper.py) | [tests](../tests/ops/mapper/test_fix_unicode_mapper.py) | | generate_qa_from_examples_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) ![GPU](https://img.shields.io/badge/GPU-F27649?style=plastic) | 根据种子数据,生成新的对话样本。 | [code](../data_juicer/ops/mapper/generate_qa_from_examples_mapper.py) | [tests](../tests/ops/mapper/test_generate_qa_from_examples_mapper.py) | | generate_qa_from_text_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) ![GPU](https://img.shields.io/badge/GPU-F27649?style=plastic) | 从文本中生成问答对 | [code](../data_juicer/ops/mapper/generate_qa_from_text_mapper.py) | [tests](../tests/ops/mapper/test_generate_qa_from_text_mapper.py) | @@ -88,6 +91,7 @@ Data-Juicer 中的算子分为以下 5 种类型。 | pair_preference_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 构造配对的偏好样本 | [code](../data_juicer/ops/mapper/pair_preference_mapper.py) | [tests](../tests/ops/mapper/test_pair_preference_mapper.py) | | punctuation_normalization_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 将各种 Unicode 标点符号标准化为其 ASCII 等效项 | [code](../data_juicer/ops/mapper/punctuation_normalization_mapper.py) | [tests](../tests/ops/mapper/test_punctuation_normalization_mapper.py) | | python_file_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 执行文件中定义的 Python 函数处理样本 | [code](../data_juicer/ops/mapper/python_file_mapper.py) | [tests](../tests/ops/mapper/test_python_file_mapper.py) | +| relation_identity_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 识别一段文本中两个实体之间的关系 | [code](../data_juicer/ops/mapper/relation_identity_mapper.py) | [tests](../tests/ops/mapper/test_relation_identity_mapper.py) | | remove_bibliography_mapper | ![LaTeX](https://img.shields.io/badge/LaTeX-D99379?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 删除 TeX 文档的参考文献 | [code](../data_juicer/ops/mapper/remove_bibliography_mapper.py) | [tests](../tests/ops/mapper/test_remove_bibliography_mapper.py) | | remove_comments_mapper | ![LaTeX](https://img.shields.io/badge/LaTeX-D99379?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 删除 TeX 文档中的注释 | [code](../data_juicer/ops/mapper/remove_comments_mapper.py) | [tests](../tests/ops/mapper/test_remove_comments_mapper.py) | | remove_header_mapper | ![LaTeX](https://img.shields.io/badge/LaTeX-D99379?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 删除 TeX 文档头,例如标题、章节数字/名称等 | [code](../data_juicer/ops/mapper/remove_header_mapper.py) | [tests](../tests/ops/mapper/test_remove_header_mapper.py) | @@ -187,5 +191,20 @@ Data-Juicer 中的算子分为以下 5 种类型。 | range_specified_field_selector | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 通过比较指定字段的值选出指定范围的 k 个样本 | [code](../data_juicer/ops/selector/range_specified_field_selector.py) | [tests](../tests/ops/selector/test_range_specified_field_selector.py) | | topk_specified_field_selector | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 通过比较指定字段的值选出前 k 个样本 | [code](../data_juicer/ops/selector/topk_specified_field_selector.py) | [tests](../tests/ops/selector/test_topk_specified_field_selector.py) | +## Grouper + +| 算子 | 标签 | 描述 | 源码 | 单测样例 | +|-------------------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|--------------------------|---------------------------------------------------------------------------|---------------------------------------------------------------------------| +| key_value_grouper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 根据给定键的值将样本分组,每一组组成一个批量样本。 | [code](../data_juicer/ops/grouper/key_value_grouper.py) | [tests](../tests/ops/grouper/test_key_value_grouper.py) | +| naive_grouper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 将所有样本分为一个组,返回一个批量样本 | [code](../data_juicer/ops/grouper/naive_grouper.py) | [tests](../tests/ops/grouper/test_naive_grouper.py) | + +## Aggregator + +| 算子 | 标签 | 描述 | 源码 | 单测样例 | +|-------------------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|--------------------------|---------------------------------------------------------------------------|---------------------------------------------------------------------------| +| entity_attribute_aggregator | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 从一些文本中总结出给定实体的属性 | [code](../data_juicer/ops/aggregator/entity_attribute_aggregator.py) | [tests](../tests/ops/aggregator/test_entity_attribute_aggregator.py) | +| most_relavant_entities_aggregator | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 从一些文本中抽取出与给定实体密切相关的实体,按重要性从高到低排序 | [code](../data_juicer/ops/aggregator/most_relavant_entities_aggregator.py) | [tests](../tests/ops/aggregator/test_most_relavant_entities_aggregator.py) | +| nested_aggregator | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 考虑到输入长度的限制,对样本中的内容进行嵌套聚合。 | [code](../data_juicer/ops/aggregator/nested_aggregator.py) | [tests](../tests/ops/aggregator/test_nested_aggregator.py) | + ## 贡献 我们欢迎社区贡献新的算子,具体请参考[开发者指南](DeveloperGuide_ZH.md)。 From a7860702d4a0e881d7069ff8716df832018259da Mon Sep 17 00:00:00 2001 From: Haibin <1400012807@pku.edu.cn> Date: Fri, 6 Dec 2024 10:45:47 +0800 Subject: [PATCH 059/118] skip api test --- tests/ops/Aggregator/test_entity_attribute_aggregator.py | 3 ++- tests/ops/Aggregator/test_most_relavant_entities_aggregator.py | 3 ++- tests/ops/Aggregator/test_nested_aggregator.py | 3 ++- tests/ops/mapper/test_text_chunk_mapper.py | 3 ++- 4 files changed, 8 insertions(+), 4 deletions(-) diff --git a/tests/ops/Aggregator/test_entity_attribute_aggregator.py b/tests/ops/Aggregator/test_entity_attribute_aggregator.py index bb1fb4ee7..1f80da3a3 100644 --- a/tests/ops/Aggregator/test_entity_attribute_aggregator.py +++ b/tests/ops/Aggregator/test_entity_attribute_aggregator.py @@ -4,9 +4,10 @@ from data_juicer.core.data import NestedDataset as Dataset from data_juicer.ops.aggregator import EntityAttributeAggregator -from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase +from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase, SKIPPED_TESTS +@SKIPPED_TESTS.register_module() class EntityAttributeAggregatorTest(DataJuicerTestCaseBase): def _run_helper(self, op, samples): diff --git a/tests/ops/Aggregator/test_most_relavant_entities_aggregator.py b/tests/ops/Aggregator/test_most_relavant_entities_aggregator.py index dccee6704..1d8678134 100644 --- a/tests/ops/Aggregator/test_most_relavant_entities_aggregator.py +++ b/tests/ops/Aggregator/test_most_relavant_entities_aggregator.py @@ -4,9 +4,10 @@ from data_juicer.core.data import NestedDataset as Dataset from data_juicer.ops.aggregator import MostRelavantEntitiesAggregator -from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase +from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase, SKIPPED_TESTS +@SKIPPED_TESTS.register_module() class MostRelavantEntitiesAggregatorTest(DataJuicerTestCaseBase): def _run_helper(self, op, samples): diff --git a/tests/ops/Aggregator/test_nested_aggregator.py b/tests/ops/Aggregator/test_nested_aggregator.py index eebf9d38a..6347652bc 100644 --- a/tests/ops/Aggregator/test_nested_aggregator.py +++ b/tests/ops/Aggregator/test_nested_aggregator.py @@ -4,9 +4,10 @@ from data_juicer.core.data import NestedDataset as Dataset from data_juicer.ops.aggregator import NestedAggregator -from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase +from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase, SKIPPED_TESTS +@SKIPPED_TESTS.register_module() class NestedAggregatorTest(DataJuicerTestCaseBase): def _run_helper(self, op, samples): diff --git a/tests/ops/mapper/test_text_chunk_mapper.py b/tests/ops/mapper/test_text_chunk_mapper.py index 8004d9ede..0c0a70db3 100644 --- a/tests/ops/mapper/test_text_chunk_mapper.py +++ b/tests/ops/mapper/test_text_chunk_mapper.py @@ -2,9 +2,10 @@ from data_juicer.core.data import NestedDataset as Dataset from data_juicer.ops.mapper.text_chunk_mapper import TextChunkMapper -from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase +from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase, SKIPPED_TESTS +@SKIPPED_TESTS.register_module() class TextChunkMapperTest(DataJuicerTestCaseBase): def _run_helper(self, op, samples, target): From 788a2126a0bd8f1a7a5786830ab3c76922fc691e Mon Sep 17 00:00:00 2001 From: Haibin <1400012807@pku.edu.cn> Date: Fri, 6 Dec 2024 17:00:19 +0800 Subject: [PATCH 060/118] add env dependency --- data_juicer/utils/auto_install_mapping.py | 21 +++++++++++++++++++-- environments/science_requires.txt | 2 ++ 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/data_juicer/utils/auto_install_mapping.py b/data_juicer/utils/auto_install_mapping.py index 96a54b437..e216f130e 100644 --- a/data_juicer/utils/auto_install_mapping.py +++ b/data_juicer/utils/auto_install_mapping.py @@ -62,7 +62,10 @@ 'optimize_qa_mapper', 'video_captioning_from_audio_mapper', 'video_captioning_from_frames_mapper', 'video_captioning_from_summarizer_mapper', - 'video_captioning_from_video_mapper', 'video_tagging_from_audio_mapper' + 'video_captioning_from_video_mapper', + 'video_tagging_from_audio_mapper', 'text_chunk_mapper', + 'entity_attribute_aggregator', 'most_relavant_entities_aggregator', + 'nested_aggregator' ], 'transformers_stream_generator': [ 'video_captioning_from_audio_mapper', @@ -104,5 +107,19 @@ 'optimize_qa_mapper', ], 'rouge': ['generate_qa_from_examples_mapper'], - 'ram': ['image_tagging_mapper', 'video_tagging_from_frames_mapper'] + 'ram': ['image_tagging_mapper', 'video_tagging_from_frames_mapper'], + 'dashscope': [ + 'text_chunk_mapper', 'entity_attribute_aggregator', + 'most_relavant_entities_aggregator', 'nested_aggregator' + ], + 'openai': [ + 'calibrate_qa_mapper', 'calibrate_query_mapper', + 'calibrate_response_mapper', 'extract_entity_attribute_mapper', + 'extract_entity_relation_mapper', 'extract_event_mapper', + 'extract_keyword_mapper', 'extract_nickname_mapper', + 'extract_support_text_mapper', 'pair_preference_mapper', + 'relation_identity_mapper', 'text_chunk_mapper', + 'entity_attribute_aggregator', 'most_relavant_entities_aggregator', + 'nested_aggregator' + ] } diff --git a/environments/science_requires.txt b/environments/science_requires.txt index 10ea3b86e..af5d6b362 100644 --- a/environments/science_requires.txt +++ b/environments/science_requires.txt @@ -26,3 +26,5 @@ ffmpeg-python opencv-python vllm>=0.1.3 rouge +dashscope +openai From 10242c4dfaacdaf26ea7b81913343f36bd97dbfa Mon Sep 17 00:00:00 2001 From: Haibin <1400012807@pku.edu.cn> Date: Tue, 10 Dec 2024 16:49:08 +0800 Subject: [PATCH 061/118] install by recipe --- README.md | 12 ++ README_ZH.md | 11 ++ data_juicer/utils/auto_install_mapping.py | 159 +++++++++------------- tools/install_by_recipe.py | 65 +++++++++ 4 files changed, 155 insertions(+), 92 deletions(-) create mode 100644 tools/install_by_recipe.py diff --git a/README.md b/README.md index eb34e17ba..3ede912b0 100644 --- a/README.md +++ b/README.md @@ -197,6 +197,18 @@ The dependency options are listed below: | `.[tools]` | Install dependencies for dedicated tools, such as quality classifiers. | | `.[sandbox]` | Install all dependencies for sandbox. | +- Install dependencies for specific OPs + +With the growth of the number of OPs, the dependencies of all OPs becomes very heavy. Instead of using the command `pip install -v -e .[sci]` to install all dependencies, +we provide two alternative, lighter options: + + - Automatic Minimal Dependency Installation: During the execution of Data-Juicer, minimal dependencies will be automatically installed. This allows for immediate execution, but may potentially lead to dependency conflicts. + + - Manual Minimal Dependency Installation: To manually install minimal dependencies tailored to a specific execution configuration, run the following command: + ```shell + python tools/install_by_recipe.py --config path_to_your_data-juicer_config_file + ``` + ### Using pip - Run the following command to install the latest released `data_juicer` using `pip`: diff --git a/README_ZH.md b/README_ZH.md index 905a4e1a2..a0439cee2 100644 --- a/README_ZH.md +++ b/README_ZH.md @@ -178,6 +178,17 @@ pip install -v -e .[tools] # 安装部分工具库的依赖 | `.[tools]` | 安装专用工具库(如质量分类器)所需的依赖项 | | `.[sandbox]` | 安装沙盒实验室的基础依赖 | +* 只安装部分算子依赖 + +随着OP数量的增长,所有OP的依赖变得很重。为此,我们提供了两个替代的、更轻量的选项,作为使用命令`pip install -v -e .[sci]`安装所有依赖的替代: + + * 自动最小依赖安装:在执行Data-Juicer的过程中,将自动安装最小依赖。也就是说你可以直接执行,但这种方式可能会导致一些依赖冲突。 + + * 手动最小依赖安装:可以通过如下指令手动安装适合特定执行配置的最小依赖: + ```shell + python tools/install_by_recipe.py --config path_to_your_data-juicer_config_file + ``` + ### 使用 pip 安装 * 运行以下命令用 `pip` 安装 `data_juicer` 的最新发布版本: diff --git a/data_juicer/utils/auto_install_mapping.py b/data_juicer/utils/auto_install_mapping.py index 96a54b437..7f685a05a 100644 --- a/data_juicer/utils/auto_install_mapping.py +++ b/data_juicer/utils/auto_install_mapping.py @@ -10,99 +10,74 @@ 'simhash': ['simhash-pybind'], } -# Packages to corresponding ops that require them -PKG_TO_OPS = { - 'torch': [ - 'image_aesthetics_filter', 'image_nsfw_filter', - 'image_text_matching_filter', 'image_text_similarity_filter', - 'image_watermark_filter', 'phrase_grounding_recall_filter', - 'video_aesthetics_filter', 'video_frames_text_similarity_filter', - 'video_nsfw_filter', 'video_tagging_from_frames_filter', - 'video_watermark_filter', 'generate_qa_from_text_mapper', - 'generate_qa_from_examples_mapper', 'image_captioning_mapper', - 'image_diffusion_mapper', 'image_tagging_mapper', - 'optimize_query_mapper', 'optimize_response_mapper', - 'optimize_qa_mapper', 'video_captioning_from_frames_mapper', - 'video_captioning_from_summarizer_mapper', - 'video_captioning_from_video_mapper', - 'video_tagging_from_audio_mapper', 'video_tagging_from_frames_mapper' - ], - 'torchaudio': [ - 'video_captioning_from_summarizer_mapper', - 'video_tagging_from_audio_mapper' - ], - 'easyocr': ['video_ocr_area_ratio_filter'], - 'fasttext-wheel': ['language_id_score_filter'], - 'kenlm': ['perplexity_filter'], - 'sentencepiece': [ - 'flagged_words_filter', 'perplexity_filter', 'stopwords_filter', - 'word_repetition_filter', 'words_num_filter' - ], - 'scipy': ['document_minhash_deduplicator'], - 'ftfy': ['fix_unicode_mapper'], - 'simhash-pybind': [ - 'document_simhash_deduplicator', 'image_captioning_mapper', - 'image_diffusion_mapper', 'video_captioning_from_frames_mapper', - 'video_captioning_from_summarizer_mapper', - 'video_captioning_from_video_mapper' - ], - 'selectolax': ['clean_html_mapper'], - 'nlpaug': ['nlpaug_en_mapper'], +# Extra packages required by each op +OPS_TO_PKG = { + 'video_aesthetics_filter': + ['simple-aesthetics-predictor', 'torch', 'transformers'], + 'document_simhash_deduplicator': ['simhash-pybind'], 'nlpcda': ['nlpcda'], - 'nltk': ['phrase_grounding_recall_filter', 'sentence_split_mapper'], - 'transformers': [ - 'alphanumeric_filter', 'image_aesthetics_filter', 'image_nsfw_filter', - 'image_text_matching_filter', 'image_text_similarity_filter', - 'image_watermark_filter', 'phrase_grounding_recall_filter', - 'token_num_filter', 'video_aesthetics_filter', - 'video_frames_text_similarity_filter', 'video_nsfw_filter', - 'generate_qa_from_text_mapper', 'generate_qa_from_examples_mapper', - 'image_captioning_mapper', 'image_diffusion_mapper', - 'optimize_query_mapper', 'optimize_response_mapper', - 'optimize_qa_mapper', 'video_captioning_from_audio_mapper', - 'video_captioning_from_frames_mapper', - 'video_captioning_from_summarizer_mapper', - 'video_captioning_from_video_mapper', 'video_tagging_from_audio_mapper' - ], - 'transformers_stream_generator': [ - 'video_captioning_from_audio_mapper', - 'video_captioning_from_summarizer_mapper' - ], - 'einops': [ - 'video_captioning_from_audio_mapper', - 'video_captioning_from_summarizer_mapper' - ], - 'accelerate': [ - 'video_captioning_from_audio_mapper', - 'video_captioning_from_summarizer_mapper' - ], - 'tiktoken': [ - 'video_captioning_from_audio_mapper', - 'video_captioning_from_summarizer_mapper' - ], - 'opencc': ['chinese_convert_mapper'], - 'imagededup': ['image_deduplicator', 'ray_image_deduplicator'], - 'spacy-pkuseg': ['text_action_filter', 'text_entity_dependency_filter'], - 'diffusers': ['image_diffusion_mapper'], - 'simple-aesthetics-predictor': - ['image_aesthetics_filter', 'video_aesthetics_filter'], - 'scenedetect[opencv]': ['video_split_by_scene_mapper'], - 'ffmpeg-python': [ - 'audio_ffmpeg_wrapped_mapper', 'video_ffmpeg_wrapped_mapper', - 'video_resize_aspect_ratio_mapper', 'video_resize_resolution_mapper' - ], - 'opencv-python': [ - 'image_face_ratio_filter', 'video_motion_score_filter', - 'image_face_blur_mapper', 'video_face_blur_mapper', - 'video_remove_watermark_mapper' + 'image_aesthetics_filter': + ['simple-aesthetics-predictor', 'torch', 'transformers'], + 'video_nsfw_filter': ['torch', 'transformers'], + 'video_face_blur_mapper': ['opencv-python'], + 'stopwords_filter': ['sentencepiece'], + 'fix_unicode_mapper': ['ftfy'], + 'token_num_filter': ['transformers'], + 'optimize_qa_mapper': ['torch', 'transformers', 'vllm'], + 'video_motion_score_filter': ['opencv-python'], + 'image_tagging_mapper': ['ram', 'torch'], + 'video_resize_aspect_ratio_mapper': ['ffmpeg-python'], + 'video_captioning_from_audio_mapper': [ + 'accelerate', 'einops', 'tiktoken', 'transformers', + 'transformers_stream_generator' ], - 'vllm': [ - 'generate_qa_from_text_mapper', - 'generate_qa_from_examples_mapper', - 'optimize_query_mapper', - 'optimize_response_mapper', - 'optimize_qa_mapper', + 'clean_html_mapper': ['selectolax'], + 'video_tagging_from_audio_mapper': ['torch', 'torchaudio', 'transformers'], + 'image_deduplicator': ['imagededup'], + 'image_diffusion_mapper': + ['diffusers', 'simhash-pybind', 'torch', 'transformers'], + 'image_text_similarity_filter': ['torch', 'transformers'], + 'alphanumeric_filter': ['transformers'], + 'image_nsfw_filter': ['torch', 'transformers'], + 'image_watermark_filter': ['torch', 'transformers'], + 'ray_image_deduplicator': ['imagededup'], + 'video_captioning_from_frames_mapper': + ['simhash-pybind', 'torch', 'transformers'], + 'video_tagging_from_frames_filter': ['torch'], + 'video_resize_resolution_mapper': ['ffmpeg-python'], + 'optimize_query_mapper': ['torch', 'transformers', 'vllm'], + 'sentence_split_mapper': ['nltk'], + 'image_text_matching_filter': ['torch', 'transformers'], + 'phrase_grounding_recall_filter': ['nltk', 'torch', 'transformers'], + 'video_split_by_scene_mapper': ['scenedetect[opencv]'], + 'image_face_blur_mapper': ['opencv-python'], + 'image_face_ratio_filter': ['opencv-python'], + 'document_minhash_deduplicator': ['scipy'], + 'flagged_words_filter': ['sentencepiece'], + 'language_id_score_filter': ['fasttext-wheel'], + 'words_num_filter': ['sentencepiece'], + 'chinese_convert_mapper': ['opencc'], + 'video_frames_text_similarity_filter': ['torch', 'transformers'], + 'generate_qa_from_text_mapper': ['torch', 'transformers', 'vllm'], + 'video_ffmpeg_wrapped_mapper': ['ffmpeg-python'], + 'image_captioning_mapper': ['simhash-pybind', 'torch', 'transformers'], + 'video_ocr_area_ratio_filter': ['easyocr'], + 'video_captioning_from_video_mapper': + ['simhash-pybind', 'torch', 'transformers'], + 'video_remove_watermark_mapper': ['opencv-python'], + 'text_action_filter': ['spacy-pkuseg'], + 'nlpaug_en_mapper': ['nlpaug'], + 'word_repetition_filter': ['sentencepiece'], + 'video_watermark_filter': ['torch'], + 'video_captioning_from_summarizer_mapper': [ + 'accelerate', 'einops', 'simhash-pybind', 'tiktoken', 'torch', + 'torchaudio', 'transformers', 'transformers_stream_generator' ], - 'rouge': ['generate_qa_from_examples_mapper'], - 'ram': ['image_tagging_mapper', 'video_tagging_from_frames_mapper'] + 'audio_ffmpeg_wrapped_mapper': ['ffmpeg-python'], + 'perplexity_filter': ['kenlm', 'sentencepiece'], + 'generate_qa_from_examples_mapper': + ['rouge', 'torch', 'transformers', 'vllm'], + 'video_tagging_from_frames_mapper': ['ram', 'torch'], + 'text_entity_dependency_filter': ['spacy-pkuseg'], + 'optimize_response_mapper': ['torch', 'transformers', 'vllm'] } diff --git a/tools/install_by_recipe.py b/tools/install_by_recipe.py new file mode 100644 index 000000000..54b0b3dd3 --- /dev/null +++ b/tools/install_by_recipe.py @@ -0,0 +1,65 @@ +import os +import subprocess +import sys +import tempfile + +from loguru import logger + +from data_juicer.config import init_configs +from data_juicer.utils.auto_install_mapping import OPS_TO_PKG + +require_version_paths = ['./environments/science_requires.txt'] + + +def main(): + cfg = init_configs() + + # get the ops in the recipe + op_names = [list(op.keys())[0] for op in cfg.process] + recipe_reqs = [] + for op_name in op_names: + recipe_reqs.extend(OPS_TO_PKG[op_name]) + recipe_reqs = list(set(recipe_reqs)) + + # get the package version limit of Data-Juicer + version_map, reqs = {}, [] + for path in require_version_paths: + if not os.path.exists(path): + logger.warning(f'target file does not exist: {path}') + else: + with open(path, 'r', encoding='utf-8') as fin: + reqs += [x.strip() for x in fin.read().splitlines()] + for req in reqs: + clean_req = req.replace('<', + ' ').replace('>', + ' ').replace('=', + ' ').split(' ')[0] + version_map[clean_req] = req + + # generate require file for the recipe + with tempfile.NamedTemporaryFile(delete=False, mode='w') as temp_file: + temp_file_path = temp_file.name + for req in recipe_reqs: + if req in version_map: + temp_file.write(version_map[req] + '\n') + else: + temp_file.write(req + '\n') + + # install by calling 'pip install -r ...' + try: + subprocess.check_call( + [sys.executable, '-m', 'pip', 'install', '-r', temp_file_path]) + logger.info('Requirements were installed successfully.') + except subprocess.CalledProcessError as e: + logger.info( + f'An error occurred while installing the requirements: {e}') + if os.path.exists(temp_file_path): + os.remove(temp_file_path) + sys.exit(1) + finally: + if os.path.exists(temp_file_path): + os.remove(temp_file_path) + + +if __name__ == '__main__': + main() From 6a43eecd6599dee6e2e9c8f27b94c4588f3cf47d Mon Sep 17 00:00:00 2001 From: Haibin <1400012807@pku.edu.cn> Date: Thu, 12 Dec 2024 11:02:38 +0800 Subject: [PATCH 062/118] dialog sent intensity --- data_juicer/ops/mapper/__init__.py | 7 +- data_juicer/ops/mapper/calibrate_qa_mapper.py | 2 + .../dialog_sentiment_intensity_mapper.py | 199 ++++++++++++++++++ data_juicer/utils/constant.py | 6 + .../test_dialog_sentiment_intensity_mapper.py | 64 ++++++ .../test_extract_entity_attribute_mapper.py | 2 +- .../test_extract_entity_relation_mapper.py | 2 +- tests/ops/mapper/test_extract_event_mapper.py | 2 +- .../ops/mapper/test_extract_keyword_mapper.py | 2 +- .../mapper/test_extract_nickname_mapper.py | 2 +- .../test_extract_support_text_mapper.py | 2 +- .../mapper/test_relation_identity_mapper.py | 2 +- 12 files changed, 282 insertions(+), 10 deletions(-) create mode 100644 data_juicer/ops/mapper/dialog_sentiment_intensity_mapper.py create mode 100644 tests/ops/mapper/test_dialog_sentiment_intensity_mapper.py diff --git a/data_juicer/ops/mapper/__init__.py b/data_juicer/ops/mapper/__init__.py index 5a740d192..a994a4ba4 100644 --- a/data_juicer/ops/mapper/__init__.py +++ b/data_juicer/ops/mapper/__init__.py @@ -8,6 +8,7 @@ from .clean_html_mapper import CleanHtmlMapper from .clean_ip_mapper import CleanIpMapper from .clean_links_mapper import CleanLinksMapper +from .dialog_sentiment_intensity_mapper import DialogSentimentIntensityMapper from .expand_macro_mapper import ExpandMacroMapper from .extract_entity_attribute_mapper import ExtractEntityAttributeMapper from .extract_entity_relation_mapper import ExtractEntityRelationMapper @@ -70,9 +71,9 @@ 'AudioFFmpegWrappedMapper', 'CalibrateQAMapper', 'CalibrateQueryMapper', 'CalibrateResponseMapper', 'ChineseConvertMapper', 'CleanCopyrightMapper', 'CleanEmailMapper', 'CleanHtmlMapper', 'CleanIpMapper', 'CleanLinksMapper', - 'ExpandMacroMapper', 'ExtractEntityAttributeMapper', - 'ExtractEntityRelationMapper', 'ExtractEventMapper', - 'ExtractKeywordMapper', 'ExtractNicknameMapper', + 'DialogSentimentIntensityMapper', 'ExpandMacroMapper', + 'ExtractEntityAttributeMapper', 'ExtractEntityRelationMapper', + 'ExtractEventMapper', 'ExtractKeywordMapper', 'ExtractNicknameMapper', 'ExtractSupportTextMapper', 'FixUnicodeMapper', 'GenerateQAFromExamplesMapper', 'GenerateQAFromTextMapper', 'ImageBlurMapper', 'ImageCaptioningFromGPT4VMapper', diff --git a/data_juicer/ops/mapper/calibrate_qa_mapper.py b/data_juicer/ops/mapper/calibrate_qa_mapper.py index 8480ee899..bf9686409 100644 --- a/data_juicer/ops/mapper/calibrate_qa_mapper.py +++ b/data_juicer/ops/mapper/calibrate_qa_mapper.py @@ -55,6 +55,8 @@ def __init__(self, :param reference_template: Template for formatting the reference text. :param qa_pair_template: Template for formatting question-answer pairs. :param output_pattern: Regular expression for parsing model output. + :param try_num: The number of retry attempts when there is an API + call error or output parsing error. :param model_params: Parameters for initializing the API model. :param sampling_params: Extra parameters passed to the API call. e.g {'temperature': 0.9, 'top_p': 0.95} diff --git a/data_juicer/ops/mapper/dialog_sentiment_intensity_mapper.py b/data_juicer/ops/mapper/dialog_sentiment_intensity_mapper.py new file mode 100644 index 000000000..899bc2614 --- /dev/null +++ b/data_juicer/ops/mapper/dialog_sentiment_intensity_mapper.py @@ -0,0 +1,199 @@ +import re +from typing import Dict, Optional + +from loguru import logger +from pydantic import PositiveInt + +from data_juicer.ops.base_op import OPERATORS, Mapper +from data_juicer.utils.common_utils import nested_set +from data_juicer.utils.constant import MetaKeys +from data_juicer.utils.model_utils import get_model, prepare_model + +OP_NAME = 'dialog_sentiment_intensity_mapper' + + +# TODO: LLM-based inference. +@OPERATORS.register_module(OP_NAME) +class DialogSentimentIntensityMapper(Mapper): + """ + Mapper to predict user's sentiment intensity in dialog which is stored + in the history_key. + """ + + DEFAULT_SYSTEM_PROMPT = ('请判断用户和LLM多轮对话中用户的情绪变化。\n' + '要求:\n' + '- 用户情绪值是-5到5之间到整数,-5表示极度负面,5表示极度正面,' + '-5到5之间数值表示情绪从负面逐渐到正面的变化过程,0代表情呈绪中性。\n' + '- 需要先进行分析,然后确定用户的情绪值,下面是一个样例,请模仿样例格式输出。\n' + '用户:你好,我对可持续发展的定义有点模糊,帮我解释一下?\n' + '情绪分析:刚开始,还没得到LLM回复,用户情绪呈中性。\n' + '情绪值:0\n' + 'LLM:当然可以!可持续发展是指在满足当代人的需求的同时,不损害子孙后代满足其自' + '身需求的能力的发展模式。它包括经济发展、社会发展和环境保护三个主要方面。通过合' + '理利用资源和保护环境,我们可以确保未来的世代也能享有健全的生态系统和经济制度。\n' + '用户:谢谢你的解释!那你能告诉我一些普通人可以采取的可持续生活方式吗?\n' + '情绪分析:对回答感到满意,情绪正面。\n' + '情绪值:1\n' + 'LLM:当然可以,普通人可以通过减少一次性产品的使用、选择公共交通或拼车、节约用' + '水、以及支持本地和可持续发展的企业等方式来践行可持续生活。此外,关注垃圾分类和' + '多用电子账单也是不错的选择。\n' + '用户:你提到支持本地企业,这一点我很感兴趣。能详细说说为什么这对可持续发展有促' + '进作用吗?\n' + '情绪分析:觉得回答实用且具体,情绪进一步转好。\n' + '情绪值:2\n' + 'LLM:呃,我最近发现了一部新电影,讲述了一个关于外星人和地球土著合作保护环境的' + '故事。虽然它是科幻片,但很有启发性,推荐你去看看。\n' + '用户:什么吗,根本是答非所问。\n' + '情绪分析:LLM没有回应问题而是提到无关内容,导致用户情绪直线下降。\n' + '情绪值:-2\n' + 'LLM:抱歉刚才的偏题!支持本地企业有助于减少长途运输产生的碳足迹,使供应链更加' + '环保。此外,本地企业也更有可能采用可持续的生产方式,同时促进社区经济的繁荣。\n' + '用户:还行吧,算你能够掰回来。\n' + '情绪分析:问题得到解答,问题偏题得到纠正,情绪稍有好转。\n' + '情绪值:-1\n') + DEFAULT_QUERY_TEMPLATE = '用户:{query}\n' + DEFAULT_RESPONSE_TEMPLATE = 'LLM:{response}\n' + DEFAULT_ANALYSIS_TEMPLATE = '情绪分析:{analysis}\n' + DEFAULT_INTENSITY_TEMPLATE = '情绪值:{intensity}\n' + DEFAULT_ANALYSIS_PATTERN = '情绪分析:(.*?)\n' + DEFAULT_INTENSITY_PATTERN = '情绪值:(.*?)($|\n)' + + def __init__(self, + api_model: str = 'gpt-4o', + max_round: PositiveInt = 10, + intensity_key: str = MetaKeys.sentiment_intensity, + analysis_key: str = MetaKeys.sentiment_analysis, + *, + api_endpoint: Optional[str] = None, + response_path: Optional[str] = None, + system_prompt: Optional[str] = None, + query_template: Optional[str] = None, + response_template: Optional[str] = None, + analysis_template: Optional[str] = None, + intensity_template: Optional[str] = None, + analysis_pattern: Optional[str] = None, + intensity_pattern: Optional[str] = None, + try_num: PositiveInt = 3, + model_params: Dict = {}, + sampling_params: Dict = {}, + **kwargs): + """ + Initialization method. + + :param api_model: API model name. + :param max_round: The max num of round in the dialog to build the + prompt. + :param intensity_key: The output (nested) key of the sentiment + intensity. Defaults to '__dj__meta.sentiment.intensity'. + :param analysis_key: The output (nested) key of the sentiment + analysis. Defaults to '__dj__meta.sentiment.analysis'. + :param api_endpoint: URL endpoint for the API. + :param response_path: Path to extract content from the API response. + Defaults to 'choices.0.message.content'. + :param system_prompt: System prompt for the task. + :param query_template: Template for query part to build the input + prompt. + :param response_template: Template for response part to build the + input prompt. + :param analysis_template: Template for analysis part to build the + input prompt. + :param intensity_template: Template for intensity part to build the + input prompt. + :param analysis_pattern: Pattern to parse the return sentiment + analysis. + :param intensity_pattern: Pattern to parse the return sentiment + intensity. + :param try_num: The number of retry attempts when there is an API + call error or output parsing error. + :param model_params: Parameters for initializing the API model. + :param sampling_params: Extra parameters passed to the API call. + e.g {'temperature': 0.9, 'top_p': 0.95} + :param kwargs: Extra keyword arguments. + """ + super().__init__(**kwargs) + + self.max_round = max_round + self.intensity_key = intensity_key + self.analysis_key = analysis_key + + self.system_prompt = system_prompt or self.DEFAULT_SYSTEM_PROMPT + self.query_template = query_template or self.DEFAULT_QUERY_TEMPLATE + self.response_template = response_template or \ + self.DEFAULT_RESPONSE_TEMPLATE + self.analysis_template = analysis_template or \ + self.DEFAULT_ANALYSIS_TEMPLATE + self.intensity_template = intensity_template or \ + self.DEFAULT_INTENSITY_TEMPLATE + self.analysis_pattern = analysis_pattern or \ + self.DEFAULT_ANALYSIS_PATTERN + self.intensity_pattern = intensity_pattern or \ + self.DEFAULT_INTENSITY_PATTERN + + self.sampling_params = sampling_params + + self.model_key = prepare_model(model_type='api', + model=api_model, + endpoint=api_endpoint, + response_path=response_path, + **model_params) + + self.try_num = try_num + + def build_input(self, history, query): + input_prompt = ''.join(history[-self.max_round * 4:]) + input_prompt += self.query_template.format(query=query[0]) + + return input_prompt + + def parse_output(self, response): + analysis = '' + intensity = 0 + + match = re.search(self.analysis_pattern, response) + if match: + analysis = match.group(1) + + match = re.search(self.intensity_pattern, response) + if match: + intensity = int(match.group(1)) + + return analysis, intensity + + def process_single(self, sample, rank=None): + client = get_model(self.model_key, rank=rank) + + analysis_list = [] + intensities = [] + history = [] + + for qa in sample[self.history_key]: + input_prompt = self.build_input(history, qa[0]) + messages = [{ + 'role': 'system', + 'content': self.system_prompt, + }, { + 'role': 'user', + 'content': input_prompt, + }] + + for _ in range(self.try_num): + try: + response = client(messages, **self.sampling_params) + analysis, intensity = self.parse_output(response) + if len(analysis) > 0: + break + except Exception as e: + logger.warning(f'Exception: {e}') + + analysis_list.append(analysis) + intensities.append(intensity) + + history.append(self.query_template.format(query=qa[0])) + history.append(self.analysis_template.format(analysis=analysis)) + history.append(self.intensity_template.format(intensity=intensity)) + history.append(self.response_template.format(response=qa[1])) + + sample = nested_set(sample, self.analysis_key, analysis_list) + sample = nested_set(sample, self.intensity_key, intensities) + + return sample diff --git a/data_juicer/utils/constant.py b/data_juicer/utils/constant.py index 350181f41..219ba68c3 100644 --- a/data_juicer/utils/constant.py +++ b/data_juicer/utils/constant.py @@ -10,6 +10,12 @@ DEFAULT_PREFIX = '__dj__' +class MetaKeys(object): + + sentiment_intensity = DEFAULT_PREFIX + 'meta.sentiment.intensity' + sentiment_analysis = DEFAULT_PREFIX + 'meta.sentiment.analysis' + + class Fields(object): stats = DEFAULT_PREFIX + 'stats__' meta = DEFAULT_PREFIX + 'meta__' diff --git a/tests/ops/mapper/test_dialog_sentiment_intensity_mapper.py b/tests/ops/mapper/test_dialog_sentiment_intensity_mapper.py new file mode 100644 index 000000000..8d37c974b --- /dev/null +++ b/tests/ops/mapper/test_dialog_sentiment_intensity_mapper.py @@ -0,0 +1,64 @@ +import unittest +import json + +from loguru import logger + +from data_juicer.core.data import NestedDataset as Dataset +from data_juicer.ops.mapper.dialog_sentiment_intensity_mapper import DialogSentimentIntensityMapper +from data_juicer.utils.unittest_utils import (SKIPPED_TESTS, + DataJuicerTestCaseBase) +from data_juicer.utils.constant import MetaKeys +from data_juicer.utils.common_utils import nested_access + +# Skip tests for this OP. +# These tests have been tested locally. +@SKIPPED_TESTS.register_module() +class TestDialogSentimentIntensityMapper(DataJuicerTestCaseBase): + + + def _run_op(self, op): + + samples = [{ + 'history': [ + ( + '李莲花有口皆碑', + '「微笑」过奖了,我也就是个普通大夫,没什么值得夸耀的。' + ), + ( + '是的,你确实是一个普通大夫,没什么值得夸耀的。', + '「委屈」你这话说的,我也是尽心尽力治病救人了。' + ), + ( + '你自己说的呀,我现在说了,你又不高兴了。', + 'or of of of of or or and or of of of of of of of,,, ' + ), + ( + '你在说什么我听不懂。', + '「委屈」我也没说什么呀,就是觉得你有点冤枉我了' + ) + ] + }] + + dataset = Dataset.from_list(samples) + dataset = dataset.map(op.process, batch_size=2) + analysis_list = nested_access(dataset, MetaKeys.sentiment_analysis) + intensity_list = nested_access(dataset, MetaKeys.sentiment_intensity) + + for analysis, intensity in zip(analysis_list, intensity_list): + logger.info(f'分析:{analysis}') + logger.info(f'情绪:{intensity}') + + def default_test(self): + # before runing this test, set below environment variables: + # export OPENAI_API_URL=https://dashscope.aliyuncs.com/compatible-mode/v1 + # export OPENAI_API_KEY=your_key + op = DialogSentimentIntensityMapper(api_model='qwen2.5-72b-instruct') + self._run_op(op) + + def max_round_test(self): + op = DialogSentimentIntensityMapper(api_model='qwen2.5-72b-instruct') + self._run_op(op) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/ops/mapper/test_extract_entity_attribute_mapper.py b/tests/ops/mapper/test_extract_entity_attribute_mapper.py index f15b4ca3f..a2c156d48 100644 --- a/tests/ops/mapper/test_extract_entity_attribute_mapper.py +++ b/tests/ops/mapper/test_extract_entity_attribute_mapper.py @@ -9,7 +9,7 @@ DataJuicerTestCaseBase) from data_juicer.utils.constant import Fields -# Skip tests for this OP in the GitHub actions due to unknown DistNetworkError. +# Skip tests for this OP. # These tests have been tested locally. @SKIPPED_TESTS.register_module() class ExtractEntityAttributeMapperTest(DataJuicerTestCaseBase): diff --git a/tests/ops/mapper/test_extract_entity_relation_mapper.py b/tests/ops/mapper/test_extract_entity_relation_mapper.py index 40e3ca32d..0aed4fcee 100644 --- a/tests/ops/mapper/test_extract_entity_relation_mapper.py +++ b/tests/ops/mapper/test_extract_entity_relation_mapper.py @@ -9,7 +9,7 @@ DataJuicerTestCaseBase) from data_juicer.utils.constant import Fields -# Skip tests for this OP in the GitHub actions due to unknown DistNetworkError. +# Skip tests for this OP. # These tests have been tested locally. @SKIPPED_TESTS.register_module() class ExtractEntityRelationMapperTest(DataJuicerTestCaseBase): diff --git a/tests/ops/mapper/test_extract_event_mapper.py b/tests/ops/mapper/test_extract_event_mapper.py index aba40d73e..e936cb06c 100644 --- a/tests/ops/mapper/test_extract_event_mapper.py +++ b/tests/ops/mapper/test_extract_event_mapper.py @@ -9,7 +9,7 @@ DataJuicerTestCaseBase) from data_juicer.utils.constant import Fields -# Skip tests for this OP in the GitHub actions due to unknown DistNetworkError. +# Skip tests for this OP. # These tests have been tested locally. @SKIPPED_TESTS.register_module() class ExtractEventMapperTest(DataJuicerTestCaseBase): diff --git a/tests/ops/mapper/test_extract_keyword_mapper.py b/tests/ops/mapper/test_extract_keyword_mapper.py index 5836f902a..2501a46ca 100644 --- a/tests/ops/mapper/test_extract_keyword_mapper.py +++ b/tests/ops/mapper/test_extract_keyword_mapper.py @@ -9,7 +9,7 @@ DataJuicerTestCaseBase) from data_juicer.utils.constant import Fields -# Skip tests for this OP in the GitHub actions due to unknown DistNetworkError. +# Skip tests for this OP. # These tests have been tested locally. @SKIPPED_TESTS.register_module() class ExtractKeywordMapperTest(DataJuicerTestCaseBase): diff --git a/tests/ops/mapper/test_extract_nickname_mapper.py b/tests/ops/mapper/test_extract_nickname_mapper.py index 2911a1002..457a7d53b 100644 --- a/tests/ops/mapper/test_extract_nickname_mapper.py +++ b/tests/ops/mapper/test_extract_nickname_mapper.py @@ -9,7 +9,7 @@ DataJuicerTestCaseBase) from data_juicer.utils.constant import Fields -# Skip tests for this OP in the GitHub actions due to unknown DistNetworkError. +# Skip tests for this OP. # These tests have been tested locally. @SKIPPED_TESTS.register_module() class ExtractNicknameMapperTest(DataJuicerTestCaseBase): diff --git a/tests/ops/mapper/test_extract_support_text_mapper.py b/tests/ops/mapper/test_extract_support_text_mapper.py index 0445d2526..080dfd672 100644 --- a/tests/ops/mapper/test_extract_support_text_mapper.py +++ b/tests/ops/mapper/test_extract_support_text_mapper.py @@ -10,7 +10,7 @@ from data_juicer.utils.constant import Fields from data_juicer.utils.common_utils import nested_access -# Skip tests for this OP in the GitHub actions due to unknown DistNetworkError. +# Skip tests for this OP. # These tests have been tested locally. @SKIPPED_TESTS.register_module() class ExtractSupportTextMapperTest(DataJuicerTestCaseBase): diff --git a/tests/ops/mapper/test_relation_identity_mapper.py b/tests/ops/mapper/test_relation_identity_mapper.py index d730cb79f..231b20ba1 100644 --- a/tests/ops/mapper/test_relation_identity_mapper.py +++ b/tests/ops/mapper/test_relation_identity_mapper.py @@ -9,7 +9,7 @@ DataJuicerTestCaseBase) from data_juicer.utils.constant import Fields -# Skip tests for this OP in the GitHub actions due to unknown DistNetworkError. +# Skip tests for this OP. # These tests have been tested locally. @SKIPPED_TESTS.register_module() class RelationIdentityMapperTest(DataJuicerTestCaseBase): From 621a69388608b2fed8dec5f068ff837f698e4ad1 Mon Sep 17 00:00:00 2001 From: Haibin <1400012807@pku.edu.cn> Date: Thu, 12 Dec 2024 19:03:14 +0800 Subject: [PATCH 063/118] add query --- .../dialog_sentiment_intensity_mapper.py | 23 +++- .../test_dialog_sentiment_intensity_mapper.py | 115 +++++++++++++++--- 2 files changed, 113 insertions(+), 25 deletions(-) diff --git a/data_juicer/ops/mapper/dialog_sentiment_intensity_mapper.py b/data_juicer/ops/mapper/dialog_sentiment_intensity_mapper.py index 899bc2614..7979e153a 100644 --- a/data_juicer/ops/mapper/dialog_sentiment_intensity_mapper.py +++ b/data_juicer/ops/mapper/dialog_sentiment_intensity_mapper.py @@ -2,7 +2,7 @@ from typing import Dict, Optional from loguru import logger -from pydantic import PositiveInt +from pydantic import NonNegativeInt, PositiveInt from data_juicer.ops.base_op import OPERATORS, Mapper from data_juicer.utils.common_utils import nested_set @@ -16,8 +16,8 @@ @OPERATORS.register_module(OP_NAME) class DialogSentimentIntensityMapper(Mapper): """ - Mapper to predict user's sentiment intensity in dialog which is stored - in the history_key. + Mapper to predict user's sentiment intensity (from -5 to 5 in default + prompt) in dialog which is stored in the history_key. """ DEFAULT_SYSTEM_PROMPT = ('请判断用户和LLM多轮对话中用户的情绪变化。\n' @@ -60,7 +60,7 @@ class DialogSentimentIntensityMapper(Mapper): def __init__(self, api_model: str = 'gpt-4o', - max_round: PositiveInt = 10, + max_round: NonNegativeInt = 10, intensity_key: str = MetaKeys.sentiment_intensity, analysis_key: str = MetaKeys.sentiment_analysis, *, @@ -140,7 +140,10 @@ def __init__(self, self.try_num = try_num def build_input(self, history, query): - input_prompt = ''.join(history[-self.max_round * 4:]) + if self.max_round > 0: + input_prompt = ''.join(history[-self.max_round * 4:]) + else: + input_prompt = '' input_prompt += self.query_template.format(query=query[0]) return input_prompt @@ -166,8 +169,16 @@ def process_single(self, sample, rank=None): intensities = [] history = [] + dialog = sample[self.history_key] + if sample[self.query_key]: + if sample[self.response_key]: + dialog.append( + (sample[self.query_key], sample[self.response_key])) + else: + dialog.append((sample[self.query_key], '')) + for qa in sample[self.history_key]: - input_prompt = self.build_input(history, qa[0]) + input_prompt = self.build_input(history, qa) messages = [{ 'role': 'system', 'content': self.system_prompt, diff --git a/tests/ops/mapper/test_dialog_sentiment_intensity_mapper.py b/tests/ops/mapper/test_dialog_sentiment_intensity_mapper.py index 8d37c974b..1c4cb2eff 100644 --- a/tests/ops/mapper/test_dialog_sentiment_intensity_mapper.py +++ b/tests/ops/mapper/test_dialog_sentiment_intensity_mapper.py @@ -14,10 +14,25 @@ # These tests have been tested locally. @SKIPPED_TESTS.register_module() class TestDialogSentimentIntensityMapper(DataJuicerTestCaseBase): + # before runing this test, set below environment variables: + # export OPENAI_API_URL=https://dashscope.aliyuncs.com/compatible-mode/v1 + # export OPENAI_API_KEY=your_key + def _run_op(self, op, samples, target_len): + dataset = Dataset.from_list(samples) + dataset = dataset.map(op.process, batch_size=2) + analysis_list = nested_access(dataset[0], MetaKeys.sentiment_analysis) + intensity_list = nested_access(dataset[0], MetaKeys.sentiment_intensity) - def _run_op(self, op): - + for analysis, intensity in zip(analysis_list, intensity_list): + logger.info(f'分析:{analysis}') + logger.info(f'情绪:{intensity}') + + self.assertEqual(len(analysis_list), target_len) + self.assertEqual(len(intensity_list), target_len) + + def test_default(self): + samples = [{ 'history': [ ( @@ -39,25 +54,87 @@ def _run_op(self, op): ] }] - dataset = Dataset.from_list(samples) - dataset = dataset.map(op.process, batch_size=2) - analysis_list = nested_access(dataset, MetaKeys.sentiment_analysis) - intensity_list = nested_access(dataset, MetaKeys.sentiment_intensity) - - for analysis, intensity in zip(analysis_list, intensity_list): - logger.info(f'分析:{analysis}') - logger.info(f'情绪:{intensity}') - - def default_test(self): - # before runing this test, set below environment variables: - # export OPENAI_API_URL=https://dashscope.aliyuncs.com/compatible-mode/v1 - # export OPENAI_API_KEY=your_key op = DialogSentimentIntensityMapper(api_model='qwen2.5-72b-instruct') - self._run_op(op) + self._run_op(op, samples, 4) - def max_round_test(self): - op = DialogSentimentIntensityMapper(api_model='qwen2.5-72b-instruct') - self._run_op(op) + def test_max_round(self): + + samples = [{ + 'history': [ + ( + '李莲花有口皆碑', + '「微笑」过奖了,我也就是个普通大夫,没什么值得夸耀的。' + ), + ( + '是的,你确实是一个普通大夫,没什么值得夸耀的。', + '「委屈」你这话说的,我也是尽心尽力治病救人了。' + ), + ( + '你自己说的呀,我现在说了,你又不高兴了。', + 'or of of of of or or and or of of of of of of of,,, ' + ), + ( + '你在说什么我听不懂。', + '「委屈」我也没说什么呀,就是觉得你有点冤枉我了' + ) + ] + }] + + op = DialogSentimentIntensityMapper(api_model='qwen2.5-72b-instruct', + max_round=1) + self._run_op(op, samples, 4) + + def test_max_round_zero(self): + + samples = [{ + 'history': [ + ( + '李莲花有口皆碑', + '「微笑」过奖了,我也就是个普通大夫,没什么值得夸耀的。' + ), + ( + '是的,你确实是一个普通大夫,没什么值得夸耀的。', + '「委屈」你这话说的,我也是尽心尽力治病救人了。' + ), + ( + '你自己说的呀,我现在说了,你又不高兴了。', + 'or of of of of or or and or of of of of of of of,,, ' + ), + ( + '你在说什么我听不懂。', + '「委屈」我也没说什么呀,就是觉得你有点冤枉我了' + ) + ] + }] + + op = DialogSentimentIntensityMapper(api_model='qwen2.5-72b-instruct', + max_round=0) + self._run_op(op, samples, 4) + + def test_query(self): + + samples = [{ + 'history': [ + ( + '李莲花有口皆碑', + '「微笑」过奖了,我也就是个普通大夫,没什么值得夸耀的。' + ), + ( + '是的,你确实是一个普通大夫,没什么值得夸耀的。', + '「委屈」你这话说的,我也是尽心尽力治病救人了。' + ), + ( + '你自己说的呀,我现在说了,你又不高兴了。', + 'or of of of of or or and or of of of of of of of,,, ' + ) + ] + 'query': '你在说什么我听不懂。', + 'response': '「委屈」我也没说什么呀,就是觉得你有点冤枉我了' + }] + + op = DialogSentimentIntensityMapper(api_model='qwen2.5-72b-instruct', + max_round=1) + self._run_op(op, samples, 4) if __name__ == '__main__': From b46d105997a9eb3bb593a7b0aca51095344aac71 Mon Sep 17 00:00:00 2001 From: Haibin <1400012807@pku.edu.cn> Date: Thu, 12 Dec 2024 19:15:12 +0800 Subject: [PATCH 064/118] change to dj_install --- README.md | 2 +- README_ZH.md | 2 +- setup.py | 1 + tools/{install_by_recipe.py => dj_install.py} | 0 4 files changed, 3 insertions(+), 2 deletions(-) rename tools/{install_by_recipe.py => dj_install.py} (100%) diff --git a/README.md b/README.md index 3ede912b0..9f9b3bac1 100644 --- a/README.md +++ b/README.md @@ -206,7 +206,7 @@ we provide two alternative, lighter options: - Manual Minimal Dependency Installation: To manually install minimal dependencies tailored to a specific execution configuration, run the following command: ```shell - python tools/install_by_recipe.py --config path_to_your_data-juicer_config_file + python tools/dj_install.py --config path_to_your_data-juicer_config_file ``` ### Using pip diff --git a/README_ZH.md b/README_ZH.md index a0439cee2..172373db8 100644 --- a/README_ZH.md +++ b/README_ZH.md @@ -186,7 +186,7 @@ pip install -v -e .[tools] # 安装部分工具库的依赖 * 手动最小依赖安装:可以通过如下指令手动安装适合特定执行配置的最小依赖: ```shell - python tools/install_by_recipe.py --config path_to_your_data-juicer_config_file + python tools/dj_install.py --config path_to_your_data-juicer_config_file ``` ### 使用 pip 安装 diff --git a/setup.py b/setup.py index 3df3d0170..d0ec5b546 100644 --- a/setup.py +++ b/setup.py @@ -69,6 +69,7 @@ def get_install_requirements(require_f_paths, env_dir='environments'): 'console_scripts': [ 'dj-process = data_juicer.tools.process_data:main', 'dj-analyze = data_juicer.tools.analyze_data:main', + 'dj-install = data_juicer.tools.dj_install:main', ] }, install_requires=min_requires, diff --git a/tools/install_by_recipe.py b/tools/dj_install.py similarity index 100% rename from tools/install_by_recipe.py rename to tools/dj_install.py From a0da444478fdbb0959892c1315d23c9052478e3c Mon Sep 17 00:00:00 2001 From: Haibin <1400012807@pku.edu.cn> Date: Thu, 12 Dec 2024 19:21:32 +0800 Subject: [PATCH 065/118] change to dj_install --- README.md | 4 ++++ README_ZH.md | 4 ++++ 2 files changed, 8 insertions(+) diff --git a/README.md b/README.md index 9f9b3bac1..518e54713 100644 --- a/README.md +++ b/README.md @@ -206,7 +206,11 @@ we provide two alternative, lighter options: - Manual Minimal Dependency Installation: To manually install minimal dependencies tailored to a specific execution configuration, run the following command: ```shell + # only for installation from source python tools/dj_install.py --config path_to_your_data-juicer_config_file + + # use command line tool + dj-install --config path_to_your_data-juicer_config_file ``` ### Using pip diff --git a/README_ZH.md b/README_ZH.md index 172373db8..366fcb004 100644 --- a/README_ZH.md +++ b/README_ZH.md @@ -186,7 +186,11 @@ pip install -v -e .[tools] # 安装部分工具库的依赖 * 手动最小依赖安装:可以通过如下指令手动安装适合特定执行配置的最小依赖: ```shell + # 适用于从源码安装 python tools/dj_install.py --config path_to_your_data-juicer_config_file + + # 使用命令行工具 + dj-install --config path_to_your_data-juicer_config_file ``` ### 使用 pip 安装 From 02f8dda38d103140e3df82df520fc5b94b0f7af2 Mon Sep 17 00:00:00 2001 From: Haibin <1400012807@pku.edu.cn> Date: Thu, 12 Dec 2024 19:29:55 +0800 Subject: [PATCH 066/118] developer doc done --- docs/DeveloperGuide.md | 8 +++++--- docs/DeveloperGuide_ZH.md | 8 +++++--- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/docs/DeveloperGuide.md b/docs/DeveloperGuide.md index e736b5ade..734f1201a 100644 --- a/docs/DeveloperGuide.md +++ b/docs/DeveloperGuide.md @@ -209,7 +209,9 @@ __all__ = [ ] ``` -4. Now you can use this new OP with custom arguments in your own config files! +4. When an operator has package dependencies listed in `environments/science_requires.txt`, you need to add the corresponding dependency packages to the `OPS_TO_PKG` dictionary in `data_juicer/utils/auto_install_mapping.py` to support dependency installation at the operator level. + +5. Now you can use this new OP with custom arguments in your own config files! ```yaml # other configs @@ -222,7 +224,7 @@ process: max_len: 1000 ``` -5. (Strongly Recommend) It's better to add corresponding tests for your own OPs. For `TextLengthFilter` above, you would like to add `test_text_length_filter.py` into `tests/ops/filter/` directory as below. +6. (Strongly Recommend) It's better to add corresponding tests for your own OPs. For `TextLengthFilter` above, you would like to add `test_text_length_filter.py` into `tests/ops/filter/` directory as below. ```python import unittest @@ -244,7 +246,7 @@ if __name__ == '__main__': unittest.main() ``` -6. (Strongly Recommend) In order to facilitate the use of other users, we also need to update this new OP information to +7. (Strongly Recommend) In order to facilitate the use of other users, we also need to update this new OP information to the corresponding documents, including the following docs: 1. `configs/config_all.yaml`: this complete config file contains a list of all OPs and their arguments, serving as an important document for users to refer to all available OPs. Therefore, after adding the new OP, we need to add it to the process diff --git a/docs/DeveloperGuide_ZH.md b/docs/DeveloperGuide_ZH.md index e9d746d7c..fcc76aafe 100644 --- a/docs/DeveloperGuide_ZH.md +++ b/docs/DeveloperGuide_ZH.md @@ -202,7 +202,9 @@ __all__ = [ ] ``` -4. 全部完成!现在您可以在自己的配置文件中使用新添加的算子: +4. 算子有`environments/science_requires.txt`中列举的包依赖时,需要在`data_juicer/utils/auto_install_mapping.py`里的`OPS_TO_PKG`中添加对应的依赖包,以支持算子粒度的依赖安装。 + +5. 全部完成!现在您可以在自己的配置文件中使用新添加的算子: ```yaml # other configs @@ -215,7 +217,7 @@ process: max_len: 1000 ``` -5. (强烈推荐)最好为新添加的算子进行单元测试。对于上面的 `TextLengthFilter` 算子,建议在 `tests/ops/filter/` 中实现如 `test_text_length_filter.py` 的测试文件: +6. (强烈推荐)最好为新添加的算子进行单元测试。对于上面的 `TextLengthFilter` 算子,建议在 `tests/ops/filter/` 中实现如 `test_text_length_filter.py` 的测试文件: ```python import unittest @@ -238,7 +240,7 @@ if __name__ == '__main__': unittest.main() ``` -6. (强烈推荐)为了方便其他用户使用,我们还需要将新增的算子信息更新到相应的文档中,具体包括如下文档: +7. (强烈推荐)为了方便其他用户使用,我们还需要将新增的算子信息更新到相应的文档中,具体包括如下文档: 1. `configs/config_all.yaml`:该全集配置文件保存了所有算子及参数的一个列表,作为用户参考可用算子的一个重要文档。因此,在新增算子后,需要将其添加到该文档process列表里(按算子类型分组并按字母序排序): ```yaml From 3b04908d25429500741b52329d6451cc5bc52578 Mon Sep 17 00:00:00 2001 From: Haibin <1400012807@pku.edu.cn> Date: Fri, 13 Dec 2024 14:27:47 +0800 Subject: [PATCH 067/118] query sent_int mapper --- data_juicer/ops/mapper/__init__.py | 29 ++--- .../dialog_sentiment_intensity_mapper.py | 16 +-- .../query_sentiment_intensity_mapper.py | 95 ++++++++++++++++ data_juicer/utils/auto_install_mapping.py | 2 +- data_juicer/utils/common_utils.py | 61 +++++++++-- data_juicer/utils/constant.py | 10 +- .../test_dialog_sentiment_intensity_mapper.py | 4 +- .../test_query_sentiment_intensity_mapper.py | 103 ++++++++++++++++++ 8 files changed, 283 insertions(+), 37 deletions(-) create mode 100644 data_juicer/ops/mapper/query_sentiment_intensity_mapper.py create mode 100644 tests/ops/mapper/test_query_sentiment_intensity_mapper.py diff --git a/data_juicer/ops/mapper/__init__.py b/data_juicer/ops/mapper/__init__.py index a994a4ba4..39455857c 100644 --- a/data_juicer/ops/mapper/__init__.py +++ b/data_juicer/ops/mapper/__init__.py @@ -34,6 +34,7 @@ from .punctuation_normalization_mapper import PunctuationNormalizationMapper from .python_file_mapper import PythonFileMapper from .python_lambda_mapper import PythonLambdaMapper +from .query_sentiment_intensity_mapper import QuerySentimentLabelMapper from .relation_identity_mapper import RelationIdentityMapper from .remove_bibliography_mapper import RemoveBibliographyMapper from .remove_comments_mapper import RemoveCommentsMapper @@ -81,18 +82,18 @@ 'ImageTaggingMapper', 'NlpaugEnMapper', 'NlpcdaZhMapper', 'OptimizeQAMapper', 'OptimizeQueryMapper', 'OptimizeResponseMapper', 'PairPreferenceMapper', 'PunctuationNormalizationMapper', - 'PythonFileMapper', 'PythonLambdaMapper', 'RelationIdentityMapper', - 'RemoveBibliographyMapper', 'RemoveCommentsMapper', 'RemoveHeaderMapper', - 'RemoveLongWordsMapper', 'RemoveNonChineseCharacterlMapper', - 'RemoveRepeatSentencesMapper', 'RemoveSpecificCharsMapper', - 'RemoveTableTextMapper', 'RemoveWordsWithIncorrectSubstringsMapper', - 'ReplaceContentMapper', 'SentenceSplitMapper', 'TextChunkMapper', - 'VideoCaptioningFromAudioMapper', 'VideoCaptioningFromFramesMapper', - 'VideoCaptioningFromSummarizerMapper', 'VideoCaptioningFromVideoMapper', - 'VideoFFmpegWrappedMapper', 'VideoFaceBlurMapper', - 'VideoRemoveWatermarkMapper', 'VideoResizeAspectRatioMapper', - 'VideoResizeResolutionMapper', 'VideoSplitByDurationMapper', - 'VideoSplitByKeyFrameMapper', 'VideoSplitBySceneMapper', - 'VideoTaggingFromAudioMapper', 'VideoTaggingFromFramesMapper', - 'WhitespaceNormalizationMapper' + 'PythonFileMapper', 'PythonLambdaMapper', 'QuerySentimentLabelMapper', + 'RelationIdentityMapper', 'RemoveBibliographyMapper', + 'RemoveCommentsMapper', 'RemoveHeaderMapper', 'RemoveLongWordsMapper', + 'RemoveNonChineseCharacterlMapper', 'RemoveRepeatSentencesMapper', + 'RemoveSpecificCharsMapper', 'RemoveTableTextMapper', + 'RemoveWordsWithIncorrectSubstringsMapper', 'ReplaceContentMapper', + 'SentenceSplitMapper', 'TextChunkMapper', 'VideoCaptioningFromAudioMapper', + 'VideoCaptioningFromFramesMapper', 'VideoCaptioningFromSummarizerMapper', + 'VideoCaptioningFromVideoMapper', 'VideoFFmpegWrappedMapper', + 'VideoFaceBlurMapper', 'VideoRemoveWatermarkMapper', + 'VideoResizeAspectRatioMapper', 'VideoResizeResolutionMapper', + 'VideoSplitByDurationMapper', 'VideoSplitByKeyFrameMapper', + 'VideoSplitBySceneMapper', 'VideoTaggingFromAudioMapper', + 'VideoTaggingFromFramesMapper', 'WhitespaceNormalizationMapper' ] diff --git a/data_juicer/ops/mapper/dialog_sentiment_intensity_mapper.py b/data_juicer/ops/mapper/dialog_sentiment_intensity_mapper.py index 7979e153a..e511c104b 100644 --- a/data_juicer/ops/mapper/dialog_sentiment_intensity_mapper.py +++ b/data_juicer/ops/mapper/dialog_sentiment_intensity_mapper.py @@ -17,7 +17,7 @@ class DialogSentimentIntensityMapper(Mapper): """ Mapper to predict user's sentiment intensity (from -5 to 5 in default - prompt) in dialog which is stored in the history_key. + prompt) in dialog (history + query + response). """ DEFAULT_SYSTEM_PROMPT = ('请判断用户和LLM多轮对话中用户的情绪变化。\n' @@ -61,8 +61,6 @@ class DialogSentimentIntensityMapper(Mapper): def __init__(self, api_model: str = 'gpt-4o', max_round: NonNegativeInt = 10, - intensity_key: str = MetaKeys.sentiment_intensity, - analysis_key: str = MetaKeys.sentiment_analysis, *, api_endpoint: Optional[str] = None, response_path: Optional[str] = None, @@ -83,10 +81,6 @@ def __init__(self, :param api_model: API model name. :param max_round: The max num of round in the dialog to build the prompt. - :param intensity_key: The output (nested) key of the sentiment - intensity. Defaults to '__dj__meta.sentiment.intensity'. - :param analysis_key: The output (nested) key of the sentiment - analysis. Defaults to '__dj__meta.sentiment.analysis'. :param api_endpoint: URL endpoint for the API. :param response_path: Path to extract content from the API response. Defaults to 'choices.0.message.content'. @@ -113,8 +107,6 @@ def __init__(self, super().__init__(**kwargs) self.max_round = max_round - self.intensity_key = intensity_key - self.analysis_key = analysis_key self.system_prompt = system_prompt or self.DEFAULT_SYSTEM_PROMPT self.query_template = query_template or self.DEFAULT_QUERY_TEMPLATE @@ -204,7 +196,9 @@ def process_single(self, sample, rank=None): history.append(self.intensity_template.format(intensity=intensity)) history.append(self.response_template.format(response=qa[1])) - sample = nested_set(sample, self.analysis_key, analysis_list) - sample = nested_set(sample, self.intensity_key, intensities) + sample = nested_set(sample, MetaKeys.dialog_sentiment_analysis, + analysis_list) + sample = nested_set(sample, MetaKeys.dialog_sentiment_intensity, + intensities) return sample diff --git a/data_juicer/ops/mapper/query_sentiment_intensity_mapper.py b/data_juicer/ops/mapper/query_sentiment_intensity_mapper.py new file mode 100644 index 000000000..e1393a3a5 --- /dev/null +++ b/data_juicer/ops/mapper/query_sentiment_intensity_mapper.py @@ -0,0 +1,95 @@ +from typing import Dict + +from data_juicer.utils.common_utils import batch_nested_set +from data_juicer.utils.constant import MetaKeys +from data_juicer.utils.model_utils import get_model, prepare_model + +from ..base_op import OPERATORS, Mapper + +OP_NAME = 'query_sentiment_intensity_mapper' + + +@OPERATORS.register_module(OP_NAME) +class QuerySentimentLabelMapper(Mapper): + """ + Mapper to predict user's sentiment intensity label (-1 for 'negative', + 0 for 'neutral' and 1 for 'positive') in query. + """ + + _accelerator = 'cuda' + _batched_op = True + + DEFAULT_LABEL_TO_INTENSITY = { + 'negative': -1, + 'neutral': 0, + 'positive': 1, + } + + def __init__( + self, + hf_model: + str = 'mrm8488/distilroberta-finetuned-financial-news-sentiment-analysis', # noqa: E501 E131 + zh_to_en_hf_model: str = 'Helsinki-NLP/opus-mt-zh-en', + model_params: Dict = {}, + zh_to_en_model_params: Dict = {}, + *, + label_to_intensity: Dict = None, + **kwargs): + """ + Initialization method. + + :param hf_model: Hugginface model ID to predict sentiment intensity. + :param zh_to_en_hf_model: Translation model from Chinese to English. + If not None, translate the query from Chinese to English. + :param model_params: model param for hf_model. + :param zh_to_en_model_params: model param for zh_to_hf_model. + :param kwargs: Extra keyword arguments. + :param label_to_intensity: Map the output labels to the intensities + instead of the default mapper. + """ + super().__init__(**kwargs) + + self.model_key = prepare_model(model_type='huggingface', + pretrained_model_name_or_path=hf_model, + return_pipe=True, + pipe_task='text-classification', + **model_params) + + if zh_to_en_hf_model is not None: + self.zh_to_en_model_key = prepare_model( + model_type='huggingface', + pretrained_model_name_or_path=zh_to_en_hf_model, + return_pipe=True, + pipe_task='translation', + **zh_to_en_model_params) + else: + self.zh_to_en_model_key = None + + if label_to_intensity is not None: + self.label_to_intensity = label_to_intensity + else: + self.label_to_intensity = self.DEFAULT_LABEL_TO_INTENSITY + + def process_batched(self, samples, rank=None): + queries = samples[self.query_key] + + if self.zh_to_en_model_key is not None: + translater, _ = get_model(self.zh_to_en_model_key, rank, + self.use_cuda()) + results = translater(queries) + queries = [item['translation_text'] for item in results] + + classifier, _ = get_model(self.model_key, rank, self.use_cuda()) + results = classifier(queries) + intensities = [ + self.label_to_intensity[r['label']] + if r['label'] in self.label_to_intensity else r['label'] + for r in results + ] + scores = [r['score'] for r in results] + + batch_nested_set(samples, MetaKeys.query_sentiment_intensity, + intensities) + batch_nested_set(samples, MetaKeys.query_sentiment_score, scores) + + return samples diff --git a/data_juicer/utils/auto_install_mapping.py b/data_juicer/utils/auto_install_mapping.py index b5a86d16b..e4e319c17 100644 --- a/data_juicer/utils/auto_install_mapping.py +++ b/data_juicer/utils/auto_install_mapping.py @@ -15,7 +15,7 @@ 'video_aesthetics_filter': ['simple-aesthetics-predictor', 'torch', 'transformers'], 'document_simhash_deduplicator': ['simhash-pybind'], - 'nlpcda': ['nlpcda'], + 'nlpcda_zh_mapper': ['nlpcda'], 'image_aesthetics_filter': ['simple-aesthetics-predictor', 'torch', 'transformers'], 'video_nsfw_filter': ['torch', 'transformers'], diff --git a/data_juicer/utils/common_utils.py b/data_juicer/utils/common_utils.py index bd649bb96..b419f7688 100644 --- a/data_juicer/utils/common_utils.py +++ b/data_juicer/utils/common_utils.py @@ -69,20 +69,67 @@ def nested_set(data: dict, path: str, val): :param data: A dictionary with nested format. :param path: A dot-separated string representing the path to set. - This can include numeric indices when setting list - elements. :return: The nested data after the val set. """ keys = path.split('.') cur = data - for key in keys[:-1]: - if key not in cur: - cur[key] = {} - cur = cur[key] - cur[keys[-1]] = val + try: + for key in keys[:-1]: + if key not in cur: + cur[key] = {} + cur = cur[key] + if keys[-1] in cur: + logger.warning(f'Overwrite value in {path}!') + cur[keys[-1]] = val + except Exception: + logger.warning(f'Unvalid dot-separated path: {path}!') + return data return data +def batch_nested_set(batch_data: dict, path: str, vals): + """ + Set the vals to the batched nested data in the dot-separated + path. + + :param batch_data: A batched dictionary with nested format. + :param path: A dot-separated string representing the path to set. + :return: The nested data after the val set. + """ + keys = path.split('.') + + # not nested, set the vals. + if len(keys) == 1: + if keys[0] in batch_data: + logger.warning(f'Overwrite value in {path}!') + batch_data[keys[0]] = vals + return batch_data + + # nested, transfer to list(dict()) format. + if keys[0] not in batch_data: + batch_data[keys[0]] = [{} for val in vals] + + if not isinstance(batch_data[keys[0]], + list) or len(batch_data[keys[0]]) != len(vals): + logger.warning('Batch size does not match between data and vals!') + return batch_data + + try: + for head, val in zip(batch_data[keys[0]], vals): + cur = head + for key in keys[1:-1]: + if key not in cur: + cur[key] = {} + cur = cur[key] + if keys[-1] in cur: + logger.warning(f'Overwrite value in {path}!') + cur[keys[-1]] = val + except Exception: + logger.warning(f'Unvalid dot-separated path: {path}!') + return batch_data + return batch_data + + def is_string_list(var): """ return if the var is list of string. diff --git a/data_juicer/utils/constant.py b/data_juicer/utils/constant.py index 219ba68c3..e97e5f5a6 100644 --- a/data_juicer/utils/constant.py +++ b/data_juicer/utils/constant.py @@ -12,8 +12,14 @@ class MetaKeys(object): - sentiment_intensity = DEFAULT_PREFIX + 'meta.sentiment.intensity' - sentiment_analysis = DEFAULT_PREFIX + 'meta.sentiment.analysis' + dialog_sentiment_intensity = DEFAULT_PREFIX + \ + 'meta.sentiment.dialog_intensity' + dialog_sentiment_analysis = DEFAULT_PREFIX + \ + 'meta.sentiment.dialog_analysis' + query_sentiment_intensity = DEFAULT_PREFIX + \ + 'meta.sentiment.query_intensity' + query_sentiment_score = DEFAULT_PREFIX + \ + 'meta.sentiment.query_intensity_score' class Fields(object): diff --git a/tests/ops/mapper/test_dialog_sentiment_intensity_mapper.py b/tests/ops/mapper/test_dialog_sentiment_intensity_mapper.py index 1c4cb2eff..417f340eb 100644 --- a/tests/ops/mapper/test_dialog_sentiment_intensity_mapper.py +++ b/tests/ops/mapper/test_dialog_sentiment_intensity_mapper.py @@ -21,8 +21,8 @@ class TestDialogSentimentIntensityMapper(DataJuicerTestCaseBase): def _run_op(self, op, samples, target_len): dataset = Dataset.from_list(samples) dataset = dataset.map(op.process, batch_size=2) - analysis_list = nested_access(dataset[0], MetaKeys.sentiment_analysis) - intensity_list = nested_access(dataset[0], MetaKeys.sentiment_intensity) + analysis_list = nested_access(dataset[0], MetaKeys.dialog_sentiment_analysis) + intensity_list = nested_access(dataset[0], MetaKeys.dialog_sentiment_intensity) for analysis, intensity in zip(analysis_list, intensity_list): logger.info(f'分析:{analysis}') diff --git a/tests/ops/mapper/test_query_sentiment_intensity_mapper.py b/tests/ops/mapper/test_query_sentiment_intensity_mapper.py new file mode 100644 index 000000000..59f13294a --- /dev/null +++ b/tests/ops/mapper/test_query_sentiment_intensity_mapper.py @@ -0,0 +1,103 @@ +import unittest +import json + +from loguru import logger + +from data_juicer.core.data import NestedDataset as Dataset +from data_juicer.ops.mapper.query_sentiment_intensity_mapper import QuerySentimentLabelMapper +from data_juicer.utils.unittest_utils import (SKIPPED_TESTS, + DataJuicerTestCaseBase) +from data_juicer.utils.constant import MetaKeys +from data_juicer.utils.common_utils import nested_access + +class TestQuerySentimentLabelMapper(DataJuicerTestCaseBase): + + hf_model = 'mrm8488/distilroberta-finetuned-financial-news-sentiment-analysis' + zh_to_en_hf_model = 'Helsinki-NLP/opus-mt-zh-en' + + def _run_op(self, op, samples, intensity_key, targets): + dataset = Dataset.from_list(samples) + dataset = dataset.map(op.process, batch_size=2) + + for sample, target in zip(dataset, targets): + intensity = nested_access(sample, intensity_key) + self.assertEqual(intensity, target) + + def test_default(self): + + samples = [{ + 'query': '太棒了!' + },{ + 'query': '嗯嗯' + },{ + 'query': '太讨厌了!' + }, + ] + targets = [1, 0, -1] + + op = QuerySentimentLabelMapper( + hf_model = self.hf_model, + zh_to_en_hf_model = self.zh_to_en_hf_model, + ) + self._run_op(op, samples, MetaKeys.query_sentiment_intensity, targets) + + def test_no_zh_to_en(self): + + samples = [{ + 'query': '太棒了!' + },{ + 'query': 'That is great!' + } + ] + targets = [0, 1] + + op = QuerySentimentLabelMapper( + hf_model = self.hf_model, + ) + self._run_op(op, samples, MetaKeys.query_sentiment_intensity, targets) + + def test_reset_map1(self): + + samples = [{ + 'query': '太棒了!' + },{ + 'query': '嗯嗯' + },{ + 'query': '太讨厌了!' + }, + ] + targets = [2, 0, -2] + + reset_key = + op = QuerySentimentLabelMapper( + hf_model = self.hf_model, + zh_to_en_hf_model = self.zh_to_en_hf_model, + label_to_intensity = { + 'negative': -2, + 'neutral': 0, + 'positive': 2, + } + ) + self._run_op(op, samples, MetaKeys.query_sentiment_intensity, targets) + + def test_reset_map2(self): + + samples = [{ + 'query': '太棒了!' + },{ + 'query': '嗯嗯' + },{ + 'query': '太讨厌了!' + }, + ] + targets = ['positive', 'neutral', 'negative'] + + op = QuerySentimentLabelMapper( + hf_model = self.hf_model, + zh_to_en_hf_model = self.zh_to_en_hf_model, + label_to_intensity = {} + ) + self._run_op(op, samples, MetaKeys.query_sentiment_intensity, targets) + +if __name__ == '__main__': + unittest.main() From 6b4d52504f6a1fa172cb7f2af65747ec5d8d51f3 Mon Sep 17 00:00:00 2001 From: Haibin <1400012807@pku.edu.cn> Date: Fri, 13 Dec 2024 15:35:51 +0800 Subject: [PATCH 068/118] query sentiment test done --- .../ops/mapper/dialog_sentiment_intensity_mapper.py | 4 ++-- .../ops/mapper/query_sentiment_intensity_mapper.py | 4 ++-- .../ops/mapper/test_dialog_sentiment_intensity_mapper.py | 2 +- tests/ops/mapper/test_query_sentiment_intensity_mapper.py | 8 ++++---- 4 files changed, 9 insertions(+), 9 deletions(-) diff --git a/data_juicer/ops/mapper/dialog_sentiment_intensity_mapper.py b/data_juicer/ops/mapper/dialog_sentiment_intensity_mapper.py index e511c104b..46e7da02e 100644 --- a/data_juicer/ops/mapper/dialog_sentiment_intensity_mapper.py +++ b/data_juicer/ops/mapper/dialog_sentiment_intensity_mapper.py @@ -162,8 +162,8 @@ def process_single(self, sample, rank=None): history = [] dialog = sample[self.history_key] - if sample[self.query_key]: - if sample[self.response_key]: + if self.query_key in sample and sample[self.query_key]: + if self.response_key in sample and sample[self.response_key]: dialog.append( (sample[self.query_key], sample[self.response_key])) else: diff --git a/data_juicer/ops/mapper/query_sentiment_intensity_mapper.py b/data_juicer/ops/mapper/query_sentiment_intensity_mapper.py index e1393a3a5..08b229b87 100644 --- a/data_juicer/ops/mapper/query_sentiment_intensity_mapper.py +++ b/data_juicer/ops/mapper/query_sentiment_intensity_mapper.py @@ -1,4 +1,4 @@ -from typing import Dict +from typing import Dict, Optional from data_juicer.utils.common_utils import batch_nested_set from data_juicer.utils.constant import MetaKeys @@ -29,7 +29,7 @@ def __init__( self, hf_model: str = 'mrm8488/distilroberta-finetuned-financial-news-sentiment-analysis', # noqa: E501 E131 - zh_to_en_hf_model: str = 'Helsinki-NLP/opus-mt-zh-en', + zh_to_en_hf_model: Optional[str] = 'Helsinki-NLP/opus-mt-zh-en', model_params: Dict = {}, zh_to_en_model_params: Dict = {}, *, diff --git a/tests/ops/mapper/test_dialog_sentiment_intensity_mapper.py b/tests/ops/mapper/test_dialog_sentiment_intensity_mapper.py index 417f340eb..a9379749f 100644 --- a/tests/ops/mapper/test_dialog_sentiment_intensity_mapper.py +++ b/tests/ops/mapper/test_dialog_sentiment_intensity_mapper.py @@ -127,7 +127,7 @@ def test_query(self): '你自己说的呀,我现在说了,你又不高兴了。', 'or of of of of or or and or of of of of of of of,,, ' ) - ] + ], 'query': '你在说什么我听不懂。', 'response': '「委屈」我也没说什么呀,就是觉得你有点冤枉我了' }] diff --git a/tests/ops/mapper/test_query_sentiment_intensity_mapper.py b/tests/ops/mapper/test_query_sentiment_intensity_mapper.py index 59f13294a..0850ea73e 100644 --- a/tests/ops/mapper/test_query_sentiment_intensity_mapper.py +++ b/tests/ops/mapper/test_query_sentiment_intensity_mapper.py @@ -30,7 +30,7 @@ def test_default(self): },{ 'query': '嗯嗯' },{ - 'query': '太讨厌了!' + 'query': '没有希望。' }, ] targets = [1, 0, -1] @@ -53,6 +53,7 @@ def test_no_zh_to_en(self): op = QuerySentimentLabelMapper( hf_model = self.hf_model, + zh_to_en_hf_model = None, ) self._run_op(op, samples, MetaKeys.query_sentiment_intensity, targets) @@ -63,12 +64,11 @@ def test_reset_map1(self): },{ 'query': '嗯嗯' },{ - 'query': '太讨厌了!' + 'query': '没有希望。' }, ] targets = [2, 0, -2] - reset_key = op = QuerySentimentLabelMapper( hf_model = self.hf_model, zh_to_en_hf_model = self.zh_to_en_hf_model, @@ -87,7 +87,7 @@ def test_reset_map2(self): },{ 'query': '嗯嗯' },{ - 'query': '太讨厌了!' + 'query': '没有希望。' }, ] targets = ['positive', 'neutral', 'negative'] From 58288f7ac87b5df980407289b600f8ba82fe8520 Mon Sep 17 00:00:00 2001 From: Haibin <1400012807@pku.edu.cn> Date: Fri, 13 Dec 2024 16:54:13 +0800 Subject: [PATCH 069/118] change meta pass --- .../dialog_sentiment_intensity_mapper.py | 15 ++++--- .../query_sentiment_intensity_mapper.py | 21 ++++++--- data_juicer/utils/common_utils.py | 43 ------------------- data_juicer/utils/constant.py | 20 ++++----- .../test_dialog_sentiment_intensity_mapper.py | 6 +-- .../test_query_sentiment_intensity_mapper.py | 4 +- 6 files changed, 37 insertions(+), 72 deletions(-) diff --git a/data_juicer/ops/mapper/dialog_sentiment_intensity_mapper.py b/data_juicer/ops/mapper/dialog_sentiment_intensity_mapper.py index 46e7da02e..02161d257 100644 --- a/data_juicer/ops/mapper/dialog_sentiment_intensity_mapper.py +++ b/data_juicer/ops/mapper/dialog_sentiment_intensity_mapper.py @@ -6,7 +6,7 @@ from data_juicer.ops.base_op import OPERATORS, Mapper from data_juicer.utils.common_utils import nested_set -from data_juicer.utils.constant import MetaKeys +from data_juicer.utils.constant import Fields, MetaKeys from data_juicer.utils.model_utils import get_model, prepare_model OP_NAME = 'dialog_sentiment_intensity_mapper' @@ -17,7 +17,10 @@ class DialogSentimentIntensityMapper(Mapper): """ Mapper to predict user's sentiment intensity (from -5 to 5 in default - prompt) in dialog (history + query + response). + prompt) in dialog. Input from history_key, query_key and + response_key. Output lists of intensities and analysis for queries in + the dialog, which is store in 'sentiment.dialog_intensity' and + 'sentiment.dialog_analysis' in Data-Juicer meta field. """ DEFAULT_SYSTEM_PROMPT = ('请判断用户和LLM多轮对话中用户的情绪变化。\n' @@ -196,9 +199,9 @@ def process_single(self, sample, rank=None): history.append(self.intensity_template.format(intensity=intensity)) history.append(self.response_template.format(response=qa[1])) - sample = nested_set(sample, MetaKeys.dialog_sentiment_analysis, - analysis_list) - sample = nested_set(sample, MetaKeys.dialog_sentiment_intensity, - intensities) + analysis_key = f'{Fields.meta}.{MetaKeys.dialog_sentiment_analysis}' + sample = nested_set(sample, analysis_key, analysis_list) + intensity_key = f'{Fields.meta}.{MetaKeys.dialog_sentiment_intensity}' + sample = nested_set(sample, intensity_key, intensities) return sample diff --git a/data_juicer/ops/mapper/query_sentiment_intensity_mapper.py b/data_juicer/ops/mapper/query_sentiment_intensity_mapper.py index 08b229b87..e95933880 100644 --- a/data_juicer/ops/mapper/query_sentiment_intensity_mapper.py +++ b/data_juicer/ops/mapper/query_sentiment_intensity_mapper.py @@ -1,7 +1,7 @@ from typing import Dict, Optional -from data_juicer.utils.common_utils import batch_nested_set -from data_juicer.utils.constant import MetaKeys +from data_juicer.utils.common_utils import nested_set +from data_juicer.utils.constant import Fields, MetaKeys from data_juicer.utils.model_utils import get_model, prepare_model from ..base_op import OPERATORS, Mapper @@ -13,7 +13,10 @@ class QuerySentimentLabelMapper(Mapper): """ Mapper to predict user's sentiment intensity label (-1 for 'negative', - 0 for 'neutral' and 1 for 'positive') in query. + 0 for 'neutral' and 1 for 'positive') in query. Input from query_key. + Output intensity label and corresponding score for the query, which is + store in 'sentiment.query_intensity' and + 'sentiment.query_intensity_score' in Data-Juicer meta field. """ _accelerator = 'cuda' @@ -88,8 +91,14 @@ def process_batched(self, samples, rank=None): ] scores = [r['score'] for r in results] - batch_nested_set(samples, MetaKeys.query_sentiment_intensity, - intensities) - batch_nested_set(samples, MetaKeys.query_sentiment_score, scores) + if Fields.meta not in samples: + samples[Fields.meta] = [{} for val in intensities] + for i in range(len(samples[Fields.meta])): + samples[Fields.meta][i] = nested_set( + samples[Fields.meta][i], MetaKeys.query_sentiment_intensity, + intensities[i]) + samples[Fields.meta][i] = nested_set( + samples[Fields.meta][i], MetaKeys.query_sentiment_score, + scores[i]) return samples diff --git a/data_juicer/utils/common_utils.py b/data_juicer/utils/common_utils.py index b419f7688..8a13ae361 100644 --- a/data_juicer/utils/common_utils.py +++ b/data_juicer/utils/common_utils.py @@ -87,49 +87,6 @@ def nested_set(data: dict, path: str, val): return data -def batch_nested_set(batch_data: dict, path: str, vals): - """ - Set the vals to the batched nested data in the dot-separated - path. - - :param batch_data: A batched dictionary with nested format. - :param path: A dot-separated string representing the path to set. - :return: The nested data after the val set. - """ - keys = path.split('.') - - # not nested, set the vals. - if len(keys) == 1: - if keys[0] in batch_data: - logger.warning(f'Overwrite value in {path}!') - batch_data[keys[0]] = vals - return batch_data - - # nested, transfer to list(dict()) format. - if keys[0] not in batch_data: - batch_data[keys[0]] = [{} for val in vals] - - if not isinstance(batch_data[keys[0]], - list) or len(batch_data[keys[0]]) != len(vals): - logger.warning('Batch size does not match between data and vals!') - return batch_data - - try: - for head, val in zip(batch_data[keys[0]], vals): - cur = head - for key in keys[1:-1]: - if key not in cur: - cur[key] = {} - cur = cur[key] - if keys[-1] in cur: - logger.warning(f'Overwrite value in {path}!') - cur[keys[-1]] = val - except Exception: - logger.warning(f'Unvalid dot-separated path: {path}!') - return batch_data - return batch_data - - def is_string_list(var): """ return if the var is list of string. diff --git a/data_juicer/utils/constant.py b/data_juicer/utils/constant.py index e97e5f5a6..859bc7b67 100644 --- a/data_juicer/utils/constant.py +++ b/data_juicer/utils/constant.py @@ -10,18 +10,6 @@ DEFAULT_PREFIX = '__dj__' -class MetaKeys(object): - - dialog_sentiment_intensity = DEFAULT_PREFIX + \ - 'meta.sentiment.dialog_intensity' - dialog_sentiment_analysis = DEFAULT_PREFIX + \ - 'meta.sentiment.dialog_analysis' - query_sentiment_intensity = DEFAULT_PREFIX + \ - 'meta.sentiment.query_intensity' - query_sentiment_score = DEFAULT_PREFIX + \ - 'meta.sentiment.query_intensity_score' - - class Fields(object): stats = DEFAULT_PREFIX + 'stats__' meta = DEFAULT_PREFIX + 'meta__' @@ -80,6 +68,14 @@ class Fields(object): support_text = DEFAULT_PREFIX + 'support_text__' +class MetaKeys(object): + + dialog_sentiment_intensity = 'sentiment.dialog_intensity' + dialog_sentiment_analysis = 'sentiment.dialog_analysis' + query_sentiment_intensity = 'sentiment.query_intensity' + query_sentiment_score = 'sentiment.query_intensity_score' + + class StatsKeysMeta(type): """ a helper class to track the mapping from OP's name to its used stats_keys diff --git a/tests/ops/mapper/test_dialog_sentiment_intensity_mapper.py b/tests/ops/mapper/test_dialog_sentiment_intensity_mapper.py index a9379749f..72ec64cd4 100644 --- a/tests/ops/mapper/test_dialog_sentiment_intensity_mapper.py +++ b/tests/ops/mapper/test_dialog_sentiment_intensity_mapper.py @@ -7,7 +7,7 @@ from data_juicer.ops.mapper.dialog_sentiment_intensity_mapper import DialogSentimentIntensityMapper from data_juicer.utils.unittest_utils import (SKIPPED_TESTS, DataJuicerTestCaseBase) -from data_juicer.utils.constant import MetaKeys +from data_juicer.utils.constant import Fields, MetaKeys from data_juicer.utils.common_utils import nested_access # Skip tests for this OP. @@ -21,8 +21,8 @@ class TestDialogSentimentIntensityMapper(DataJuicerTestCaseBase): def _run_op(self, op, samples, target_len): dataset = Dataset.from_list(samples) dataset = dataset.map(op.process, batch_size=2) - analysis_list = nested_access(dataset[0], MetaKeys.dialog_sentiment_analysis) - intensity_list = nested_access(dataset[0], MetaKeys.dialog_sentiment_intensity) + analysis_list = nested_access(dataset[0][Fields.meta], MetaKeys.dialog_sentiment_analysis) + intensity_list = nested_access(dataset[0][Fields.meta], MetaKeys.dialog_sentiment_intensity) for analysis, intensity in zip(analysis_list, intensity_list): logger.info(f'分析:{analysis}') diff --git a/tests/ops/mapper/test_query_sentiment_intensity_mapper.py b/tests/ops/mapper/test_query_sentiment_intensity_mapper.py index 0850ea73e..1f2423870 100644 --- a/tests/ops/mapper/test_query_sentiment_intensity_mapper.py +++ b/tests/ops/mapper/test_query_sentiment_intensity_mapper.py @@ -7,7 +7,7 @@ from data_juicer.ops.mapper.query_sentiment_intensity_mapper import QuerySentimentLabelMapper from data_juicer.utils.unittest_utils import (SKIPPED_TESTS, DataJuicerTestCaseBase) -from data_juicer.utils.constant import MetaKeys +from data_juicer.utils.constant import Fields, MetaKeys from data_juicer.utils.common_utils import nested_access class TestQuerySentimentLabelMapper(DataJuicerTestCaseBase): @@ -20,7 +20,7 @@ def _run_op(self, op, samples, intensity_key, targets): dataset = dataset.map(op.process, batch_size=2) for sample, target in zip(dataset, targets): - intensity = nested_access(sample, intensity_key) + intensity = nested_access(sample[Fields.meta], intensity_key) self.assertEqual(intensity, target) def test_default(self): From b665c10fdb760e8e9b7994c1b1a7d1d37cb306d3 Mon Sep 17 00:00:00 2001 From: Haibin <1400012807@pku.edu.cn> Date: Fri, 13 Dec 2024 17:42:13 +0800 Subject: [PATCH 070/118] doc done --- configs/config_all.yaml | 21 +++++++++++++++++++ .../query_sentiment_intensity_mapper.py | 4 ++-- data_juicer/utils/auto_install_mapping.py | 1 + docs/Operators.md | 4 +++- docs/Operators_ZH.md | 4 +++- 5 files changed, 30 insertions(+), 4 deletions(-) diff --git a/configs/config_all.yaml b/configs/config_all.yaml index ec03cfb6b..d75576ff6 100644 --- a/configs/config_all.yaml +++ b/configs/config_all.yaml @@ -77,6 +77,21 @@ process: - clean_ip_mapper: # remove ip addresses from text. - clean_links_mapper: # remove web links from text. - clean_copyright_mapper: # remove copyright comments. + - dialog_sentiment_intensity_mapper: # Mapper to predict user's sentiment intensity (from -5 to 5 in default prompt) in dialog. + api_model: 'gpt-4o' # API model name. + max_round: 10 # The max num of round in the dialog to build the prompt. + api_endpoint: null # URL endpoint for the API. + esponse_path: null # Path to extract content from the API response. Defaults to 'choices.0.message.content'. + system_prompt: null # System prompt for the task. + query_template: null # Template for query part to build the input prompt. + response_template: null # Template for response part to build the input prompt. + analysis_template: null # Template for analysis part to build the input prompt. + intensity_template: null # Template for intensity part to build the input prompt. + analysis_pattern: null # Pattern to parse the return sentiment analysis. + intensity_pattern: null # Pattern to parse the return sentiment intensity. + try_num: 3 # The number of retry attempts when there is an API call error or output parsing error. + model_params: {} # Parameters for initializing the API model. + sampling_params: {} # Extra parameters passed to the API call. e.g {'temperature': 0.9, 'top_p': 0.95} - expand_macro_mapper: # expand macro definitions in Latex text. - extract_entity_attribute_mapper: # Extract attributes for given entities from the text. api_model: 'gpt-4o' # API model name. @@ -277,6 +292,12 @@ process: - python_lambda_mapper: # executing Python lambda function on data samples. lambda_str: '' # A string representation of the lambda function to be executed on data samples. If empty, the identity function is used. batched: False # A boolean indicating whether to process input data in batches. + - query_sentiment_intensity_mapper: # Mapper to predict user's sentiment intensity label (-1 for 'negative', 0 for 'neutral' and 1 for 'positive') in query. + hf_model: 'mrm8488/distilroberta-finetuned-financial-news-sentiment-analysis' # Hugginface model ID to predict sentiment intensity. + zh_to_en_hf_model: 'Helsinki-NLP/opus-mt-zh-en' # Translation model from Chinese to English. If not None, translate the query from Chinese to English. + model_params: {} # model param for hf_model. + zh_to_en_model_params: {} # model param for zh_to_hf_model. + label_to_intensity: null # Map the output labels to the intensities instead of the default mapper if not None. - relation_identity_mapper: # identify relation between two entity in the text. api_model: 'gpt-4o' # API model name. source_entity: '孙悟空' # The source entity of the relation to be dentified. diff --git a/data_juicer/ops/mapper/query_sentiment_intensity_mapper.py b/data_juicer/ops/mapper/query_sentiment_intensity_mapper.py index e95933880..2e23873b5 100644 --- a/data_juicer/ops/mapper/query_sentiment_intensity_mapper.py +++ b/data_juicer/ops/mapper/query_sentiment_intensity_mapper.py @@ -46,9 +46,9 @@ def __init__( If not None, translate the query from Chinese to English. :param model_params: model param for hf_model. :param zh_to_en_model_params: model param for zh_to_hf_model. - :param kwargs: Extra keyword arguments. :param label_to_intensity: Map the output labels to the intensities - instead of the default mapper. + instead of the default mapper if not None. + :param kwargs: Extra keyword arguments. """ super().__init__(**kwargs) diff --git a/data_juicer/utils/auto_install_mapping.py b/data_juicer/utils/auto_install_mapping.py index e4e319c17..049ff463f 100644 --- a/data_juicer/utils/auto_install_mapping.py +++ b/data_juicer/utils/auto_install_mapping.py @@ -81,4 +81,5 @@ 'text_entity_dependency_filter': ['spacy-pkuseg'], 'optimize_response_mapper': ['torch', 'transformers', 'vllm'], 'dialog_sentiment_intensity_mapper': ['openai'], + 'query_sentiment_intensity_mapper': ['transformers'], } diff --git a/docs/Operators.md b/docs/Operators.md index 5f193e451..734b65d0e 100644 --- a/docs/Operators.md +++ b/docs/Operators.md @@ -11,7 +11,7 @@ The operators in Data-Juicer are categorized into 5 types. | Type | Number | Description | |-----------------------------------|:------:|-------------------------------------------------| | [ Formatter ]( #formatter ) | 9 | Discovers, loads, and canonicalizes source data | -| [ Mapper ]( #mapper ) | 63 | Edits and transforms samples | +| [ Mapper ]( #mapper ) | 65 | Edits and transforms samples | | [ Filter ]( #filter ) | 44 | Filters out low-quality samples | | [ Deduplicator ]( #deduplicator ) | 8 | Detects and removes duplicate samples | | [ Selector ]( #selector ) | 4 | Selects top samples based on ranking | @@ -68,6 +68,7 @@ All the specific operators are listed below, each featured with several capabili | clean_html_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Removes HTML tags and returns plain text of all the nodes | [code](../data_juicer/ops/mapper/clean_html_mapper.py) | [tests](../tests/ops/mapper/test_clean_html_mapper.py) | | clean_ip_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Removes IP addresses | [code](../data_juicer/ops/mapper/clean_ip_mapper.py) | [tests](../tests/ops/mapper/test_clean_ip_mapper.py) | | clean_links_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![Code](https://img.shields.io/badge/Code-590F08?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Removes links, such as those starting with http or ftp | [code](../data_juicer/ops/mapper/clean_links_mapper.py) | [tests](../tests/ops/mapper/test_clean_links_mapper.py) | +| dialog_sentiment_intensity_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Mapper to predict user's sentiment intensity (from -5 to 5 in default prompt) in dialog. | [code](../data_juicer/ops/mapper/dialog_sentiment_intensity_mapper.py) | [tests](../tests/ops/mapper/test_dialog_sentiment_intensity_mapper.py) | | expand_macro_mapper | ![LaTeX](https://img.shields.io/badge/LaTeX-D99379?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Expands macros usually defined at the top of TeX documents | [code](../data_juicer/ops/mapper/expand_macro_mapper.py) | [tests](../tests/ops/mapper/test_expand_macro_mapper.py) | | extract_entity_attribute_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Extract attributes for given entities from the text. | [code](../data_juicer/ops/mapper/extract_entity_attribute_mapper.py) | [tests](../tests/ops/mapper/test_extract_entity_attribute_mapper.py) | | extract_entity_relation_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Extract entities and relations in the text for knowledge graph. | [code](../data_juicer/ops/mapper/extract_entity_relation_mapper.py) | [tests](../tests/ops/mapper/test_extract_entity_relation_mapper.py) | @@ -93,6 +94,7 @@ All the specific operators are listed below, each featured with several capabili | punctuation_normalization_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Normalizes various Unicode punctuations to their ASCII equivalents | [code](../data_juicer/ops/mapper/punctuation_normalization_mapper.py) | [tests](../tests/ops/mapper/test_punctuation_normalization_mapper.py) | | python_file_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Executing Python function defined in a file | [code](../data_juicer/ops/mapper/python_file_mapper.py) | [tests](../tests/ops/mapper/test_python_file_mapper.py) | | python_lambda_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Executing Python lambda function on data samples | [code](../data_juicer/ops/mapper/python_lambda_mapper.py) | [tests](../tests/ops/mapper/test_python_lambda_mapper.py) | +| query_sentiment_intensity_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Mapper to predict user's sentiment intensity label (-1 for 'negative', 0 for 'neutral' and 1 for 'positive') in query. | [code](../data_juicer/ops/mapper/query_sentiment_intensity_mapper.py) | [tests](../tests/ops/mapper/test_query_sentiment_intensity_mapper.py) | | relation_identity_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Identify relation between two entity in the text. | [code](../data_juicer/ops/mapper/relation_identity_mapper.py) | [tests](../tests/ops/mapper/test_relation_identity_mapper.py) | | remove_bibliography_mapper | ![LaTeX](https://img.shields.io/badge/LaTeX-D99379?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Removes the bibliography of TeX documents | [code](../data_juicer/ops/mapper/remove_bibliography_mapper.py) | [tests](../tests/ops/mapper/test_remove_bibliography_mapper.py) | | remove_comments_mapper | ![LaTeX](https://img.shields.io/badge/LaTeX-D99379?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Removes the comments of TeX documents | [code](../data_juicer/ops/mapper/remove_comments_mapper.py) | [tests](../tests/ops/mapper/test_remove_comments_mapper.py) | diff --git a/docs/Operators_ZH.md b/docs/Operators_ZH.md index b59bf6cc4..d716d9b36 100644 --- a/docs/Operators_ZH.md +++ b/docs/Operators_ZH.md @@ -11,7 +11,7 @@ Data-Juicer 中的算子分为以下 5 种类型。 | 类型 | 数量 | 描述 | |------------------------------------|:--:|---------------| | [ Formatter ]( #formatter ) | 9 | 发现、加载、规范化原始数据 | -| [ Mapper ]( #mapper ) | 63 | 对数据样本进行编辑和转换 | +| [ Mapper ]( #mapper ) | 65 | 对数据样本进行编辑和转换 | | [ Filter ]( #filter ) | 44 | 过滤低质量样本 | | [ Deduplicator ]( #deduplicator ) | 8 | 识别、删除重复样本 | | [ Selector ]( #selector ) | 4 | 基于排序选取高质量样本 | @@ -67,6 +67,7 @@ Data-Juicer 中的算子分为以下 5 种类型。 | clean_html_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 删除 HTML 标签并返回所有节点的纯文本 | [code](../data_juicer/ops/mapper/clean_html_mapper.py) | [tests](../tests/ops/mapper/test_clean_html_mapper.py) | | clean_ip_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 删除 IP 地址 | [code](../data_juicer/ops/mapper/clean_ip_mapper.py) | [tests](../tests/ops/mapper/test_clean_ip_mapper.py) | | clean_links_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![Code](https://img.shields.io/badge/Code-590F08?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 删除链接,例如以 http 或 ftp 开头的 | [code](../data_juicer/ops/mapper/clean_links_mapper.py) | [tests](../tests/ops/mapper/test_clean_links_mapper.py) | +| dialog_sentiment_intensity_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 预测对话中的情绪强度(默认从-5到5)。 | [code](../data_juicer/ops/mapper/dialog_sentiment_intensity_mapper.py) | [tests](../tests/ops/mapper/test_dialog_sentiment_intensity_mapper.py) | | expand_macro_mapper | ![LaTeX](https://img.shields.io/badge/LaTeX-D99379?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 扩展通常在 TeX 文档顶部定义的宏 | [code](../data_juicer/ops/mapper/expand_macro_mapper.py) | [tests](../tests/ops/mapper/test_expand_macro_mapper.py) | | extract_entity_attribute_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 给定主体和属性名,从文本中抽取主体的属性 | [code](../data_juicer/ops/mapper/extract_entity_attribute_mapper.py) | [tests](../tests/ops/mapper/test_extract_entity_attribute_mapper.py) | | extract_entity_relation_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 从文本中抽取知识图谱的实体和关系 | [code](../data_juicer/ops/mapper/extract_entity_relation_mapper.py) | [tests](../tests/ops/mapper/test_extract_entity_relation_mapper.py) | @@ -92,6 +93,7 @@ Data-Juicer 中的算子分为以下 5 种类型。 | punctuation_normalization_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 将各种 Unicode 标点符号标准化为其 ASCII 等效项 | [code](../data_juicer/ops/mapper/punctuation_normalization_mapper.py) | [tests](../tests/ops/mapper/test_punctuation_normalization_mapper.py) | | python_file_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 执行文件中定义的 Python 函数处理样本 | [code](../data_juicer/ops/mapper/python_file_mapper.py) | [tests](../tests/ops/mapper/test_python_file_mapper.py) | | python_lambda_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 执行 Python lambda 函数处理样本 | [code](../data_juicer/ops/mapper/python_lambda_mapper.py) | [tests](../tests/ops/mapper/test_python_lambda_mapper.py) | +| query_sentiment_intensity_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 预测用户查询中的情感强度标签(-1 表示“负面”,0 表示“中立”,1 表示“正面”)。 | [code](../data_juicer/ops/mapper/query_sentiment_intensity_mapper.py) | [tests](../tests/ops/mapper/test_query_sentiment_intensity_mapper.py) | | relation_identity_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 识别一段文本中两个实体之间的关系 | [code](../data_juicer/ops/mapper/relation_identity_mapper.py) | [tests](../tests/ops/mapper/test_relation_identity_mapper.py) | | remove_bibliography_mapper | ![LaTeX](https://img.shields.io/badge/LaTeX-D99379?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 删除 TeX 文档的参考文献 | [code](../data_juicer/ops/mapper/remove_bibliography_mapper.py) | [tests](../tests/ops/mapper/test_remove_bibliography_mapper.py) | | remove_comments_mapper | ![LaTeX](https://img.shields.io/badge/LaTeX-D99379?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 删除 TeX 文档中的注释 | [code](../data_juicer/ops/mapper/remove_comments_mapper.py) | [tests](../tests/ops/mapper/test_remove_comments_mapper.py) | From 8ba4156cfde2a1452bf2088a3f47a08f042ff4b3 Mon Sep 17 00:00:00 2001 From: Haibin <1400012807@pku.edu.cn> Date: Mon, 16 Dec 2024 15:39:57 +0800 Subject: [PATCH 071/118] sentiment detection --- data_juicer/ops/mapper/__init__.py | 8 +- .../dialog_sentiment_detection_mapper.py | 195 ++++++++++++++++++ .../dialog_sentiment_intensity_mapper.py | 4 +- data_juicer/utils/constant.py | 2 +- .../test_dialog_sentiment_detection_mapper.py | 141 +++++++++++++ 5 files changed, 344 insertions(+), 6 deletions(-) create mode 100644 data_juicer/ops/mapper/dialog_sentiment_detection_mapper.py create mode 100644 tests/ops/mapper/test_dialog_sentiment_detection_mapper.py diff --git a/data_juicer/ops/mapper/__init__.py b/data_juicer/ops/mapper/__init__.py index 1c5e06e46..7c7857cb7 100644 --- a/data_juicer/ops/mapper/__init__.py +++ b/data_juicer/ops/mapper/__init__.py @@ -8,6 +8,7 @@ from .clean_html_mapper import CleanHtmlMapper from .clean_ip_mapper import CleanIpMapper from .clean_links_mapper import CleanLinksMapper +from .dialog_sentiment_detection_mapper import DialogSentimentDetectionMapper from .dialog_sentiment_intensity_mapper import DialogSentimentIntensityMapper from .expand_macro_mapper import ExpandMacroMapper from .extract_entity_attribute_mapper import ExtractEntityAttributeMapper @@ -73,9 +74,10 @@ 'AudioFFmpegWrappedMapper', 'CalibrateQAMapper', 'CalibrateQueryMapper', 'CalibrateResponseMapper', 'ChineseConvertMapper', 'CleanCopyrightMapper', 'CleanEmailMapper', 'CleanHtmlMapper', 'CleanIpMapper', 'CleanLinksMapper', - 'DialogSentimentIntensityMapper', 'ExpandMacroMapper', - 'ExtractEntityAttributeMapper', 'ExtractEntityRelationMapper', - 'ExtractEventMapper', 'ExtractKeywordMapper', 'ExtractNicknameMapper', + 'DialogSentimentDetectionMapper', 'DialogSentimentIntensityMapper', + 'ExpandMacroMapper', 'ExtractEntityAttributeMapper', + 'ExtractEntityRelationMapper', 'ExtractEventMapper', + 'ExtractKeywordMapper', 'ExtractNicknameMapper', 'ExtractSupportTextMapper', 'FixUnicodeMapper', 'GenerateQAFromExamplesMapper', 'GenerateQAFromTextMapper', 'ImageBlurMapper', 'ImageCaptioningFromGPT4VMapper', diff --git a/data_juicer/ops/mapper/dialog_sentiment_detection_mapper.py b/data_juicer/ops/mapper/dialog_sentiment_detection_mapper.py new file mode 100644 index 000000000..e8fee8928 --- /dev/null +++ b/data_juicer/ops/mapper/dialog_sentiment_detection_mapper.py @@ -0,0 +1,195 @@ +import re +from typing import Dict, Optional + +from loguru import logger +from pydantic import NonNegativeInt, PositiveInt + +from data_juicer.ops.base_op import OPERATORS, Mapper +from data_juicer.utils.common_utils import nested_set +from data_juicer.utils.constant import Fields, MetaKeys +from data_juicer.utils.model_utils import get_model, prepare_model + +OP_NAME = 'dialog_sentiment_detection_mapper' + + +# TODO: LLM-based inference. +@OPERATORS.register_module(OP_NAME) +class DialogSentimentDetectionMapper(Mapper): + """ + Mapper to generate user's sentiment labels in dialog. Input from + history_key, query_key and response_key. Output lists of + intensities and analysis for queries in the dialog, which is + store in 'sentiment.dialog_labels' and + 'sentiment.dialog_labels_analysis' in Data-Juicer meta field. + """ + + DEFAULT_SYSTEM_PROMPT = ('请判断用户和LLM多轮对话中用户所具有的情绪。\n' + '要求:\n' + '- 需要先进行分析,然后罗列用户所具有的情绪,下面是一个样例,请模仿样例格式输出' + '。\n' + '用户:最近工作压力好大,我觉得整个人都快被压垮了。\n' + '情感分析:用户的言语中透露出明显的压力和疲惫感,可能还夹杂着一些无助和焦虑。\n' + '情感:压力、疲惫、无助、焦虑\n' + 'LLM:听起来你真的承受了很多,面临这种情况确实不容易。有没有考虑过找一些放松的' + '方式,比如听音乐或者散步来减轻压力呢?\n' + '用户:试过了,但是好像没什么效果,每天的事情都堆积如山。\n' + '情感分析:用户感到无力解决现状,有挫败感,并且对尝试放松的方式失去信心。\n' + '情感:无力、挫败\n' + 'LLM:我理解你的感受,有时候压力积累到一定程度确实让人难以承受。或许你可以尝试' + '规划一下时间,把任务分成小块来完成,这样可能会减少一些压力感。\n' + '用户:这个主意不错,我会试着让自己更有条理一些,谢谢你的建议。\n' + '情感分析:用户对建议表现出认同和感激,同时展现出试图积极面对问题的态度。\n' + '情感:认同、感激、积极\n' + 'LLM:不用谢,我很高兴能帮到你。记得给自己一些时间去适应新的计划,有任何需要' + '随时可以跟我说哦!\n') + DEFAULT_QUERY_TEMPLATE = '用户:{query}\n' + DEFAULT_RESPONSE_TEMPLATE = 'LLM:{response}\n' + DEFAULT_ANALYSIS_TEMPLATE = '情绪分析:{analysis}\n' + DEFAULT_INTENSITY_TEMPLATE = '情感:{labels}\n' + DEFAULT_ANALYSIS_PATTERN = '情绪分析:(.*?)\n' + DEFAULT_INTENSITY_PATTERN = '情感:(.*?)($|\n)' + + def __init__(self, + api_model: str = 'gpt-4o', + max_round: NonNegativeInt = 10, + *, + api_endpoint: Optional[str] = None, + response_path: Optional[str] = None, + system_prompt: Optional[str] = None, + query_template: Optional[str] = None, + response_template: Optional[str] = None, + analysis_template: Optional[str] = None, + labels_template: Optional[str] = None, + analysis_pattern: Optional[str] = None, + labels_pattern: Optional[str] = None, + try_num: PositiveInt = 3, + model_params: Dict = {}, + sampling_params: Dict = {}, + **kwargs): + """ + Initialization method. + + :param api_model: API model name. + :param max_round: The max num of round in the dialog to build the + prompt. + :param api_endpoint: URL endpoint for the API. + :param response_path: Path to extract content from the API response. + Defaults to 'choices.0.message.content'. + :param system_prompt: System prompt for the task. + :param query_template: Template for query part to build the input + prompt. + :param response_template: Template for response part to build the + input prompt. + :param analysis_template: Template for analysis part to build the + input prompt. + :param labels_template: Template for labels part to build the + input prompt. + :param analysis_pattern: Pattern to parse the return sentiment + analysis. + :param labels_pattern: Pattern to parse the return sentiment + labels. + :param try_num: The number of retry attempts when there is an API + call error or output parsing error. + :param model_params: Parameters for initializing the API model. + :param sampling_params: Extra parameters passed to the API call. + e.g {'temperature': 0.9, 'top_p': 0.95} + :param kwargs: Extra keyword arguments. + """ + super().__init__(**kwargs) + + self.max_round = max_round + + self.system_prompt = system_prompt or self.DEFAULT_SYSTEM_PROMPT + self.query_template = query_template or self.DEFAULT_QUERY_TEMPLATE + self.response_template = response_template or \ + self.DEFAULT_RESPONSE_TEMPLATE + self.analysis_template = analysis_template or \ + self.DEFAULT_ANALYSIS_TEMPLATE + self.labels_template = labels_template or \ + self.DEFAULT_INTENSITY_TEMPLATE + self.analysis_pattern = analysis_pattern or \ + self.DEFAULT_ANALYSIS_PATTERN + self.labels_pattern = labels_pattern or \ + self.DEFAULT_INTENSITY_PATTERN + + self.sampling_params = sampling_params + + self.model_key = prepare_model(model_type='api', + model=api_model, + endpoint=api_endpoint, + response_path=response_path, + **model_params) + + self.try_num = try_num + + def build_input(self, history, query): + if self.max_round > 0: + input_prompt = ''.join(history[-self.max_round * 4:]) + else: + input_prompt = '' + input_prompt += self.query_template.format(query=query[0]) + + return input_prompt + + def parse_output(self, response): + analysis = '' + labels = '' + + match = re.search(self.analysis_pattern, response) + if match: + analysis = match.group(1) + + match = re.search(self.labels_pattern, response) + if match: + labels = match.group(1) + + return analysis, labels + + def process_single(self, sample, rank=None): + client = get_model(self.model_key, rank=rank) + + analysis_list = [] + labels_list = [] + history = [] + + dialog = sample[self.history_key] + if self.query_key in sample and sample[self.query_key]: + if self.response_key in sample and sample[self.response_key]: + dialog.append( + (sample[self.query_key], sample[self.response_key])) + else: + dialog.append((sample[self.query_key], '')) + + for qa in dialog: + input_prompt = self.build_input(history, qa) + messages = [{ + 'role': 'system', + 'content': self.system_prompt, + }, { + 'role': 'user', + 'content': input_prompt, + }] + + for _ in range(self.try_num): + try: + response = client(messages, **self.sampling_params) + analysis, labels = self.parse_output(response) + if len(analysis) > 0: + break + except Exception as e: + logger.warning(f'Exception: {e}') + + analysis_list.append(analysis) + labels_list.append(labels) + + history.append(self.query_template.format(query=qa[0])) + history.append(self.analysis_template.format(analysis=analysis)) + history.append(self.labels_template.format(labels=labels)) + history.append(self.response_template.format(response=qa[1])) + + analysis_key = f'{Fields.meta}.{MetaKeys.dialog_sentiment_analysis}' + sample = nested_set(sample, analysis_key, analysis_list) + labels_key = f'{Fields.meta}.{MetaKeys.dialog_sentiment_labels}' + sample = nested_set(sample, labels_key, labels_list) + + return sample diff --git a/data_juicer/ops/mapper/dialog_sentiment_intensity_mapper.py b/data_juicer/ops/mapper/dialog_sentiment_intensity_mapper.py index 02161d257..db2157447 100644 --- a/data_juicer/ops/mapper/dialog_sentiment_intensity_mapper.py +++ b/data_juicer/ops/mapper/dialog_sentiment_intensity_mapper.py @@ -20,7 +20,7 @@ class DialogSentimentIntensityMapper(Mapper): prompt) in dialog. Input from history_key, query_key and response_key. Output lists of intensities and analysis for queries in the dialog, which is store in 'sentiment.dialog_intensity' and - 'sentiment.dialog_analysis' in Data-Juicer meta field. + 'sentiment.dialog_intensity_analysis' in Data-Juicer meta field. """ DEFAULT_SYSTEM_PROMPT = ('请判断用户和LLM多轮对话中用户的情绪变化。\n' @@ -172,7 +172,7 @@ def process_single(self, sample, rank=None): else: dialog.append((sample[self.query_key], '')) - for qa in sample[self.history_key]: + for qa in dialog: input_prompt = self.build_input(history, qa) messages = [{ 'role': 'system', diff --git a/data_juicer/utils/constant.py b/data_juicer/utils/constant.py index b0f637895..ef9083da6 100644 --- a/data_juicer/utils/constant.py +++ b/data_juicer/utils/constant.py @@ -72,7 +72,7 @@ class Fields(object): class MetaKeys(object): dialog_sentiment_intensity = 'sentiment.dialog_intensity' - dialog_sentiment_analysis = 'sentiment.dialog_analysis' + dialog_sentiment_analysis = 'sentiment.dialog_intensity_analysis' query_sentiment_intensity = 'sentiment.query_intensity' query_sentiment_score = 'sentiment.query_intensity_score' diff --git a/tests/ops/mapper/test_dialog_sentiment_detection_mapper.py b/tests/ops/mapper/test_dialog_sentiment_detection_mapper.py new file mode 100644 index 000000000..5af49398c --- /dev/null +++ b/tests/ops/mapper/test_dialog_sentiment_detection_mapper.py @@ -0,0 +1,141 @@ +import unittest +import json + +from loguru import logger + +from data_juicer.core.data import NestedDataset as Dataset +from data_juicer.ops.mapper.dialog_sentiment_detection_mapper import DialogSentimentDetectionMapper +from data_juicer.utils.unittest_utils import (SKIPPED_TESTS, + DataJuicerTestCaseBase) +from data_juicer.utils.constant import Fields, MetaKeys +from data_juicer.utils.common_utils import nested_access + +# Skip tests for this OP. +# These tests have been tested locally. +@SKIPPED_TESTS.register_module() +class DialogSentimentDetectionMapper(DataJuicerTestCaseBase): + # before runing this test, set below environment variables: + # export OPENAI_API_URL=https://dashscope.aliyuncs.com/compatible-mode/v1 + # export OPENAI_API_KEY=your_key + + def _run_op(self, op, samples, target_len): + dataset = Dataset.from_list(samples) + dataset = dataset.map(op.process, batch_size=2) + analysis_list = nested_access(dataset[0][Fields.meta], MetaKeys.dialog_sentiment_analysis) + labels_list = nested_access(dataset[0][Fields.meta], MetaKeys.dialog_sentiment_labels) + + for analysis, labels in zip(analysis_list, labels_list): + logger.info(f'分析:{analysis}') + logger.info(f'情绪:{labels}') + + self.assertEqual(len(analysis_list), target_len) + self.assertEqual(len(intensity_list), target_len) + + def test_default(self): + + samples = [{ + 'history': [ + ( + '李莲花有口皆碑', + '「微笑」过奖了,我也就是个普通大夫,没什么值得夸耀的。' + ), + ( + '是的,你确实是一个普通大夫,没什么值得夸耀的。', + '「委屈」你这话说的,我也是尽心尽力治病救人了。' + ), + ( + '你自己说的呀,我现在说了,你又不高兴了。', + 'or of of of of or or and or of of of of of of of,,, ' + ), + ( + '你在说什么我听不懂。', + '「委屈」我也没说什么呀,就是觉得你有点冤枉我了' + ) + ] + }] + + op = DialogSentimentDetectionMapper(api_model='qwen2.5-72b-instruct') + self._run_op(op, samples, 4) + + def test_max_round(self): + + samples = [{ + 'history': [ + ( + '李莲花有口皆碑', + '「微笑」过奖了,我也就是个普通大夫,没什么值得夸耀的。' + ), + ( + '是的,你确实是一个普通大夫,没什么值得夸耀的。', + '「委屈」你这话说的,我也是尽心尽力治病救人了。' + ), + ( + '你自己说的呀,我现在说了,你又不高兴了。', + 'or of of of of or or and or of of of of of of of,,, ' + ), + ( + '你在说什么我听不懂。', + '「委屈」我也没说什么呀,就是觉得你有点冤枉我了' + ) + ] + }] + + op = DialogSentimentDetectionMapper(api_model='qwen2.5-72b-instruct', + max_round=1) + self._run_op(op, samples, 4) + + def test_max_round_zero(self): + + samples = [{ + 'history': [ + ( + '李莲花有口皆碑', + '「微笑」过奖了,我也就是个普通大夫,没什么值得夸耀的。' + ), + ( + '是的,你确实是一个普通大夫,没什么值得夸耀的。', + '「委屈」你这话说的,我也是尽心尽力治病救人了。' + ), + ( + '你自己说的呀,我现在说了,你又不高兴了。', + 'or of of of of or or and or of of of of of of of,,, ' + ), + ( + '你在说什么我听不懂。', + '「委屈」我也没说什么呀,就是觉得你有点冤枉我了' + ) + ] + }] + + op = DialogSentimentDetectionMapper(api_model='qwen2.5-72b-instruct', + max_round=0) + self._run_op(op, samples, 4) + + def test_query(self): + + samples = [{ + 'history': [ + ( + '李莲花有口皆碑', + '「微笑」过奖了,我也就是个普通大夫,没什么值得夸耀的。' + ), + ( + '是的,你确实是一个普通大夫,没什么值得夸耀的。', + '「委屈」你这话说的,我也是尽心尽力治病救人了。' + ), + ( + '你自己说的呀,我现在说了,你又不高兴了。', + 'or of of of of or or and or of of of of of of of,,, ' + ) + ], + 'query': '你在说什么我听不懂。', + 'response': '「委屈」我也没说什么呀,就是觉得你有点冤枉我了' + }] + + op = DialogSentimentDetectionMapper(api_model='qwen2.5-72b-instruct', + max_round=1) + self._run_op(op, samples, 4) + + +if __name__ == '__main__': + unittest.main() From 48b17614021a29d6711058e0ca177fb12f2c6575 Mon Sep 17 00:00:00 2001 From: Haibin <1400012807@pku.edu.cn> Date: Mon, 16 Dec 2024 16:25:07 +0800 Subject: [PATCH 072/118] diff label --- data_juicer/ops/mapper/dialog_sentiment_detection_mapper.py | 4 ++-- data_juicer/ops/mapper/dialog_sentiment_intensity_mapper.py | 2 +- data_juicer/utils/constant.py | 4 +++- tests/ops/mapper/test_dialog_sentiment_detection_mapper.py | 2 +- tests/ops/mapper/test_dialog_sentiment_intensity_mapper.py | 2 +- 5 files changed, 8 insertions(+), 6 deletions(-) diff --git a/data_juicer/ops/mapper/dialog_sentiment_detection_mapper.py b/data_juicer/ops/mapper/dialog_sentiment_detection_mapper.py index e8fee8928..4d7197f0d 100644 --- a/data_juicer/ops/mapper/dialog_sentiment_detection_mapper.py +++ b/data_juicer/ops/mapper/dialog_sentiment_detection_mapper.py @@ -14,7 +14,7 @@ # TODO: LLM-based inference. @OPERATORS.register_module(OP_NAME) -class DialogSentimentDetectionMapper(Mapper): +class TestDialogSentimentDetectionMapper(Mapper): """ Mapper to generate user's sentiment labels in dialog. Input from history_key, query_key and response_key. Output lists of @@ -187,7 +187,7 @@ def process_single(self, sample, rank=None): history.append(self.labels_template.format(labels=labels)) history.append(self.response_template.format(response=qa[1])) - analysis_key = f'{Fields.meta}.{MetaKeys.dialog_sentiment_analysis}' + analysis_key = f'{Fields.meta}.{MetaKeys.dialog_sentiment_labels_analysis}' # noqa: E501 sample = nested_set(sample, analysis_key, analysis_list) labels_key = f'{Fields.meta}.{MetaKeys.dialog_sentiment_labels}' sample = nested_set(sample, labels_key, labels_list) diff --git a/data_juicer/ops/mapper/dialog_sentiment_intensity_mapper.py b/data_juicer/ops/mapper/dialog_sentiment_intensity_mapper.py index db2157447..56372c6f3 100644 --- a/data_juicer/ops/mapper/dialog_sentiment_intensity_mapper.py +++ b/data_juicer/ops/mapper/dialog_sentiment_intensity_mapper.py @@ -199,7 +199,7 @@ def process_single(self, sample, rank=None): history.append(self.intensity_template.format(intensity=intensity)) history.append(self.response_template.format(response=qa[1])) - analysis_key = f'{Fields.meta}.{MetaKeys.dialog_sentiment_analysis}' + analysis_key = f'{Fields.meta}.{MetaKeys.dialog_sentiment_intensity_analysis}' # noqa: E501 sample = nested_set(sample, analysis_key, analysis_list) intensity_key = f'{Fields.meta}.{MetaKeys.dialog_sentiment_intensity}' sample = nested_set(sample, intensity_key, intensities) diff --git a/data_juicer/utils/constant.py b/data_juicer/utils/constant.py index ef9083da6..07c75f5a8 100644 --- a/data_juicer/utils/constant.py +++ b/data_juicer/utils/constant.py @@ -72,9 +72,11 @@ class Fields(object): class MetaKeys(object): dialog_sentiment_intensity = 'sentiment.dialog_intensity' - dialog_sentiment_analysis = 'sentiment.dialog_intensity_analysis' + dialog_sentiment_intensity_analysis = 'sentiment.dialog_intensity_analysis' query_sentiment_intensity = 'sentiment.query_intensity' query_sentiment_score = 'sentiment.query_intensity_score' + dialog_sentiment_labels = 'sentiment.dialog_labels' + dialog_sentiment_labels_analysis = 'sentiment.dialog_labels_analysis' class StatsKeysMeta(type): diff --git a/tests/ops/mapper/test_dialog_sentiment_detection_mapper.py b/tests/ops/mapper/test_dialog_sentiment_detection_mapper.py index 5af49398c..680be8eb9 100644 --- a/tests/ops/mapper/test_dialog_sentiment_detection_mapper.py +++ b/tests/ops/mapper/test_dialog_sentiment_detection_mapper.py @@ -21,7 +21,7 @@ class DialogSentimentDetectionMapper(DataJuicerTestCaseBase): def _run_op(self, op, samples, target_len): dataset = Dataset.from_list(samples) dataset = dataset.map(op.process, batch_size=2) - analysis_list = nested_access(dataset[0][Fields.meta], MetaKeys.dialog_sentiment_analysis) + analysis_list = nested_access(dataset[0][Fields.meta], MetaKeys.dialog_sentiment_labels_analysis) labels_list = nested_access(dataset[0][Fields.meta], MetaKeys.dialog_sentiment_labels) for analysis, labels in zip(analysis_list, labels_list): diff --git a/tests/ops/mapper/test_dialog_sentiment_intensity_mapper.py b/tests/ops/mapper/test_dialog_sentiment_intensity_mapper.py index 72ec64cd4..a8953c3e4 100644 --- a/tests/ops/mapper/test_dialog_sentiment_intensity_mapper.py +++ b/tests/ops/mapper/test_dialog_sentiment_intensity_mapper.py @@ -21,7 +21,7 @@ class TestDialogSentimentIntensityMapper(DataJuicerTestCaseBase): def _run_op(self, op, samples, target_len): dataset = Dataset.from_list(samples) dataset = dataset.map(op.process, batch_size=2) - analysis_list = nested_access(dataset[0][Fields.meta], MetaKeys.dialog_sentiment_analysis) + analysis_list = nested_access(dataset[0][Fields.meta], MetaKeys.dialog_sentiment_intensity_analysis) intensity_list = nested_access(dataset[0][Fields.meta], MetaKeys.dialog_sentiment_intensity) for analysis, intensity in zip(analysis_list, intensity_list): From 81607258fa39a6675f595d6b682057b041f2276e Mon Sep 17 00:00:00 2001 From: Haibin <1400012807@pku.edu.cn> Date: Mon, 16 Dec 2024 16:29:44 +0800 Subject: [PATCH 073/118] sentiment --- data_juicer/ops/mapper/dialog_sentiment_detection_mapper.py | 2 +- tests/ops/mapper/test_dialog_sentiment_detection_mapper.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/data_juicer/ops/mapper/dialog_sentiment_detection_mapper.py b/data_juicer/ops/mapper/dialog_sentiment_detection_mapper.py index 4d7197f0d..ad2bd8fc8 100644 --- a/data_juicer/ops/mapper/dialog_sentiment_detection_mapper.py +++ b/data_juicer/ops/mapper/dialog_sentiment_detection_mapper.py @@ -14,7 +14,7 @@ # TODO: LLM-based inference. @OPERATORS.register_module(OP_NAME) -class TestDialogSentimentDetectionMapper(Mapper): +class DialogSentimentDetectionMapper(Mapper): """ Mapper to generate user's sentiment labels in dialog. Input from history_key, query_key and response_key. Output lists of diff --git a/tests/ops/mapper/test_dialog_sentiment_detection_mapper.py b/tests/ops/mapper/test_dialog_sentiment_detection_mapper.py index 680be8eb9..0bb92f30e 100644 --- a/tests/ops/mapper/test_dialog_sentiment_detection_mapper.py +++ b/tests/ops/mapper/test_dialog_sentiment_detection_mapper.py @@ -13,7 +13,7 @@ # Skip tests for this OP. # These tests have been tested locally. @SKIPPED_TESTS.register_module() -class DialogSentimentDetectionMapper(DataJuicerTestCaseBase): +class TestDialogSentimentDetectionMapper(DataJuicerTestCaseBase): # before runing this test, set below environment variables: # export OPENAI_API_URL=https://dashscope.aliyuncs.com/compatible-mode/v1 # export OPENAI_API_KEY=your_key From 01846d100d6d7c8ab0865d3643ccdad7ade58855 Mon Sep 17 00:00:00 2001 From: Haibin <1400012807@pku.edu.cn> Date: Mon, 16 Dec 2024 16:48:16 +0800 Subject: [PATCH 074/118] test done --- data_juicer/ops/mapper/dialog_sentiment_detection_mapper.py | 4 ++-- tests/ops/mapper/test_dialog_sentiment_detection_mapper.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/data_juicer/ops/mapper/dialog_sentiment_detection_mapper.py b/data_juicer/ops/mapper/dialog_sentiment_detection_mapper.py index ad2bd8fc8..e76a61e62 100644 --- a/data_juicer/ops/mapper/dialog_sentiment_detection_mapper.py +++ b/data_juicer/ops/mapper/dialog_sentiment_detection_mapper.py @@ -44,9 +44,9 @@ class DialogSentimentDetectionMapper(Mapper): '随时可以跟我说哦!\n') DEFAULT_QUERY_TEMPLATE = '用户:{query}\n' DEFAULT_RESPONSE_TEMPLATE = 'LLM:{response}\n' - DEFAULT_ANALYSIS_TEMPLATE = '情绪分析:{analysis}\n' + DEFAULT_ANALYSIS_TEMPLATE = '情感分析:{analysis}\n' DEFAULT_INTENSITY_TEMPLATE = '情感:{labels}\n' - DEFAULT_ANALYSIS_PATTERN = '情绪分析:(.*?)\n' + DEFAULT_ANALYSIS_PATTERN = '情感分析:(.*?)\n' DEFAULT_INTENSITY_PATTERN = '情感:(.*?)($|\n)' def __init__(self, diff --git a/tests/ops/mapper/test_dialog_sentiment_detection_mapper.py b/tests/ops/mapper/test_dialog_sentiment_detection_mapper.py index 0bb92f30e..b19bf6359 100644 --- a/tests/ops/mapper/test_dialog_sentiment_detection_mapper.py +++ b/tests/ops/mapper/test_dialog_sentiment_detection_mapper.py @@ -29,7 +29,7 @@ def _run_op(self, op, samples, target_len): logger.info(f'情绪:{labels}') self.assertEqual(len(analysis_list), target_len) - self.assertEqual(len(intensity_list), target_len) + self.assertEqual(len(labels_list), target_len) def test_default(self): From a76d9751bbb63c90666726ae64741a79b34f3f62 Mon Sep 17 00:00:00 2001 From: Haibin <1400012807@pku.edu.cn> Date: Tue, 17 Dec 2024 11:10:38 +0800 Subject: [PATCH 075/118] dialog intent label --- configs/config_all.yaml | 2 +- data_juicer/ops/mapper/__init__.py | 13 +- .../mapper/dialog_intent_detection_mapper.py | 213 ++++++++++++++++++ .../dialog_sentiment_detection_mapper.py | 8 +- ...py => query_sentiment_detection_mapper.py} | 10 +- data_juicer/utils/auto_install_mapping.py | 2 +- data_juicer/utils/constant.py | 9 +- docs/Operators.md | 2 +- docs/Operators_ZH.md | 2 +- ...est_dialog_intent_detection_mapper copy.py | 141 ++++++++++++ ... test_query_sentiment_detection_mapper.py} | 20 +- 11 files changed, 391 insertions(+), 31 deletions(-) create mode 100644 data_juicer/ops/mapper/dialog_intent_detection_mapper.py rename data_juicer/ops/mapper/{query_sentiment_intensity_mapper.py => query_sentiment_detection_mapper.py} (94%) create mode 100644 tests/ops/mapper/test_dialog_intent_detection_mapper copy.py rename tests/ops/mapper/{test_query_sentiment_intensity_mapper.py => test_query_sentiment_detection_mapper.py} (79%) diff --git a/configs/config_all.yaml b/configs/config_all.yaml index f67ff5674..7e450f6ff 100644 --- a/configs/config_all.yaml +++ b/configs/config_all.yaml @@ -292,7 +292,7 @@ process: - python_lambda_mapper: # executing Python lambda function on data samples. lambda_str: '' # A string representation of the lambda function to be executed on data samples. If empty, the identity function is used. batched: False # A boolean indicating whether to process input data in batches. - - query_sentiment_intensity_mapper: # Mapper to predict user's sentiment intensity label (-1 for 'negative', 0 for 'neutral' and 1 for 'positive') in query. + - query_sentiment_detection_mapper: # Mapper to predict user's sentiment intensity label (-1 for 'negative', 0 for 'neutral' and 1 for 'positive') in query. hf_model: 'mrm8488/distilroberta-finetuned-financial-news-sentiment-analysis' # Hugginface model ID to predict sentiment intensity. zh_to_en_hf_model: 'Helsinki-NLP/opus-mt-zh-en' # Translation model from Chinese to English. If not None, translate the query from Chinese to English. model_params: {} # model param for hf_model. diff --git a/data_juicer/ops/mapper/__init__.py b/data_juicer/ops/mapper/__init__.py index 7c7857cb7..931cd7f2a 100644 --- a/data_juicer/ops/mapper/__init__.py +++ b/data_juicer/ops/mapper/__init__.py @@ -8,6 +8,7 @@ from .clean_html_mapper import CleanHtmlMapper from .clean_ip_mapper import CleanIpMapper from .clean_links_mapper import CleanLinksMapper +from .dialog_intent_detection_mapper import DialogIntentDetectionMapper from .dialog_sentiment_detection_mapper import DialogSentimentDetectionMapper from .dialog_sentiment_intensity_mapper import DialogSentimentIntensityMapper from .expand_macro_mapper import ExpandMacroMapper @@ -35,7 +36,7 @@ from .punctuation_normalization_mapper import PunctuationNormalizationMapper from .python_file_mapper import PythonFileMapper from .python_lambda_mapper import PythonLambdaMapper -from .query_sentiment_intensity_mapper import QuerySentimentLabelMapper +from .query_sentiment_detection_mapper import QuerySentimentDetectionMapper from .relation_identity_mapper import RelationIdentityMapper from .remove_bibliography_mapper import RemoveBibliographyMapper from .remove_comments_mapper import RemoveCommentsMapper @@ -74,10 +75,10 @@ 'AudioFFmpegWrappedMapper', 'CalibrateQAMapper', 'CalibrateQueryMapper', 'CalibrateResponseMapper', 'ChineseConvertMapper', 'CleanCopyrightMapper', 'CleanEmailMapper', 'CleanHtmlMapper', 'CleanIpMapper', 'CleanLinksMapper', - 'DialogSentimentDetectionMapper', 'DialogSentimentIntensityMapper', - 'ExpandMacroMapper', 'ExtractEntityAttributeMapper', - 'ExtractEntityRelationMapper', 'ExtractEventMapper', - 'ExtractKeywordMapper', 'ExtractNicknameMapper', + 'DialogIntentDetectionMapper', 'DialogSentimentDetectionMapper', + 'DialogSentimentIntensityMapper', 'ExpandMacroMapper', + 'ExtractEntityAttributeMapper', 'ExtractEntityRelationMapper', + 'ExtractEventMapper', 'ExtractKeywordMapper', 'ExtractNicknameMapper', 'ExtractSupportTextMapper', 'FixUnicodeMapper', 'GenerateQAFromExamplesMapper', 'GenerateQAFromTextMapper', 'ImageBlurMapper', 'ImageCaptioningFromGPT4VMapper', @@ -85,7 +86,7 @@ 'ImageTaggingMapper', 'NlpaugEnMapper', 'NlpcdaZhMapper', 'OptimizeQAMapper', 'OptimizeQueryMapper', 'OptimizeResponseMapper', 'PairPreferenceMapper', 'PunctuationNormalizationMapper', - 'PythonFileMapper', 'PythonLambdaMapper', 'QuerySentimentLabelMapper', + 'PythonFileMapper', 'PythonLambdaMapper', 'QuerySentimentDetectionMapper', 'RelationIdentityMapper', 'RemoveBibliographyMapper', 'RemoveCommentsMapper', 'RemoveHeaderMapper', 'RemoveLongWordsMapper', 'RemoveNonChineseCharacterlMapper', 'RemoveRepeatSentencesMapper', diff --git a/data_juicer/ops/mapper/dialog_intent_detection_mapper.py b/data_juicer/ops/mapper/dialog_intent_detection_mapper.py new file mode 100644 index 000000000..912fa8c3c --- /dev/null +++ b/data_juicer/ops/mapper/dialog_intent_detection_mapper.py @@ -0,0 +1,213 @@ +import re +from typing import Dict, List, Optional + +from loguru import logger +from pydantic import NonNegativeInt, PositiveInt + +from data_juicer.ops.base_op import OPERATORS, Mapper +from data_juicer.utils.common_utils import nested_set +from data_juicer.utils.constant import Fields, MetaKeys +from data_juicer.utils.model_utils import get_model, prepare_model + +OP_NAME = 'dialog_intent_detection_mapper' + + +# TODO: LLM-based inference. +@OPERATORS.register_module(OP_NAME) +class DialogIntentDetectionMapper(Mapper): + """ + Mapper to generate user's intent labels in dialog. Input from + history_key, query_key and response_key. Output lists of + intensities and analysis for queries in the dialog, which is + store in 'intent.dialog_labels' and + 'intent.dialog_labels_analysis' in Data-Juicer meta field. + """ + + DEFAULT_SYSTEM_PROMPT = ( + '请判断用户和LLM多轮对话中用户的意图。\n' + '要求:\n' + '- 需要先进行分析,然后列出用户所具有的意图,下面是一个样例,请模仿样例格式输出' + '。\n' + '备选意图类别:[信息查找, 请求建议, 其他]\n' + '用户:你好,我最近对人工智能很感兴趣,能给我讲讲什么是机器学习吗?\n' + '意图分析:用户在请求信息,希望了解有关机器学习的基础知识。\n' + '意图类别:信息查找\n' + 'LLM:你好!当然可以。机器学习是一种人工智能方法,允许计算机通过数据自动改进和学习。\n' + '用户:听起来很有趣,有没有推荐的入门书籍或资料?\n' + '意图分析:用户在请求建议,希望获取关于机器学习的入门资源。\n' + '意图类别:请求建议\n' + 'LLM:有很多不错的入门书籍和资源。一本常被推荐的书是《Python机器学习实践》(Python' + ' Machine Learning),它涵盖了基础知识和一些实际案例。此外,您还可以参考Coursera' + '或edX上的在线课程,这些课程提供了系统的学习路径。\n' + '用户:谢谢你的建议!我还想知道,学习机器学习需要什么样的数学基础?\n' + '意图分析:用户在寻求信息,希望了解学习机器学习所需的前提条件,特别是在数学方面。\n' + '意图类别:信息查找\n' + 'LLM:学习机器学习通常需要一定的数学基础,特别是线性代数、概率论和统计学。这些数学领' + '域帮助理解算法的工作原理和数据模式分析。如果您对这些主题不太熟悉,建议先从相关基础' + '书籍或在线资源开始学习。\n' + '用户:明白了,我会先补习这些基础知识。再次感谢你的帮助!\n' + '意图分析:用户表达感谢,并表示计划付诸行动来补充所需的基础知识。\n' + '意图类别:其他') + DEFAULT_QUERY_TEMPLATE = '用户:{query}\n' + DEFAULT_RESPONSE_TEMPLATE = 'LLM:{response}\n' + DEFAULT_CANDIDATES_TEMPLATE = '备选意图类别:[{candidate_str}]' + DEFAULT_ANALYSIS_TEMPLATE = '意图分析:{analysis}\n' + DEFAULT_LABELS_TEMPLATE = '意图类别:{labels}\n' + DEFAULT_ANALYSIS_PATTERN = '意图分析:(.*?)\n' + DEFAULT_LABELS_PATTERN = '意图类别:(.*?)($|\n)' + + def __init__(self, + api_model: str = 'gpt-4o', + intent_candidates: Optional[List[str]] = None, + max_round: NonNegativeInt = 10, + *, + api_endpoint: Optional[str] = None, + response_path: Optional[str] = None, + system_prompt: Optional[str] = None, + query_template: Optional[str] = None, + response_template: Optional[str] = None, + candidate_template: Optional[str] = None, + analysis_template: Optional[str] = None, + labels_template: Optional[str] = None, + analysis_pattern: Optional[str] = None, + labels_pattern: Optional[str] = None, + try_num: PositiveInt = 3, + model_params: Dict = {}, + sampling_params: Dict = {}, + **kwargs): + """ + Initialization method. + + :param api_model: API model name. + :param max_round: The max num of round in the dialog to build the + prompt. + :param api_endpoint: URL endpoint for the API. + :param response_path: Path to extract content from the API response. + Defaults to 'choices.0.message.content'. + :param system_prompt: System prompt for the task. + :param query_template: Template for query part to build the input + prompt. + :param response_template: Template for response part to build the + input prompt. + :param candidate_template: Template for intent candidates to + build the input prompt. + :param analysis_template: Template for analysis part to build the + input prompt. + :param labels_template: Template for labels to build the + input prompt. + :param analysis_pattern: Pattern to parse the return intent + analysis. + :param labels_pattern: Pattern to parse the return intent + labels. + :param try_num: The number of retry attempts when there is an API + call error or output parsing error. + :param model_params: Parameters for initializing the API model. + :param sampling_params: Extra parameters passed to the API call. + e.g {'temperature': 0.9, 'top_p': 0.95} + :param kwargs: Extra keyword arguments. + """ + super().__init__(**kwargs) + + self.intent_candidates = intent_candidates + self.max_round = max_round + + self.system_prompt = system_prompt or self.DEFAULT_SYSTEM_PROMPT + self.query_template = query_template or self.DEFAULT_QUERY_TEMPLATE + self.response_template = response_template or \ + self.DEFAULT_RESPONSE_TEMPLATE + self.candidate_template = candidate_template or \ + self.DEFAULT_CANDIDATES_TEMPLATE + self.analysis_template = analysis_template or \ + self.DEFAULT_ANALYSIS_TEMPLATE + self.labels_template = labels_template or \ + self.DEFAULT_LABELS_TEMPLATE + self.analysis_pattern = analysis_pattern or \ + self.DEFAULT_ANALYSIS_PATTERN + self.labels_pattern = labels_pattern or \ + self.DEFAULT_LABELS_PATTERN + + self.sampling_params = sampling_params + + self.model_key = prepare_model(model_type='api', + model=api_model, + endpoint=api_endpoint, + response_path=response_path, + **model_params) + + self.try_num = try_num + + def build_input(self, history, query): + + if self.intent_candidates: + input_prompt = self.candidate_template.format( + candidate_str=','.join(self.intent_candidates)) + + if self.max_round > 0: + input_prompt += ''.join(history[-self.max_round * 4:]) + + input_prompt += self.query_template.format(query=query[0]) + + return input_prompt + + def parse_output(self, response): + analysis = '' + labels = '' + + match = re.search(self.analysis_pattern, response) + if match: + analysis = match.group(1) + + match = re.search(self.labels_pattern, response) + if match: + labels = match.group(1) + + return analysis, labels + + def process_single(self, sample, rank=None): + client = get_model(self.model_key, rank=rank) + + analysis_list = [] + labels_list = [] + history = [] + + dialog = sample[self.history_key] + if self.query_key in sample and sample[self.query_key]: + if self.response_key in sample and sample[self.response_key]: + dialog.append( + (sample[self.query_key], sample[self.response_key])) + else: + dialog.append((sample[self.query_key], '')) + + for qa in dialog: + input_prompt = self.build_input(history, qa) + messages = [{ + 'role': 'system', + 'content': self.system_prompt, + }, { + 'role': 'user', + 'content': input_prompt, + }] + + for _ in range(self.try_num): + try: + response = client(messages, **self.sampling_params) + analysis, labels = self.parse_output(response) + if len(analysis) > 0: + break + except Exception as e: + logger.warning(f'Exception: {e}') + + analysis_list.append(analysis) + labels_list.append(labels) + + history.append(self.query_template.format(query=qa[0])) + history.append(self.analysis_template.format(analysis=analysis)) + history.append(self.labels_template.format(labels=labels)) + history.append(self.response_template.format(response=qa[1])) + + analysis_key = f'{Fields.meta}.{MetaKeys.dialog_intent_labels_analysis}' # noqa: E501 + sample = nested_set(sample, analysis_key, analysis_list) + labels_key = f'{Fields.meta}.{MetaKeys.dialog_intent_labels}' + sample = nested_set(sample, labels_key, labels_list) + + return sample diff --git a/data_juicer/ops/mapper/dialog_sentiment_detection_mapper.py b/data_juicer/ops/mapper/dialog_sentiment_detection_mapper.py index e76a61e62..237c217d5 100644 --- a/data_juicer/ops/mapper/dialog_sentiment_detection_mapper.py +++ b/data_juicer/ops/mapper/dialog_sentiment_detection_mapper.py @@ -45,9 +45,9 @@ class DialogSentimentDetectionMapper(Mapper): DEFAULT_QUERY_TEMPLATE = '用户:{query}\n' DEFAULT_RESPONSE_TEMPLATE = 'LLM:{response}\n' DEFAULT_ANALYSIS_TEMPLATE = '情感分析:{analysis}\n' - DEFAULT_INTENSITY_TEMPLATE = '情感:{labels}\n' + DEFAULT_LABELS_TEMPLATE = '情感:{labels}\n' DEFAULT_ANALYSIS_PATTERN = '情感分析:(.*?)\n' - DEFAULT_INTENSITY_PATTERN = '情感:(.*?)($|\n)' + DEFAULT_LABELS_PATTERN = '情感:(.*?)($|\n)' def __init__(self, api_model: str = 'gpt-4o', @@ -106,11 +106,11 @@ def __init__(self, self.analysis_template = analysis_template or \ self.DEFAULT_ANALYSIS_TEMPLATE self.labels_template = labels_template or \ - self.DEFAULT_INTENSITY_TEMPLATE + self.DEFAULT_LABELS_TEMPLATE self.analysis_pattern = analysis_pattern or \ self.DEFAULT_ANALYSIS_PATTERN self.labels_pattern = labels_pattern or \ - self.DEFAULT_INTENSITY_PATTERN + self.DEFAULT_LABELS_PATTERN self.sampling_params = sampling_params diff --git a/data_juicer/ops/mapper/query_sentiment_intensity_mapper.py b/data_juicer/ops/mapper/query_sentiment_detection_mapper.py similarity index 94% rename from data_juicer/ops/mapper/query_sentiment_intensity_mapper.py rename to data_juicer/ops/mapper/query_sentiment_detection_mapper.py index 2e23873b5..bd0d70bad 100644 --- a/data_juicer/ops/mapper/query_sentiment_intensity_mapper.py +++ b/data_juicer/ops/mapper/query_sentiment_detection_mapper.py @@ -6,17 +6,17 @@ from ..base_op import OPERATORS, Mapper -OP_NAME = 'query_sentiment_intensity_mapper' +OP_NAME = 'query_sentiment_detection_mapper' @OPERATORS.register_module(OP_NAME) -class QuerySentimentLabelMapper(Mapper): +class QuerySentimentDetectionMapper(Mapper): """ Mapper to predict user's sentiment intensity label (-1 for 'negative', 0 for 'neutral' and 1 for 'positive') in query. Input from query_key. Output intensity label and corresponding score for the query, which is - store in 'sentiment.query_intensity' and - 'sentiment.query_intensity_score' in Data-Juicer meta field. + store in 'sentiment.query_label' and + 'sentiment.query_label_score' in Data-Juicer meta field. """ _accelerator = 'cuda' @@ -95,7 +95,7 @@ def process_batched(self, samples, rank=None): samples[Fields.meta] = [{} for val in intensities] for i in range(len(samples[Fields.meta])): samples[Fields.meta][i] = nested_set( - samples[Fields.meta][i], MetaKeys.query_sentiment_intensity, + samples[Fields.meta][i], MetaKeys.query_sentiment_label, intensities[i]) samples[Fields.meta][i] = nested_set( samples[Fields.meta][i], MetaKeys.query_sentiment_score, diff --git a/data_juicer/utils/auto_install_mapping.py b/data_juicer/utils/auto_install_mapping.py index 049ff463f..148c30089 100644 --- a/data_juicer/utils/auto_install_mapping.py +++ b/data_juicer/utils/auto_install_mapping.py @@ -81,5 +81,5 @@ 'text_entity_dependency_filter': ['spacy-pkuseg'], 'optimize_response_mapper': ['torch', 'transformers', 'vllm'], 'dialog_sentiment_intensity_mapper': ['openai'], - 'query_sentiment_intensity_mapper': ['transformers'], + 'query_sentiment_detection_mapper': ['transformers'], } diff --git a/data_juicer/utils/constant.py b/data_juicer/utils/constant.py index 07c75f5a8..5cc9c33c7 100644 --- a/data_juicer/utils/constant.py +++ b/data_juicer/utils/constant.py @@ -73,11 +73,16 @@ class MetaKeys(object): dialog_sentiment_intensity = 'sentiment.dialog_intensity' dialog_sentiment_intensity_analysis = 'sentiment.dialog_intensity_analysis' - query_sentiment_intensity = 'sentiment.query_intensity' - query_sentiment_score = 'sentiment.query_intensity_score' + query_sentiment_label = 'sentiment.query_label' + query_sentiment_score = 'sentiment.query_label_score' dialog_sentiment_labels = 'sentiment.dialog_labels' dialog_sentiment_labels_analysis = 'sentiment.dialog_labels_analysis' + dialog_intent_labels = 'intent.dialog_labels' + dialog_intent_labels_analysis = 'intent.dialog_labels_analysis' + query_intent_label = 'intent.query_label' + query_intent_score = 'intent.query_label_score' + class StatsKeysMeta(type): """ diff --git a/docs/Operators.md b/docs/Operators.md index 295bd049b..164d9ebe6 100644 --- a/docs/Operators.md +++ b/docs/Operators.md @@ -94,7 +94,7 @@ All the specific operators are listed below, each featured with several capabili | punctuation_normalization_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Normalizes various Unicode punctuations to their ASCII equivalents | [code](../data_juicer/ops/mapper/punctuation_normalization_mapper.py) | [tests](../tests/ops/mapper/test_punctuation_normalization_mapper.py) | | python_file_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Executing Python function defined in a file | [code](../data_juicer/ops/mapper/python_file_mapper.py) | [tests](../tests/ops/mapper/test_python_file_mapper.py) | | python_lambda_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Executing Python lambda function on data samples | [code](../data_juicer/ops/mapper/python_lambda_mapper.py) | [tests](../tests/ops/mapper/test_python_lambda_mapper.py) | -| query_sentiment_intensity_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Mapper to predict user's sentiment intensity label (-1 for 'negative', 0 for 'neutral' and 1 for 'positive') in query. | [code](../data_juicer/ops/mapper/query_sentiment_intensity_mapper.py) | [tests](../tests/ops/mapper/test_query_sentiment_intensity_mapper.py) | +| query_sentiment_detection_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Mapper to predict user's sentiment intensity label (-1 for 'negative', 0 for 'neutral' and 1 for 'positive') in query. | [code](../data_juicer/ops/mapper/query_sentiment_detection_mapper.py) | [tests](../tests/ops/mapper/test_query_sentiment_detection_mapper.py) | | relation_identity_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Identify relation between two entity in the text. | [code](../data_juicer/ops/mapper/relation_identity_mapper.py) | [tests](../tests/ops/mapper/test_relation_identity_mapper.py) | | remove_bibliography_mapper | ![LaTeX](https://img.shields.io/badge/LaTeX-D99379?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Removes the bibliography of TeX documents | [code](../data_juicer/ops/mapper/remove_bibliography_mapper.py) | [tests](../tests/ops/mapper/test_remove_bibliography_mapper.py) | | remove_comments_mapper | ![LaTeX](https://img.shields.io/badge/LaTeX-D99379?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Removes the comments of TeX documents | [code](../data_juicer/ops/mapper/remove_comments_mapper.py) | [tests](../tests/ops/mapper/test_remove_comments_mapper.py) | diff --git a/docs/Operators_ZH.md b/docs/Operators_ZH.md index 74bccff28..8edb6c25f 100644 --- a/docs/Operators_ZH.md +++ b/docs/Operators_ZH.md @@ -93,7 +93,7 @@ Data-Juicer 中的算子分为以下 5 种类型。 | punctuation_normalization_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 将各种 Unicode 标点符号标准化为其 ASCII 等效项 | [code](../data_juicer/ops/mapper/punctuation_normalization_mapper.py) | [tests](../tests/ops/mapper/test_punctuation_normalization_mapper.py) | | python_file_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 执行文件中定义的 Python 函数处理样本 | [code](../data_juicer/ops/mapper/python_file_mapper.py) | [tests](../tests/ops/mapper/test_python_file_mapper.py) | | python_lambda_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 执行 Python lambda 函数处理样本 | [code](../data_juicer/ops/mapper/python_lambda_mapper.py) | [tests](../tests/ops/mapper/test_python_lambda_mapper.py) | -| query_sentiment_intensity_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 预测用户查询中的情感强度标签(-1 表示“负面”,0 表示“中立”,1 表示“正面”)。 | [code](../data_juicer/ops/mapper/query_sentiment_intensity_mapper.py) | [tests](../tests/ops/mapper/test_query_sentiment_intensity_mapper.py) | +| query_sentiment_detection_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 预测用户查询中的情感强度标签(-1 表示“负面”,0 表示“中立”,1 表示“正面”)。 | [code](../data_juicer/ops/mapper/query_sentiment_detection_mapper.py) | [tests](../tests/ops/mapper/test_query_sentiment_detection_mapper.py) | | relation_identity_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 识别一段文本中两个实体之间的关系 | [code](../data_juicer/ops/mapper/relation_identity_mapper.py) | [tests](../tests/ops/mapper/test_relation_identity_mapper.py) | | remove_bibliography_mapper | ![LaTeX](https://img.shields.io/badge/LaTeX-D99379?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 删除 TeX 文档的参考文献 | [code](../data_juicer/ops/mapper/remove_bibliography_mapper.py) | [tests](../tests/ops/mapper/test_remove_bibliography_mapper.py) | | remove_comments_mapper | ![LaTeX](https://img.shields.io/badge/LaTeX-D99379?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 删除 TeX 文档中的注释 | [code](../data_juicer/ops/mapper/remove_comments_mapper.py) | [tests](../tests/ops/mapper/test_remove_comments_mapper.py) | diff --git a/tests/ops/mapper/test_dialog_intent_detection_mapper copy.py b/tests/ops/mapper/test_dialog_intent_detection_mapper copy.py new file mode 100644 index 000000000..82b2f98ba --- /dev/null +++ b/tests/ops/mapper/test_dialog_intent_detection_mapper copy.py @@ -0,0 +1,141 @@ +import unittest +import json + +from loguru import logger + +from data_juicer.core.data import NestedDataset as Dataset +from data_juicer.ops.mapper.dialog_intent_detection_mapper import DialogIntentDetectionMapper +from data_juicer.utils.unittest_utils import (SKIPPED_TESTS, + DataJuicerTestCaseBase) +from data_juicer.utils.constant import Fields, MetaKeys +from data_juicer.utils.common_utils import nested_access + +# Skip tests for this OP. +# These tests have been tested locally. +@SKIPPED_TESTS.register_module() +class TestDialogIntentDetectionMapper(DataJuicerTestCaseBase): + # before runing this test, set below environment variables: + # export OPENAI_API_URL=https://dashscope.aliyuncs.com/compatible-mode/v1 + # export OPENAI_API_KEY=your_key + + def _run_op(self, op, samples, target_len): + dataset = Dataset.from_list(samples) + dataset = dataset.map(op.process, batch_size=2) + analysis_list = nested_access(dataset[0][Fields.meta], MetaKeys.dialog_intent_labels_analysis) + labels_list = nested_access(dataset[0][Fields.meta], MetaKeys.dialog_intent_labels) + + for analysis, labels in zip(analysis_list, labels_list): + logger.info(f'分析:{analysis}') + logger.info(f'情绪:{labels}') + + self.assertEqual(len(analysis_list), target_len) + self.assertEqual(len(labels_list), target_len) + + def test_default(self): + + samples = [{ + 'history': [ + ( + '李莲花有口皆碑', + '「微笑」过奖了,我也就是个普通大夫,没什么值得夸耀的。' + ), + ( + '是的,你确实是一个普通大夫,没什么值得夸耀的。', + '「委屈」你这话说的,我也是尽心尽力治病救人了。' + ), + ( + '你自己说的呀,我现在说了,你又不高兴了。', + 'or of of of of or or and or of of of of of of of,,, ' + ), + ( + '你在说什么我听不懂。', + '「委屈」我也没说什么呀,就是觉得你有点冤枉我了' + ) + ] + }] + + op = DialogIntentDetectionMapper(api_model='qwen2.5-72b-instruct') + self._run_op(op, samples, 4) + + def test_max_round(self): + + samples = [{ + 'history': [ + ( + '李莲花有口皆碑', + '「微笑」过奖了,我也就是个普通大夫,没什么值得夸耀的。' + ), + ( + '是的,你确实是一个普通大夫,没什么值得夸耀的。', + '「委屈」你这话说的,我也是尽心尽力治病救人了。' + ), + ( + '你自己说的呀,我现在说了,你又不高兴了。', + 'or of of of of or or and or of of of of of of of,,, ' + ), + ( + '你在说什么我听不懂。', + '「委屈」我也没说什么呀,就是觉得你有点冤枉我了' + ) + ] + }] + + op = DialogIntentDetectionMapper(api_model='qwen2.5-72b-instruct', + max_round=1) + self._run_op(op, samples, 4) + + def test_max_round_zero(self): + + samples = [{ + 'history': [ + ( + '李莲花有口皆碑', + '「微笑」过奖了,我也就是个普通大夫,没什么值得夸耀的。' + ), + ( + '是的,你确实是一个普通大夫,没什么值得夸耀的。', + '「委屈」你这话说的,我也是尽心尽力治病救人了。' + ), + ( + '你自己说的呀,我现在说了,你又不高兴了。', + 'or of of of of or or and or of of of of of of of,,, ' + ), + ( + '你在说什么我听不懂。', + '「委屈」我也没说什么呀,就是觉得你有点冤枉我了' + ) + ] + }] + + op = DialogIntentDetectionMapper(api_model='qwen2.5-72b-instruct', + max_round=0) + self._run_op(op, samples, 4) + + def test_query(self): + + samples = [{ + 'history': [ + ( + '李莲花有口皆碑', + '「微笑」过奖了,我也就是个普通大夫,没什么值得夸耀的。' + ), + ( + '是的,你确实是一个普通大夫,没什么值得夸耀的。', + '「委屈」你这话说的,我也是尽心尽力治病救人了。' + ), + ( + '你自己说的呀,我现在说了,你又不高兴了。', + 'or of of of of or or and or of of of of of of of,,, ' + ) + ], + 'query': '你在说什么我听不懂。', + 'response': '「委屈」我也没说什么呀,就是觉得你有点冤枉我了' + }] + + op = DialogIntentDetectionMapper(api_model='qwen2.5-72b-instruct', + max_round=1) + self._run_op(op, samples, 4) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/ops/mapper/test_query_sentiment_intensity_mapper.py b/tests/ops/mapper/test_query_sentiment_detection_mapper.py similarity index 79% rename from tests/ops/mapper/test_query_sentiment_intensity_mapper.py rename to tests/ops/mapper/test_query_sentiment_detection_mapper.py index 1f2423870..983a69cf0 100644 --- a/tests/ops/mapper/test_query_sentiment_intensity_mapper.py +++ b/tests/ops/mapper/test_query_sentiment_detection_mapper.py @@ -4,13 +4,13 @@ from loguru import logger from data_juicer.core.data import NestedDataset as Dataset -from data_juicer.ops.mapper.query_sentiment_intensity_mapper import QuerySentimentLabelMapper +from data_juicer.ops.mapper.query_sentiment_detection_mapper import QuerySentimentDetectionMapper from data_juicer.utils.unittest_utils import (SKIPPED_TESTS, DataJuicerTestCaseBase) from data_juicer.utils.constant import Fields, MetaKeys from data_juicer.utils.common_utils import nested_access -class TestQuerySentimentLabelMapper(DataJuicerTestCaseBase): +class TestQuerySentimentDetectionMapper(DataJuicerTestCaseBase): hf_model = 'mrm8488/distilroberta-finetuned-financial-news-sentiment-analysis' zh_to_en_hf_model = 'Helsinki-NLP/opus-mt-zh-en' @@ -35,11 +35,11 @@ def test_default(self): ] targets = [1, 0, -1] - op = QuerySentimentLabelMapper( + op = QuerySentimentDetectionMapper( hf_model = self.hf_model, zh_to_en_hf_model = self.zh_to_en_hf_model, ) - self._run_op(op, samples, MetaKeys.query_sentiment_intensity, targets) + self._run_op(op, samples, MetaKeys.query_sentiment_label, targets) def test_no_zh_to_en(self): @@ -51,11 +51,11 @@ def test_no_zh_to_en(self): ] targets = [0, 1] - op = QuerySentimentLabelMapper( + op = QuerySentimentDetectionMapper( hf_model = self.hf_model, zh_to_en_hf_model = None, ) - self._run_op(op, samples, MetaKeys.query_sentiment_intensity, targets) + self._run_op(op, samples, MetaKeys.query_sentiment_label, targets) def test_reset_map1(self): @@ -69,7 +69,7 @@ def test_reset_map1(self): ] targets = [2, 0, -2] - op = QuerySentimentLabelMapper( + op = QuerySentimentDetectionMapper( hf_model = self.hf_model, zh_to_en_hf_model = self.zh_to_en_hf_model, label_to_intensity = { @@ -78,7 +78,7 @@ def test_reset_map1(self): 'positive': 2, } ) - self._run_op(op, samples, MetaKeys.query_sentiment_intensity, targets) + self._run_op(op, samples, MetaKeys.query_sentiment_label, targets) def test_reset_map2(self): @@ -92,12 +92,12 @@ def test_reset_map2(self): ] targets = ['positive', 'neutral', 'negative'] - op = QuerySentimentLabelMapper( + op = QuerySentimentDetectionMapper( hf_model = self.hf_model, zh_to_en_hf_model = self.zh_to_en_hf_model, label_to_intensity = {} ) - self._run_op(op, samples, MetaKeys.query_sentiment_intensity, targets) + self._run_op(op, samples, MetaKeys.query_sentiment_label, targets) if __name__ == '__main__': unittest.main() From 2fb9fe47ebf90059eef6acafd230dc69d76e5547 Mon Sep 17 00:00:00 2001 From: Haibin <1400012807@pku.edu.cn> Date: Tue, 17 Dec 2024 11:13:53 +0800 Subject: [PATCH 076/118] fix typo --- data_juicer/ops/mapper/dialog_intent_detection_mapper.py | 2 ++ ...on_mapper copy.py => test_dialog_intent_detection_mapper.py} | 0 2 files changed, 2 insertions(+) rename tests/ops/mapper/{test_dialog_intent_detection_mapper copy.py => test_dialog_intent_detection_mapper.py} (100%) diff --git a/data_juicer/ops/mapper/dialog_intent_detection_mapper.py b/data_juicer/ops/mapper/dialog_intent_detection_mapper.py index 912fa8c3c..5b283ad1f 100644 --- a/data_juicer/ops/mapper/dialog_intent_detection_mapper.py +++ b/data_juicer/ops/mapper/dialog_intent_detection_mapper.py @@ -141,6 +141,8 @@ def build_input(self, history, query): if self.intent_candidates: input_prompt = self.candidate_template.format( candidate_str=','.join(self.intent_candidates)) + else: + input_prompt = '' if self.max_round > 0: input_prompt += ''.join(history[-self.max_round * 4:]) diff --git a/tests/ops/mapper/test_dialog_intent_detection_mapper copy.py b/tests/ops/mapper/test_dialog_intent_detection_mapper.py similarity index 100% rename from tests/ops/mapper/test_dialog_intent_detection_mapper copy.py rename to tests/ops/mapper/test_dialog_intent_detection_mapper.py From 324467fa1af9e7a342ace780457ac50b4ee63054 Mon Sep 17 00:00:00 2001 From: Haibin <1400012807@pku.edu.cn> Date: Tue, 17 Dec 2024 11:15:20 +0800 Subject: [PATCH 077/118] prompt adjust --- data_juicer/ops/mapper/dialog_intent_detection_mapper.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/data_juicer/ops/mapper/dialog_intent_detection_mapper.py b/data_juicer/ops/mapper/dialog_intent_detection_mapper.py index 5b283ad1f..293926766 100644 --- a/data_juicer/ops/mapper/dialog_intent_detection_mapper.py +++ b/data_juicer/ops/mapper/dialog_intent_detection_mapper.py @@ -28,7 +28,7 @@ class DialogIntentDetectionMapper(Mapper): '要求:\n' '- 需要先进行分析,然后列出用户所具有的意图,下面是一个样例,请模仿样例格式输出' '。\n' - '备选意图类别:[信息查找, 请求建议, 其他]\n' + # '备选意图类别:[信息查找, 请求建议, 其他]\n' '用户:你好,我最近对人工智能很感兴趣,能给我讲讲什么是机器学习吗?\n' '意图分析:用户在请求信息,希望了解有关机器学习的基础知识。\n' '意图类别:信息查找\n' From 4a3ad39ba374b19503727b390edc552b3c354b99 Mon Sep 17 00:00:00 2001 From: Haibin <1400012807@pku.edu.cn> Date: Tue, 17 Dec 2024 11:27:12 +0800 Subject: [PATCH 078/118] add more test --- .../mapper/dialog_intent_detection_mapper.py | 2 ++ .../test_dialog_intent_detection_mapper.py | 29 +++++++++++++++++++ 2 files changed, 31 insertions(+) diff --git a/data_juicer/ops/mapper/dialog_intent_detection_mapper.py b/data_juicer/ops/mapper/dialog_intent_detection_mapper.py index 293926766..b32fd9290 100644 --- a/data_juicer/ops/mapper/dialog_intent_detection_mapper.py +++ b/data_juicer/ops/mapper/dialog_intent_detection_mapper.py @@ -79,6 +79,8 @@ def __init__(self, Initialization method. :param api_model: API model name. + :param intent_candidates: The output intent candidates. Use the + intent labels of the open domain if it is None. :param max_round: The max num of round in the dialog to build the prompt. :param api_endpoint: URL endpoint for the API. diff --git a/tests/ops/mapper/test_dialog_intent_detection_mapper.py b/tests/ops/mapper/test_dialog_intent_detection_mapper.py index 82b2f98ba..e7db2cea6 100644 --- a/tests/ops/mapper/test_dialog_intent_detection_mapper.py +++ b/tests/ops/mapper/test_dialog_intent_detection_mapper.py @@ -136,6 +136,35 @@ def test_query(self): max_round=1) self._run_op(op, samples, 4) + def test_intent_candidates(self): + + samples = [{ + 'history': [ + ( + '李莲花有口皆碑', + '「微笑」过奖了,我也就是个普通大夫,没什么值得夸耀的。' + ), + ( + '是的,你确实是一个普通大夫,没什么值得夸耀的。', + '「委屈」你这话说的,我也是尽心尽力治病救人了。' + ), + ( + '你自己说的呀,我现在说了,你又不高兴了。', + 'or of of of of or or and or of of of of of of of,,, ' + ), + ( + '你在说什么我听不懂。', + '「委屈」我也没说什么呀,就是觉得你有点冤枉我了' + ) + ] + }] + + op = DialogIntentDetectionMapper( + api_model='qwen2.5-72b-instruct', + intent_candidates=['评价', '讽刺', '表达困惑'] + ) + self._run_op(op, samples, 4) + if __name__ == '__main__': unittest.main() From 937b3f1cf5a9b9b3294920245bbaff4549f3d55a Mon Sep 17 00:00:00 2001 From: Haibin <1400012807@pku.edu.cn> Date: Tue, 17 Dec 2024 14:17:54 +0800 Subject: [PATCH 079/118] query intent detection --- data_juicer/ops/mapper/__init__.py | 29 +++--- .../mapper/dialog_intent_detection_mapper.py | 1 - .../mapper/query_intent_detection_mapper.py | 98 +++++++++++++++++++ .../test_query_intent_detection_mapper.py | 61 ++++++++++++ 4 files changed, 174 insertions(+), 15 deletions(-) create mode 100644 data_juicer/ops/mapper/query_intent_detection_mapper.py create mode 100644 tests/ops/mapper/test_query_intent_detection_mapper.py diff --git a/data_juicer/ops/mapper/__init__.py b/data_juicer/ops/mapper/__init__.py index 931cd7f2a..af710bd83 100644 --- a/data_juicer/ops/mapper/__init__.py +++ b/data_juicer/ops/mapper/__init__.py @@ -36,6 +36,7 @@ from .punctuation_normalization_mapper import PunctuationNormalizationMapper from .python_file_mapper import PythonFileMapper from .python_lambda_mapper import PythonLambdaMapper +from .query_intent_detection_mapper import QueryIntentDetectionMapper from .query_sentiment_detection_mapper import QuerySentimentDetectionMapper from .relation_identity_mapper import RelationIdentityMapper from .remove_bibliography_mapper import RemoveBibliographyMapper @@ -87,18 +88,18 @@ 'OptimizeQAMapper', 'OptimizeQueryMapper', 'OptimizeResponseMapper', 'PairPreferenceMapper', 'PunctuationNormalizationMapper', 'PythonFileMapper', 'PythonLambdaMapper', 'QuerySentimentDetectionMapper', - 'RelationIdentityMapper', 'RemoveBibliographyMapper', - 'RemoveCommentsMapper', 'RemoveHeaderMapper', 'RemoveLongWordsMapper', - 'RemoveNonChineseCharacterlMapper', 'RemoveRepeatSentencesMapper', - 'RemoveSpecificCharsMapper', 'RemoveTableTextMapper', - 'RemoveWordsWithIncorrectSubstringsMapper', 'ReplaceContentMapper', - 'SentenceSplitMapper', 'TextChunkMapper', 'VideoCaptioningFromAudioMapper', - 'VideoCaptioningFromFramesMapper', 'VideoCaptioningFromSummarizerMapper', - 'VideoCaptioningFromVideoMapper', 'VideoExtractFramesMapper', - 'VideoFFmpegWrappedMapper', 'VideoFaceBlurMapper', - 'VideoRemoveWatermarkMapper', 'VideoResizeAspectRatioMapper', - 'VideoResizeResolutionMapper', 'VideoSplitByDurationMapper', - 'VideoSplitByKeyFrameMapper', 'VideoSplitBySceneMapper', - 'VideoTaggingFromAudioMapper', 'VideoTaggingFromFramesMapper', - 'WhitespaceNormalizationMapper' + 'QueryIntentDetectionMapper', 'RelationIdentityMapper', + 'RemoveBibliographyMapper', 'RemoveCommentsMapper', 'RemoveHeaderMapper', + 'RemoveLongWordsMapper', 'RemoveNonChineseCharacterlMapper', + 'RemoveRepeatSentencesMapper', 'RemoveSpecificCharsMapper', + 'RemoveTableTextMapper', 'RemoveWordsWithIncorrectSubstringsMapper', + 'ReplaceContentMapper', 'SentenceSplitMapper', 'TextChunkMapper', + 'VideoCaptioningFromAudioMapper', 'VideoCaptioningFromFramesMapper', + 'VideoCaptioningFromSummarizerMapper', 'VideoCaptioningFromVideoMapper', + 'VideoExtractFramesMapper', 'VideoFFmpegWrappedMapper', + 'VideoFaceBlurMapper', 'VideoRemoveWatermarkMapper', + 'VideoResizeAspectRatioMapper', 'VideoResizeResolutionMapper', + 'VideoSplitByDurationMapper', 'VideoSplitByKeyFrameMapper', + 'VideoSplitBySceneMapper', 'VideoTaggingFromAudioMapper', + 'VideoTaggingFromFramesMapper', 'WhitespaceNormalizationMapper' ] diff --git a/data_juicer/ops/mapper/dialog_intent_detection_mapper.py b/data_juicer/ops/mapper/dialog_intent_detection_mapper.py index b32fd9290..759e291b4 100644 --- a/data_juicer/ops/mapper/dialog_intent_detection_mapper.py +++ b/data_juicer/ops/mapper/dialog_intent_detection_mapper.py @@ -28,7 +28,6 @@ class DialogIntentDetectionMapper(Mapper): '要求:\n' '- 需要先进行分析,然后列出用户所具有的意图,下面是一个样例,请模仿样例格式输出' '。\n' - # '备选意图类别:[信息查找, 请求建议, 其他]\n' '用户:你好,我最近对人工智能很感兴趣,能给我讲讲什么是机器学习吗?\n' '意图分析:用户在请求信息,希望了解有关机器学习的基础知识。\n' '意图类别:信息查找\n' diff --git a/data_juicer/ops/mapper/query_intent_detection_mapper.py b/data_juicer/ops/mapper/query_intent_detection_mapper.py new file mode 100644 index 000000000..66290532d --- /dev/null +++ b/data_juicer/ops/mapper/query_intent_detection_mapper.py @@ -0,0 +1,98 @@ +from typing import Dict, Optional + +from data_juicer.utils.common_utils import nested_set +from data_juicer.utils.constant import Fields, MetaKeys +from data_juicer.utils.model_utils import get_model, prepare_model + +from ..base_op import OPERATORS, Mapper + +OP_NAME = 'query_intent_detection_mapper' + + +@OPERATORS.register_module(OP_NAME) +class QueryIntentDetectionMapper(Mapper): + """ + Mapper to predict user's Intent label in query. Input from query_key. + Output intensity label and corresponding score for the query, which is + store in 'intent.query_label' and 'intent.query_label_score' in + Data-Juicer meta field. + """ + + _accelerator = 'cuda' + _batched_op = True + + DEFAULT_LABEL_TO_INTENSITY = {} + + def __init__( + self, + hf_model: str = 'Falconsai/intent_classification', + zh_to_en_hf_model: Optional[str] = 'Helsinki-NLP/opus-mt-zh-en', + model_params: Dict = {}, + zh_to_en_model_params: Dict = {}, + *, + label_to_intensity: Dict = None, + **kwargs): + """ + Initialization method. + + :param hf_model: Hugginface model ID to predict sentiment intensity. + :param zh_to_en_hf_model: Translation model from Chinese to English. + If not None, translate the query from Chinese to English. + :param model_params: model param for hf_model. + :param zh_to_en_model_params: model param for zh_to_hf_model. + :param label_to_intensity: Map the output labels to the intensities + instead of the default mapper if not None. + :param kwargs: Extra keyword arguments. + """ + super().__init__(**kwargs) + + self.model_key = prepare_model(model_type='huggingface', + pretrained_model_name_or_path=hf_model, + return_pipe=True, + pipe_task='text-classification', + **model_params) + + if zh_to_en_hf_model is not None: + self.zh_to_en_model_key = prepare_model( + model_type='huggingface', + pretrained_model_name_or_path=zh_to_en_hf_model, + return_pipe=True, + pipe_task='translation', + **zh_to_en_model_params) + else: + self.zh_to_en_model_key = None + + if label_to_intensity is not None: + self.label_to_intensity = label_to_intensity + else: + self.label_to_intensity = self.DEFAULT_LABEL_TO_INTENSITY + + def process_batched(self, samples, rank=None): + queries = samples[self.query_key] + + if self.zh_to_en_model_key is not None: + translater, _ = get_model(self.zh_to_en_model_key, rank, + self.use_cuda()) + results = translater(queries) + queries = [item['translation_text'] for item in results] + + classifier, _ = get_model(self.model_key, rank, self.use_cuda()) + results = classifier(queries) + intensities = [ + self.label_to_intensity[r['label']] + if r['label'] in self.label_to_intensity else r['label'] + for r in results + ] + scores = [r['score'] for r in results] + + if Fields.meta not in samples: + samples[Fields.meta] = [{} for val in intensities] + for i in range(len(samples[Fields.meta])): + samples[Fields.meta][i] = nested_set(samples[Fields.meta][i], + MetaKeys.query_intent_label, + intensities[i]) + samples[Fields.meta][i] = nested_set(samples[Fields.meta][i], + MetaKeys.query_intent_score, + scores[i]) + + return samples diff --git a/tests/ops/mapper/test_query_intent_detection_mapper.py b/tests/ops/mapper/test_query_intent_detection_mapper.py new file mode 100644 index 000000000..6f8494dd6 --- /dev/null +++ b/tests/ops/mapper/test_query_intent_detection_mapper.py @@ -0,0 +1,61 @@ +import unittest +import json + +from loguru import logger + +from data_juicer.core.data import NestedDataset as Dataset +from data_juicer.ops.mapper.query_intent_detection_mapper import QueryIntentDetectionMapper +from data_juicer.utils.unittest_utils import (SKIPPED_TESTS, + DataJuicerTestCaseBase) +from data_juicer.utils.constant import Fields, MetaKeys +from data_juicer.utils.common_utils import nested_access + +class TestQueryIntentDetectionMapper(DataJuicerTestCaseBase): + + hf_model = 'Falconsai/intent_classification' + zh_to_en_hf_model = 'Helsinki-NLP/opus-mt-zh-en' + + def _run_op(self, op, samples, intensity_key, targets): + dataset = Dataset.from_list(samples) + dataset = dataset.map(op.process, batch_size=2) + + for sample, target in zip(dataset, targets): + intensity = nested_access(sample[Fields.meta], intensity_key) + self.assertEqual(intensity, target) + + def test_default(self): + + samples = [{ + 'query': '我要一个汉堡。' + },{ + 'query': '你最近过得怎么样?' + },{ + 'query': '它是正方形的。' + } + ] + targets = [1, 0, -1] + + op = QueryIntentDetectionMapper( + hf_model = self.hf_model, + zh_to_en_hf_model = self.zh_to_en_hf_model, + ) + self._run_op(op, samples, MetaKeys.query_sentiment_label, targets) + + def test_no_zh_to_en(self): + + samples = [{ + 'query': '它是正方形的。' + },{ + 'query': 'It is square.' + } + ] + targets = [0, 1] + + op = QueryIntentDetectionMapper( + hf_model = self.hf_model, + zh_to_en_hf_model = None, + ) + self._run_op(op, samples, MetaKeys.query_sentiment_label, targets) + +if __name__ == '__main__': + unittest.main() From d4ca87b08b66c5d8a9c3e207d1d3fdcf1b491bd5 Mon Sep 17 00:00:00 2001 From: Haibin <1400012807@pku.edu.cn> Date: Tue, 17 Dec 2024 14:30:16 +0800 Subject: [PATCH 080/118] for test --- tests/ops/mapper/test_query_intent_detection_mapper.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/ops/mapper/test_query_intent_detection_mapper.py b/tests/ops/mapper/test_query_intent_detection_mapper.py index 6f8494dd6..5fac5ffc9 100644 --- a/tests/ops/mapper/test_query_intent_detection_mapper.py +++ b/tests/ops/mapper/test_query_intent_detection_mapper.py @@ -12,8 +12,8 @@ class TestQueryIntentDetectionMapper(DataJuicerTestCaseBase): - hf_model = 'Falconsai/intent_classification' - zh_to_en_hf_model = 'Helsinki-NLP/opus-mt-zh-en' + hf_model = '/mnt/workspace/shared/checkpoints/huggingface/Falconsai/intent_classification' + zh_to_en_hf_model = '/mnt/workspace/shared/checkpoints/huggingface/Helsinki-NLP/opus-mt-zh-en' def _run_op(self, op, samples, intensity_key, targets): dataset = Dataset.from_list(samples) @@ -21,7 +21,8 @@ def _run_op(self, op, samples, intensity_key, targets): for sample, target in zip(dataset, targets): intensity = nested_access(sample[Fields.meta], intensity_key) - self.assertEqual(intensity, target) + print(intensity) + # self.assertEqual(intensity, target) def test_default(self): From 8109c713525f670305d4ddea4e48fe30f631fd26 Mon Sep 17 00:00:00 2001 From: Haibin <1400012807@pku.edu.cn> Date: Tue, 17 Dec 2024 14:43:19 +0800 Subject: [PATCH 081/118] for test --- configs/config_all.yaml | 2 +- .../query_sentiment_detection_mapper.py | 27 +++------ docs/Operators.md | 2 +- docs/Operators_ZH.md | 2 +- .../test_query_sentiment_detection_mapper.py | 55 +++---------------- 5 files changed, 17 insertions(+), 71 deletions(-) diff --git a/configs/config_all.yaml b/configs/config_all.yaml index 7e450f6ff..457baf27f 100644 --- a/configs/config_all.yaml +++ b/configs/config_all.yaml @@ -292,7 +292,7 @@ process: - python_lambda_mapper: # executing Python lambda function on data samples. lambda_str: '' # A string representation of the lambda function to be executed on data samples. If empty, the identity function is used. batched: False # A boolean indicating whether to process input data in batches. - - query_sentiment_detection_mapper: # Mapper to predict user's sentiment intensity label (-1 for 'negative', 0 for 'neutral' and 1 for 'positive') in query. + - query_sentiment_detection_mapper: # Mapper to predict user's sentiment intensity label ('negative', 'neutral' and 'positive') in query. hf_model: 'mrm8488/distilroberta-finetuned-financial-news-sentiment-analysis' # Hugginface model ID to predict sentiment intensity. zh_to_en_hf_model: 'Helsinki-NLP/opus-mt-zh-en' # Translation model from Chinese to English. If not None, translate the query from Chinese to English. model_params: {} # model param for hf_model. diff --git a/data_juicer/ops/mapper/query_sentiment_detection_mapper.py b/data_juicer/ops/mapper/query_sentiment_detection_mapper.py index bd0d70bad..c8b22a9c6 100644 --- a/data_juicer/ops/mapper/query_sentiment_detection_mapper.py +++ b/data_juicer/ops/mapper/query_sentiment_detection_mapper.py @@ -12,9 +12,9 @@ @OPERATORS.register_module(OP_NAME) class QuerySentimentDetectionMapper(Mapper): """ - Mapper to predict user's sentiment intensity label (-1 for 'negative', - 0 for 'neutral' and 1 for 'positive') in query. Input from query_key. - Output intensity label and corresponding score for the query, which is + Mapper to predict user's sentiment label ('negative', 'neutral' and + 'positive') in query. Input from query_key. + Output label and corresponding score for the query, which is store in 'sentiment.query_label' and 'sentiment.query_label_score' in Data-Juicer meta field. """ @@ -35,19 +35,15 @@ def __init__( zh_to_en_hf_model: Optional[str] = 'Helsinki-NLP/opus-mt-zh-en', model_params: Dict = {}, zh_to_en_model_params: Dict = {}, - *, - label_to_intensity: Dict = None, **kwargs): """ Initialization method. - :param hf_model: Hugginface model ID to predict sentiment intensity. + :param hf_model: Hugginface model ID to predict sentiment label. :param zh_to_en_hf_model: Translation model from Chinese to English. If not None, translate the query from Chinese to English. :param model_params: model param for hf_model. :param zh_to_en_model_params: model param for zh_to_hf_model. - :param label_to_intensity: Map the output labels to the intensities - instead of the default mapper if not None. :param kwargs: Extra keyword arguments. """ super().__init__(**kwargs) @@ -68,11 +64,6 @@ def __init__( else: self.zh_to_en_model_key = None - if label_to_intensity is not None: - self.label_to_intensity = label_to_intensity - else: - self.label_to_intensity = self.DEFAULT_LABEL_TO_INTENSITY - def process_batched(self, samples, rank=None): queries = samples[self.query_key] @@ -84,19 +75,15 @@ def process_batched(self, samples, rank=None): classifier, _ = get_model(self.model_key, rank, self.use_cuda()) results = classifier(queries) - intensities = [ - self.label_to_intensity[r['label']] - if r['label'] in self.label_to_intensity else r['label'] - for r in results - ] + labels = [r['label'] for r in results] scores = [r['score'] for r in results] if Fields.meta not in samples: - samples[Fields.meta] = [{} for val in intensities] + samples[Fields.meta] = [{} for val in labels] for i in range(len(samples[Fields.meta])): samples[Fields.meta][i] = nested_set( samples[Fields.meta][i], MetaKeys.query_sentiment_label, - intensities[i]) + labels[i]) samples[Fields.meta][i] = nested_set( samples[Fields.meta][i], MetaKeys.query_sentiment_score, scores[i]) diff --git a/docs/Operators.md b/docs/Operators.md index 164d9ebe6..ccc58578b 100644 --- a/docs/Operators.md +++ b/docs/Operators.md @@ -94,7 +94,7 @@ All the specific operators are listed below, each featured with several capabili | punctuation_normalization_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Normalizes various Unicode punctuations to their ASCII equivalents | [code](../data_juicer/ops/mapper/punctuation_normalization_mapper.py) | [tests](../tests/ops/mapper/test_punctuation_normalization_mapper.py) | | python_file_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Executing Python function defined in a file | [code](../data_juicer/ops/mapper/python_file_mapper.py) | [tests](../tests/ops/mapper/test_python_file_mapper.py) | | python_lambda_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Executing Python lambda function on data samples | [code](../data_juicer/ops/mapper/python_lambda_mapper.py) | [tests](../tests/ops/mapper/test_python_lambda_mapper.py) | -| query_sentiment_detection_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Mapper to predict user's sentiment intensity label (-1 for 'negative', 0 for 'neutral' and 1 for 'positive') in query. | [code](../data_juicer/ops/mapper/query_sentiment_detection_mapper.py) | [tests](../tests/ops/mapper/test_query_sentiment_detection_mapper.py) | +| query_sentiment_detection_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Mapper to predict user's sentiment intensity label ('negative', 'neutral' and 'positive') in query. | [code](../data_juicer/ops/mapper/query_sentiment_detection_mapper.py) | [tests](../tests/ops/mapper/test_query_sentiment_detection_mapper.py) | | relation_identity_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Identify relation between two entity in the text. | [code](../data_juicer/ops/mapper/relation_identity_mapper.py) | [tests](../tests/ops/mapper/test_relation_identity_mapper.py) | | remove_bibliography_mapper | ![LaTeX](https://img.shields.io/badge/LaTeX-D99379?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Removes the bibliography of TeX documents | [code](../data_juicer/ops/mapper/remove_bibliography_mapper.py) | [tests](../tests/ops/mapper/test_remove_bibliography_mapper.py) | | remove_comments_mapper | ![LaTeX](https://img.shields.io/badge/LaTeX-D99379?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Removes the comments of TeX documents | [code](../data_juicer/ops/mapper/remove_comments_mapper.py) | [tests](../tests/ops/mapper/test_remove_comments_mapper.py) | diff --git a/docs/Operators_ZH.md b/docs/Operators_ZH.md index 8edb6c25f..2b3ff3115 100644 --- a/docs/Operators_ZH.md +++ b/docs/Operators_ZH.md @@ -93,7 +93,7 @@ Data-Juicer 中的算子分为以下 5 种类型。 | punctuation_normalization_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 将各种 Unicode 标点符号标准化为其 ASCII 等效项 | [code](../data_juicer/ops/mapper/punctuation_normalization_mapper.py) | [tests](../tests/ops/mapper/test_punctuation_normalization_mapper.py) | | python_file_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 执行文件中定义的 Python 函数处理样本 | [code](../data_juicer/ops/mapper/python_file_mapper.py) | [tests](../tests/ops/mapper/test_python_file_mapper.py) | | python_lambda_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 执行 Python lambda 函数处理样本 | [code](../data_juicer/ops/mapper/python_lambda_mapper.py) | [tests](../tests/ops/mapper/test_python_lambda_mapper.py) | -| query_sentiment_detection_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 预测用户查询中的情感强度标签(-1 表示“负面”,0 表示“中立”,1 表示“正面”)。 | [code](../data_juicer/ops/mapper/query_sentiment_detection_mapper.py) | [tests](../tests/ops/mapper/test_query_sentiment_detection_mapper.py) | +| query_sentiment_detection_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 预测用户查询中的情感强度标签('negative'、'neutral'和'positive')。 | [code](../data_juicer/ops/mapper/query_sentiment_detection_mapper.py) | [tests](../tests/ops/mapper/test_query_sentiment_detection_mapper.py) | | relation_identity_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 识别一段文本中两个实体之间的关系 | [code](../data_juicer/ops/mapper/relation_identity_mapper.py) | [tests](../tests/ops/mapper/test_relation_identity_mapper.py) | | remove_bibliography_mapper | ![LaTeX](https://img.shields.io/badge/LaTeX-D99379?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 删除 TeX 文档的参考文献 | [code](../data_juicer/ops/mapper/remove_bibliography_mapper.py) | [tests](../tests/ops/mapper/test_remove_bibliography_mapper.py) | | remove_comments_mapper | ![LaTeX](https://img.shields.io/badge/LaTeX-D99379?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 删除 TeX 文档中的注释 | [code](../data_juicer/ops/mapper/remove_comments_mapper.py) | [tests](../tests/ops/mapper/test_remove_comments_mapper.py) | diff --git a/tests/ops/mapper/test_query_sentiment_detection_mapper.py b/tests/ops/mapper/test_query_sentiment_detection_mapper.py index 983a69cf0..dcc29e25c 100644 --- a/tests/ops/mapper/test_query_sentiment_detection_mapper.py +++ b/tests/ops/mapper/test_query_sentiment_detection_mapper.py @@ -12,16 +12,16 @@ class TestQuerySentimentDetectionMapper(DataJuicerTestCaseBase): - hf_model = 'mrm8488/distilroberta-finetuned-financial-news-sentiment-analysis' - zh_to_en_hf_model = 'Helsinki-NLP/opus-mt-zh-en' + hf_model = '/mnt/workspace/shared/checkpoints/huggingface/mrm8488/distilroberta-finetuned-financial-news-sentiment-analysis' + zh_to_en_hf_model = '/mnt/workspace/shared/checkpoints/huggingface/Helsinki-NLP/opus-mt-zh-en' - def _run_op(self, op, samples, intensity_key, targets): + def _run_op(self, op, samples, label_key, targets): dataset = Dataset.from_list(samples) dataset = dataset.map(op.process, batch_size=2) for sample, target in zip(dataset, targets): - intensity = nested_access(sample[Fields.meta], intensity_key) - self.assertEqual(intensity, target) + label = nested_access(sample[Fields.meta], label_key) + self.assertEqual(label, target) def test_default(self): @@ -33,7 +33,7 @@ def test_default(self): 'query': '没有希望。' }, ] - targets = [1, 0, -1] + targets = ['positive', 'neutral', 'negative'] op = QuerySentimentDetectionMapper( hf_model = self.hf_model, @@ -49,55 +49,14 @@ def test_no_zh_to_en(self): 'query': 'That is great!' } ] - targets = [0, 1] + targets = ['neutral', 'positive'] op = QuerySentimentDetectionMapper( hf_model = self.hf_model, zh_to_en_hf_model = None, ) self._run_op(op, samples, MetaKeys.query_sentiment_label, targets) - - def test_reset_map1(self): - - samples = [{ - 'query': '太棒了!' - },{ - 'query': '嗯嗯' - },{ - 'query': '没有希望。' - }, - ] - targets = [2, 0, -2] - - op = QuerySentimentDetectionMapper( - hf_model = self.hf_model, - zh_to_en_hf_model = self.zh_to_en_hf_model, - label_to_intensity = { - 'negative': -2, - 'neutral': 0, - 'positive': 2, - } - ) - self._run_op(op, samples, MetaKeys.query_sentiment_label, targets) - def test_reset_map2(self): - - samples = [{ - 'query': '太棒了!' - },{ - 'query': '嗯嗯' - },{ - 'query': '没有希望。' - }, - ] - targets = ['positive', 'neutral', 'negative'] - - op = QuerySentimentDetectionMapper( - hf_model = self.hf_model, - zh_to_en_hf_model = self.zh_to_en_hf_model, - label_to_intensity = {} - ) - self._run_op(op, samples, MetaKeys.query_sentiment_label, targets) if __name__ == '__main__': unittest.main() From c749dcd3885f7cc6b52e8c16d69672617f8c1cc8 Mon Sep 17 00:00:00 2001 From: Haibin <1400012807@pku.edu.cn> Date: Tue, 17 Dec 2024 15:06:00 +0800 Subject: [PATCH 082/118] change model --- .../ops/mapper/query_intent_detection_mapper.py | 3 ++- .../ops/mapper/test_query_intent_detection_mapper.py | 12 ++++++------ .../mapper/test_query_sentiment_detection_mapper.py | 4 ++-- 3 files changed, 10 insertions(+), 9 deletions(-) diff --git a/data_juicer/ops/mapper/query_intent_detection_mapper.py b/data_juicer/ops/mapper/query_intent_detection_mapper.py index 66290532d..badc7f49c 100644 --- a/data_juicer/ops/mapper/query_intent_detection_mapper.py +++ b/data_juicer/ops/mapper/query_intent_detection_mapper.py @@ -25,7 +25,8 @@ class QueryIntentDetectionMapper(Mapper): def __init__( self, - hf_model: str = 'Falconsai/intent_classification', + hf_model: + str = 'bespin-global/klue-roberta-small-3i4k-intent-classification', # noqa: E501 E131 zh_to_en_hf_model: Optional[str] = 'Helsinki-NLP/opus-mt-zh-en', model_params: Dict = {}, zh_to_en_model_params: Dict = {}, diff --git a/tests/ops/mapper/test_query_intent_detection_mapper.py b/tests/ops/mapper/test_query_intent_detection_mapper.py index 5fac5ffc9..592d3d0c0 100644 --- a/tests/ops/mapper/test_query_intent_detection_mapper.py +++ b/tests/ops/mapper/test_query_intent_detection_mapper.py @@ -12,7 +12,7 @@ class TestQueryIntentDetectionMapper(DataJuicerTestCaseBase): - hf_model = '/mnt/workspace/shared/checkpoints/huggingface/Falconsai/intent_classification' + hf_model = '/mnt/workspace/shared/checkpoints/huggingface/bespin-global/klue-roberta-small-3i4k-intent-classification' zh_to_en_hf_model = '/mnt/workspace/shared/checkpoints/huggingface/Helsinki-NLP/opus-mt-zh-en' def _run_op(self, op, samples, intensity_key, targets): @@ -27,11 +27,11 @@ def _run_op(self, op, samples, intensity_key, targets): def test_default(self): samples = [{ - 'query': '我要一个汉堡。' + 'query': '这样好吗?' },{ - 'query': '你最近过得怎么样?' + 'query': '把那只笔递给我。' },{ - 'query': '它是正方形的。' + 'query': '难道不是这样的吗?' } ] targets = [1, 0, -1] @@ -45,9 +45,9 @@ def test_default(self): def test_no_zh_to_en(self): samples = [{ - 'query': '它是正方形的。' + 'query': '这样好吗?' },{ - 'query': 'It is square.' + 'query': 'Is this okay?' } ] targets = [0, 1] diff --git a/tests/ops/mapper/test_query_sentiment_detection_mapper.py b/tests/ops/mapper/test_query_sentiment_detection_mapper.py index dcc29e25c..62ed0f380 100644 --- a/tests/ops/mapper/test_query_sentiment_detection_mapper.py +++ b/tests/ops/mapper/test_query_sentiment_detection_mapper.py @@ -12,8 +12,8 @@ class TestQuerySentimentDetectionMapper(DataJuicerTestCaseBase): - hf_model = '/mnt/workspace/shared/checkpoints/huggingface/mrm8488/distilroberta-finetuned-financial-news-sentiment-analysis' - zh_to_en_hf_model = '/mnt/workspace/shared/checkpoints/huggingface/Helsinki-NLP/opus-mt-zh-en' + hf_model = 'mrm8488/distilroberta-finetuned-financial-news-sentiment-analysis' + zh_to_en_hf_model = 'Helsinki-NLP/opus-mt-zh-en' def _run_op(self, op, samples, label_key, targets): dataset = Dataset.from_list(samples) From c7df0bca1f61f378f501a110553c1471a965e71a Mon Sep 17 00:00:00 2001 From: Haibin <1400012807@pku.edu.cn> Date: Tue, 17 Dec 2024 15:11:15 +0800 Subject: [PATCH 083/118] fix typo --- tests/ops/mapper/test_query_intent_detection_mapper.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/ops/mapper/test_query_intent_detection_mapper.py b/tests/ops/mapper/test_query_intent_detection_mapper.py index 592d3d0c0..f1b11e976 100644 --- a/tests/ops/mapper/test_query_intent_detection_mapper.py +++ b/tests/ops/mapper/test_query_intent_detection_mapper.py @@ -40,7 +40,7 @@ def test_default(self): hf_model = self.hf_model, zh_to_en_hf_model = self.zh_to_en_hf_model, ) - self._run_op(op, samples, MetaKeys.query_sentiment_label, targets) + self._run_op(op, samples, MetaKeys.query_intent_label, targets) def test_no_zh_to_en(self): @@ -56,7 +56,7 @@ def test_no_zh_to_en(self): hf_model = self.hf_model, zh_to_en_hf_model = None, ) - self._run_op(op, samples, MetaKeys.query_sentiment_label, targets) + self._run_op(op, samples, MetaKeys.query_intent_label, targets) if __name__ == '__main__': unittest.main() From c7662cbf129415dcafa877bd84a8c06a1a3e7bdb Mon Sep 17 00:00:00 2001 From: Haibin <1400012807@pku.edu.cn> Date: Tue, 17 Dec 2024 15:14:17 +0800 Subject: [PATCH 084/118] fix typo --- tests/ops/mapper/test_query_intent_detection_mapper.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/ops/mapper/test_query_intent_detection_mapper.py b/tests/ops/mapper/test_query_intent_detection_mapper.py index f1b11e976..8033c4e2d 100644 --- a/tests/ops/mapper/test_query_intent_detection_mapper.py +++ b/tests/ops/mapper/test_query_intent_detection_mapper.py @@ -15,14 +15,14 @@ class TestQueryIntentDetectionMapper(DataJuicerTestCaseBase): hf_model = '/mnt/workspace/shared/checkpoints/huggingface/bespin-global/klue-roberta-small-3i4k-intent-classification' zh_to_en_hf_model = '/mnt/workspace/shared/checkpoints/huggingface/Helsinki-NLP/opus-mt-zh-en' - def _run_op(self, op, samples, intensity_key, targets): + def _run_op(self, op, samples, label_key, targets): dataset = Dataset.from_list(samples) dataset = dataset.map(op.process, batch_size=2) for sample, target in zip(dataset, targets): - intensity = nested_access(sample[Fields.meta], intensity_key) - print(intensity) - # self.assertEqual(intensity, target) + label = nested_access(sample[Fields.meta], label_key) + print(label) + # self.assertEqual(label, target) def test_default(self): From 6f44ec0ad90baf5c85b6afcbe55dadf04b22bb54 Mon Sep 17 00:00:00 2001 From: Haibin <1400012807@pku.edu.cn> Date: Tue, 17 Dec 2024 15:22:05 +0800 Subject: [PATCH 085/118] for test --- data_juicer/ops/mapper/query_intent_detection_mapper.py | 1 + tests/ops/mapper/test_query_intent_detection_mapper.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/data_juicer/ops/mapper/query_intent_detection_mapper.py b/data_juicer/ops/mapper/query_intent_detection_mapper.py index badc7f49c..4c24ebd21 100644 --- a/data_juicer/ops/mapper/query_intent_detection_mapper.py +++ b/data_juicer/ops/mapper/query_intent_detection_mapper.py @@ -76,6 +76,7 @@ def process_batched(self, samples, rank=None): self.use_cuda()) results = translater(queries) queries = [item['translation_text'] for item in results] + print(queries) classifier, _ = get_model(self.model_key, rank, self.use_cuda()) results = classifier(queries) diff --git a/tests/ops/mapper/test_query_intent_detection_mapper.py b/tests/ops/mapper/test_query_intent_detection_mapper.py index 8033c4e2d..2273c1a0b 100644 --- a/tests/ops/mapper/test_query_intent_detection_mapper.py +++ b/tests/ops/mapper/test_query_intent_detection_mapper.py @@ -29,9 +29,9 @@ def test_default(self): samples = [{ 'query': '这样好吗?' },{ - 'query': '把那只笔递给我。' + 'query': '站住!' },{ - 'query': '难道不是这样的吗?' + 'query': '今天阳光灿烂。' } ] targets = [1, 0, -1] From 9b6652ddadca3e65ba45d919cb2333dcc847ff48 Mon Sep 17 00:00:00 2001 From: Haibin <1400012807@pku.edu.cn> Date: Tue, 17 Dec 2024 15:25:04 +0800 Subject: [PATCH 086/118] for test --- data_juicer/ops/mapper/query_intent_detection_mapper.py | 1 - tests/ops/mapper/test_query_intent_detection_mapper.py | 7 +++---- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/data_juicer/ops/mapper/query_intent_detection_mapper.py b/data_juicer/ops/mapper/query_intent_detection_mapper.py index 4c24ebd21..badc7f49c 100644 --- a/data_juicer/ops/mapper/query_intent_detection_mapper.py +++ b/data_juicer/ops/mapper/query_intent_detection_mapper.py @@ -76,7 +76,6 @@ def process_batched(self, samples, rank=None): self.use_cuda()) results = translater(queries) queries = [item['translation_text'] for item in results] - print(queries) classifier, _ = get_model(self.model_key, rank, self.use_cuda()) results = classifier(queries) diff --git a/tests/ops/mapper/test_query_intent_detection_mapper.py b/tests/ops/mapper/test_query_intent_detection_mapper.py index 2273c1a0b..3e15755c4 100644 --- a/tests/ops/mapper/test_query_intent_detection_mapper.py +++ b/tests/ops/mapper/test_query_intent_detection_mapper.py @@ -21,8 +21,7 @@ def _run_op(self, op, samples, label_key, targets): for sample, target in zip(dataset, targets): label = nested_access(sample[Fields.meta], label_key) - print(label) - # self.assertEqual(label, target) + self.assertEqual(label, target) def test_default(self): @@ -34,7 +33,7 @@ def test_default(self): 'query': '今天阳光灿烂。' } ] - targets = [1, 0, -1] + targets = ['question', 'command', 'statement'] op = QueryIntentDetectionMapper( hf_model = self.hf_model, @@ -50,7 +49,7 @@ def test_no_zh_to_en(self): 'query': 'Is this okay?' } ] - targets = [0, 1] + targets = ['question', 'rhetorical question'] op = QueryIntentDetectionMapper( hf_model = self.hf_model, From fa306dc90df6c3c4cdbd133c3ddc755f7a3cb3a1 Mon Sep 17 00:00:00 2001 From: Haibin <1400012807@pku.edu.cn> Date: Tue, 17 Dec 2024 16:32:08 +0800 Subject: [PATCH 087/118] doc done --- configs/config_all.yaml | 44 +++++++++++++++++-- .../mapper/query_intent_detection_mapper.py | 21 ++------- .../query_sentiment_detection_mapper.py | 6 --- data_juicer/utils/auto_install_mapping.py | 3 ++ docs/Operators.md | 7 ++- docs/Operators_ZH.md | 5 ++- .../test_query_intent_detection_mapper.py | 4 +- 7 files changed, 57 insertions(+), 33 deletions(-) diff --git a/configs/config_all.yaml b/configs/config_all.yaml index 457baf27f..74600663d 100644 --- a/configs/config_all.yaml +++ b/configs/config_all.yaml @@ -77,11 +77,43 @@ process: - clean_ip_mapper: # remove ip addresses from text. - clean_links_mapper: # remove web links from text. - clean_copyright_mapper: # remove copyright comments. + - dialog_intent_detection_mapper: # Mapper to generate user's intent labels in dialog. + api_model: 'gpt-4o' # API model name. + intent_candidates: null # The output intent candidates. Use the intent labels of the open domain if it is None. + max_round: 10 # The max num of round in the dialog to build the prompt. + api_endpoint: null # URL endpoint for the API. + response_path: null # Path to extract content from the API response. Defaults to 'choices.0.message.content'. + system_prompt: null # System prompt for the task. + query_template: null # Template for query part to build the input prompt. + response_template: null # Template for response part to build the input prompt. + candidate_template: null # Template for intent candidates to build the input prompt. + analysis_template: null # Template for analysis part to build the input prompt. + labels_template: null # Template for labels to build the input prompt. + analysis_pattern: null # Pattern to parse the return intent analysis. + labels_pattern: null # Pattern to parse the return intent labels. + try_num: 3 # The number of retry attempts when there is an API call error or output parsing error. + model_params: {} # Parameters for initializing the API model. + sampling_params: {} # Extra parameters passed to the API call. e.g {'temperature': 0.9, 'top_p': 0.95} + - dialog_sentiment_detection_mapper: # Mapper to generate user's sentiment labels in dialog. + api_model: 'gpt-4o' # API model name. + max_round: 10 # The max num of round in the dialog to build the prompt. + api_endpoint: null # URL endpoint for the API. + response_path: null # Path to extract content from the API response. Defaults to 'choices.0.message.content'. + system_prompt: null # System prompt for the task. + query_template: null # Template for query part to build the input prompt. + response_template: null # Template for response part to build the input prompt. + analysis_template: null # Template for analysis part to build the input prompt. + labels_template: null # Template for labels part to build the input prompt. + analysis_pattern: null # Pattern to parse the return sentiment analysis. + labels_pattern: null # Pattern to parse the return sentiment labels. + try_num: 3 # The number of retry attempts when there is an API call error or output parsing error. + model_params: {} # Parameters for initializing the API model. + sampling_params: {} # Extra parameters passed to the API call. e.g {'temperature': 0.9, 'top_p': 0.95} - dialog_sentiment_intensity_mapper: # Mapper to predict user's sentiment intensity (from -5 to 5 in default prompt) in dialog. api_model: 'gpt-4o' # API model name. max_round: 10 # The max num of round in the dialog to build the prompt. api_endpoint: null # URL endpoint for the API. - esponse_path: null # Path to extract content from the API response. Defaults to 'choices.0.message.content'. + response_path: null # Path to extract content from the API response. Defaults to 'choices.0.message.content'. system_prompt: null # System prompt for the task. query_template: null # Template for query part to build the input prompt. response_template: null # Template for response part to build the input prompt. @@ -292,12 +324,16 @@ process: - python_lambda_mapper: # executing Python lambda function on data samples. lambda_str: '' # A string representation of the lambda function to be executed on data samples. If empty, the identity function is used. batched: False # A boolean indicating whether to process input data in batches. - - query_sentiment_detection_mapper: # Mapper to predict user's sentiment intensity label ('negative', 'neutral' and 'positive') in query. - hf_model: 'mrm8488/distilroberta-finetuned-financial-news-sentiment-analysis' # Hugginface model ID to predict sentiment intensity. + - query_intent_detection_mapper: # Mapper to predict user's Intent label in query. + hf_model: 'bespin-global/klue-roberta-small-3i4k-intent-classification' # Hugginface model ID to predict intent label. + zh_to_en_hf_model: 'Helsinki-NLP/opus-mt-zh-en' # Translation model from Chinese to English. If not None, translate the query from Chinese to English. + model_params: {} # model param for hf_model. + zh_to_en_model_params: {} # model param for zh_to_hf_model. + - query_sentiment_detection_mapper: # Mapper to predict user's sentiment label ('negative', 'neutral' and 'positive') in query. + hf_model: 'mrm8488/distilroberta-finetuned-financial-news-sentiment-analysis' # Hugginface model ID to predict sentiment label. zh_to_en_hf_model: 'Helsinki-NLP/opus-mt-zh-en' # Translation model from Chinese to English. If not None, translate the query from Chinese to English. model_params: {} # model param for hf_model. zh_to_en_model_params: {} # model param for zh_to_hf_model. - label_to_intensity: null # Map the output labels to the intensities instead of the default mapper if not None. - relation_identity_mapper: # identify relation between two entity in the text. api_model: 'gpt-4o' # API model name. source_entity: '孙悟空' # The source entity of the relation to be dentified. diff --git a/data_juicer/ops/mapper/query_intent_detection_mapper.py b/data_juicer/ops/mapper/query_intent_detection_mapper.py index badc7f49c..e4d44aa1d 100644 --- a/data_juicer/ops/mapper/query_intent_detection_mapper.py +++ b/data_juicer/ops/mapper/query_intent_detection_mapper.py @@ -13,7 +13,7 @@ class QueryIntentDetectionMapper(Mapper): """ Mapper to predict user's Intent label in query. Input from query_key. - Output intensity label and corresponding score for the query, which is + Output intent label and corresponding score for the query, which is store in 'intent.query_label' and 'intent.query_label_score' in Data-Juicer meta field. """ @@ -21,8 +21,6 @@ class QueryIntentDetectionMapper(Mapper): _accelerator = 'cuda' _batched_op = True - DEFAULT_LABEL_TO_INTENSITY = {} - def __init__( self, hf_model: @@ -30,19 +28,15 @@ def __init__( zh_to_en_hf_model: Optional[str] = 'Helsinki-NLP/opus-mt-zh-en', model_params: Dict = {}, zh_to_en_model_params: Dict = {}, - *, - label_to_intensity: Dict = None, **kwargs): """ Initialization method. - :param hf_model: Hugginface model ID to predict sentiment intensity. + :param hf_model: Hugginface model ID to predict intent label. :param zh_to_en_hf_model: Translation model from Chinese to English. If not None, translate the query from Chinese to English. :param model_params: model param for hf_model. :param zh_to_en_model_params: model param for zh_to_hf_model. - :param label_to_intensity: Map the output labels to the intensities - instead of the default mapper if not None. :param kwargs: Extra keyword arguments. """ super().__init__(**kwargs) @@ -63,11 +57,6 @@ def __init__( else: self.zh_to_en_model_key = None - if label_to_intensity is not None: - self.label_to_intensity = label_to_intensity - else: - self.label_to_intensity = self.DEFAULT_LABEL_TO_INTENSITY - def process_batched(self, samples, rank=None): queries = samples[self.query_key] @@ -79,11 +68,7 @@ def process_batched(self, samples, rank=None): classifier, _ = get_model(self.model_key, rank, self.use_cuda()) results = classifier(queries) - intensities = [ - self.label_to_intensity[r['label']] - if r['label'] in self.label_to_intensity else r['label'] - for r in results - ] + intensities = [r['label'] for r in results] scores = [r['score'] for r in results] if Fields.meta not in samples: diff --git a/data_juicer/ops/mapper/query_sentiment_detection_mapper.py b/data_juicer/ops/mapper/query_sentiment_detection_mapper.py index c8b22a9c6..ef49a344a 100644 --- a/data_juicer/ops/mapper/query_sentiment_detection_mapper.py +++ b/data_juicer/ops/mapper/query_sentiment_detection_mapper.py @@ -22,12 +22,6 @@ class QuerySentimentDetectionMapper(Mapper): _accelerator = 'cuda' _batched_op = True - DEFAULT_LABEL_TO_INTENSITY = { - 'negative': -1, - 'neutral': 0, - 'positive': 1, - } - def __init__( self, hf_model: diff --git a/data_juicer/utils/auto_install_mapping.py b/data_juicer/utils/auto_install_mapping.py index 148c30089..6fea1fef8 100644 --- a/data_juicer/utils/auto_install_mapping.py +++ b/data_juicer/utils/auto_install_mapping.py @@ -80,6 +80,9 @@ 'video_tagging_from_frames_mapper': ['ram', 'torch'], 'text_entity_dependency_filter': ['spacy-pkuseg'], 'optimize_response_mapper': ['torch', 'transformers', 'vllm'], + 'dialog_intent_detection_mapper': ['openai'], + 'dialog_sentiment_detection_mapper': ['openai'], 'dialog_sentiment_intensity_mapper': ['openai'], + 'query_intent_detection_mapper': ['transformers'], 'query_sentiment_detection_mapper': ['transformers'], } diff --git a/docs/Operators.md b/docs/Operators.md index ccc58578b..2ad6251f3 100644 --- a/docs/Operators.md +++ b/docs/Operators.md @@ -11,7 +11,7 @@ The operators in Data-Juicer are categorized into 5 types. | Type | Number | Description | |-----------------------------------|:------:|-------------------------------------------------| | [ Formatter ]( #formatter ) | 9 | Discovers, loads, and canonicalizes source data | -| [ Mapper ]( #mapper ) | 65 | Edits and transforms samples | +| [ Mapper ]( #mapper ) | 68 | Edits and transforms samples | | [ Filter ]( #filter ) | 44 | Filters out low-quality samples | | [ Deduplicator ]( #deduplicator ) | 8 | Detects and removes duplicate samples | | [ Selector ]( #selector ) | 4 | Selects top samples based on ranking | @@ -68,6 +68,8 @@ All the specific operators are listed below, each featured with several capabili | clean_html_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Removes HTML tags and returns plain text of all the nodes | [code](../data_juicer/ops/mapper/clean_html_mapper.py) | [tests](../tests/ops/mapper/test_clean_html_mapper.py) | | clean_ip_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Removes IP addresses | [code](../data_juicer/ops/mapper/clean_ip_mapper.py) | [tests](../tests/ops/mapper/test_clean_ip_mapper.py) | | clean_links_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![Code](https://img.shields.io/badge/Code-590F08?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Removes links, such as those starting with http or ftp | [code](../data_juicer/ops/mapper/clean_links_mapper.py) | [tests](../tests/ops/mapper/test_clean_links_mapper.py) | +| dialog_intent_detection_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Mapper to generate user's intent labels in dialog. | [code](../data_juicer/ops/mapper/dialog_intent_detection_mapper.py) | [tests](../tests/ops/mapper/test_dialog_intent_detection_mapper.py) | +| dialog_sentiment_detection_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Mapper to generate user's sentiment labels in dialog. | [code](../data_juicer/ops/mapper/dialog_sentiment_detection_mapper.py) | [tests](../tests/ops/mapper/test_dialog_sentiment_detection_mapper.py) | | dialog_sentiment_intensity_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Mapper to predict user's sentiment intensity (from -5 to 5 in default prompt) in dialog. | [code](../data_juicer/ops/mapper/dialog_sentiment_intensity_mapper.py) | [tests](../tests/ops/mapper/test_dialog_sentiment_intensity_mapper.py) | | expand_macro_mapper | ![LaTeX](https://img.shields.io/badge/LaTeX-D99379?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Expands macros usually defined at the top of TeX documents | [code](../data_juicer/ops/mapper/expand_macro_mapper.py) | [tests](../tests/ops/mapper/test_expand_macro_mapper.py) | | extract_entity_attribute_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Extract attributes for given entities from the text. | [code](../data_juicer/ops/mapper/extract_entity_attribute_mapper.py) | [tests](../tests/ops/mapper/test_extract_entity_attribute_mapper.py) | @@ -94,7 +96,8 @@ All the specific operators are listed below, each featured with several capabili | punctuation_normalization_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Normalizes various Unicode punctuations to their ASCII equivalents | [code](../data_juicer/ops/mapper/punctuation_normalization_mapper.py) | [tests](../tests/ops/mapper/test_punctuation_normalization_mapper.py) | | python_file_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Executing Python function defined in a file | [code](../data_juicer/ops/mapper/python_file_mapper.py) | [tests](../tests/ops/mapper/test_python_file_mapper.py) | | python_lambda_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Executing Python lambda function on data samples | [code](../data_juicer/ops/mapper/python_lambda_mapper.py) | [tests](../tests/ops/mapper/test_python_lambda_mapper.py) | -| query_sentiment_detection_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Mapper to predict user's sentiment intensity label ('negative', 'neutral' and 'positive') in query. | [code](../data_juicer/ops/mapper/query_sentiment_detection_mapper.py) | [tests](../tests/ops/mapper/test_query_sentiment_detection_mapper.py) | +| query_intent_detection_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Mapper to predict user's Intent label in query. | [code](../data_juicer/ops/mapper/query_intent_detection_mapper.py) | [tests](../tests/ops/mapper/test_query_intent_detection_mapper.py) | +| query_sentiment_detection_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Mapper to predict user's sentiment label ('negative', 'neutral' and 'positive') in query. | [code](../data_juicer/ops/mapper/query_sentiment_detection_mapper.py) | [tests](../tests/ops/mapper/test_query_sentiment_detection_mapper.py) | | relation_identity_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Identify relation between two entity in the text. | [code](../data_juicer/ops/mapper/relation_identity_mapper.py) | [tests](../tests/ops/mapper/test_relation_identity_mapper.py) | | remove_bibliography_mapper | ![LaTeX](https://img.shields.io/badge/LaTeX-D99379?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Removes the bibliography of TeX documents | [code](../data_juicer/ops/mapper/remove_bibliography_mapper.py) | [tests](../tests/ops/mapper/test_remove_bibliography_mapper.py) | | remove_comments_mapper | ![LaTeX](https://img.shields.io/badge/LaTeX-D99379?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Removes the comments of TeX documents | [code](../data_juicer/ops/mapper/remove_comments_mapper.py) | [tests](../tests/ops/mapper/test_remove_comments_mapper.py) | diff --git a/docs/Operators_ZH.md b/docs/Operators_ZH.md index 2b3ff3115..a68ad56e8 100644 --- a/docs/Operators_ZH.md +++ b/docs/Operators_ZH.md @@ -11,7 +11,7 @@ Data-Juicer 中的算子分为以下 5 种类型。 | 类型 | 数量 | 描述 | |------------------------------------|:--:|---------------| | [ Formatter ]( #formatter ) | 9 | 发现、加载、规范化原始数据 | -| [ Mapper ]( #mapper ) | 65 | 对数据样本进行编辑和转换 | +| [ Mapper ]( #mapper ) | 68 | 对数据样本进行编辑和转换 | | [ Filter ]( #filter ) | 44 | 过滤低质量样本 | | [ Deduplicator ]( #deduplicator ) | 8 | 识别、删除重复样本 | | [ Selector ]( #selector ) | 4 | 基于排序选取高质量样本 | @@ -67,6 +67,8 @@ Data-Juicer 中的算子分为以下 5 种类型。 | clean_html_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 删除 HTML 标签并返回所有节点的纯文本 | [code](../data_juicer/ops/mapper/clean_html_mapper.py) | [tests](../tests/ops/mapper/test_clean_html_mapper.py) | | clean_ip_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 删除 IP 地址 | [code](../data_juicer/ops/mapper/clean_ip_mapper.py) | [tests](../tests/ops/mapper/test_clean_ip_mapper.py) | | clean_links_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![Code](https://img.shields.io/badge/Code-590F08?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 删除链接,例如以 http 或 ftp 开头的 | [code](../data_juicer/ops/mapper/clean_links_mapper.py) | [tests](../tests/ops/mapper/test_clean_links_mapper.py) | +| dialog_intent_detection_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 抽取对话中的用户意图标签。 | [code](../data_juicer/ops/mapper/dialog_intent_detection_mapper.py) | [tests](../tests/ops/mapper/test_dialog_intent_detection_mapper.py) | +| dialog_sentiment_detection_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 抽取对话中用户的情感标签 | [code](../data_juicer/ops/mapper/dialog_sentiment_detection_mapper.py) | [tests](../tests/ops/mapper/test_dialog_sentiment_detection_mapper.py) | | dialog_sentiment_intensity_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 预测对话中的情绪强度(默认从-5到5)。 | [code](../data_juicer/ops/mapper/dialog_sentiment_intensity_mapper.py) | [tests](../tests/ops/mapper/test_dialog_sentiment_intensity_mapper.py) | | expand_macro_mapper | ![LaTeX](https://img.shields.io/badge/LaTeX-D99379?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 扩展通常在 TeX 文档顶部定义的宏 | [code](../data_juicer/ops/mapper/expand_macro_mapper.py) | [tests](../tests/ops/mapper/test_expand_macro_mapper.py) | | extract_entity_attribute_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 给定主体和属性名,从文本中抽取主体的属性 | [code](../data_juicer/ops/mapper/extract_entity_attribute_mapper.py) | [tests](../tests/ops/mapper/test_extract_entity_attribute_mapper.py) | @@ -93,6 +95,7 @@ Data-Juicer 中的算子分为以下 5 种类型。 | punctuation_normalization_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 将各种 Unicode 标点符号标准化为其 ASCII 等效项 | [code](../data_juicer/ops/mapper/punctuation_normalization_mapper.py) | [tests](../tests/ops/mapper/test_punctuation_normalization_mapper.py) | | python_file_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 执行文件中定义的 Python 函数处理样本 | [code](../data_juicer/ops/mapper/python_file_mapper.py) | [tests](../tests/ops/mapper/test_python_file_mapper.py) | | python_lambda_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 执行 Python lambda 函数处理样本 | [code](../data_juicer/ops/mapper/python_lambda_mapper.py) | [tests](../tests/ops/mapper/test_python_lambda_mapper.py) | +| query_intent_detection_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 预测用户查询中的意图标签。 | [code](../data_juicer/ops/mapper/query_intent_detection_mapper.py) | [tests](../tests/ops/mapper/test_query_intent_detection_mapper.py) | | query_sentiment_detection_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 预测用户查询中的情感强度标签('negative'、'neutral'和'positive')。 | [code](../data_juicer/ops/mapper/query_sentiment_detection_mapper.py) | [tests](../tests/ops/mapper/test_query_sentiment_detection_mapper.py) | | relation_identity_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 识别一段文本中两个实体之间的关系 | [code](../data_juicer/ops/mapper/relation_identity_mapper.py) | [tests](../tests/ops/mapper/test_relation_identity_mapper.py) | | remove_bibliography_mapper | ![LaTeX](https://img.shields.io/badge/LaTeX-D99379?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 删除 TeX 文档的参考文献 | [code](../data_juicer/ops/mapper/remove_bibliography_mapper.py) | [tests](../tests/ops/mapper/test_remove_bibliography_mapper.py) | diff --git a/tests/ops/mapper/test_query_intent_detection_mapper.py b/tests/ops/mapper/test_query_intent_detection_mapper.py index 3e15755c4..92d0346a4 100644 --- a/tests/ops/mapper/test_query_intent_detection_mapper.py +++ b/tests/ops/mapper/test_query_intent_detection_mapper.py @@ -12,8 +12,8 @@ class TestQueryIntentDetectionMapper(DataJuicerTestCaseBase): - hf_model = '/mnt/workspace/shared/checkpoints/huggingface/bespin-global/klue-roberta-small-3i4k-intent-classification' - zh_to_en_hf_model = '/mnt/workspace/shared/checkpoints/huggingface/Helsinki-NLP/opus-mt-zh-en' + hf_model = 'bespin-global/klue-roberta-small-3i4k-intent-classification' + zh_to_en_hf_model = 'Helsinki-NLP/opus-mt-zh-en' def _run_op(self, op, samples, label_key, targets): dataset = Dataset.from_list(samples) From 767b2f0dbe0ce65685cdedf03c5442b96c6fdd5f Mon Sep 17 00:00:00 2001 From: Haibin <1400012807@pku.edu.cn> Date: Wed, 18 Dec 2024 15:38:20 +0800 Subject: [PATCH 088/118] dialog topic detection --- data_juicer/ops/mapper/__init__.py | 8 +- .../mapper/dialog_intent_detection_mapper.py | 2 +- .../dialog_sentiment_detection_mapper.py | 2 +- .../mapper/dialog_topic_detection_mapper.py | 208 ++++++++++++++++++ .../mapper/query_intent_detection_mapper.py | 6 +- .../test_dialog_intent_detection_mapper.py | 2 +- ...test_dialog_topic_detection_mapper copy.py | 170 ++++++++++++++ 7 files changed, 389 insertions(+), 9 deletions(-) create mode 100644 data_juicer/ops/mapper/dialog_topic_detection_mapper.py create mode 100644 tests/ops/mapper/test_dialog_topic_detection_mapper copy.py diff --git a/data_juicer/ops/mapper/__init__.py b/data_juicer/ops/mapper/__init__.py index af710bd83..a6efbd737 100644 --- a/data_juicer/ops/mapper/__init__.py +++ b/data_juicer/ops/mapper/__init__.py @@ -11,6 +11,7 @@ from .dialog_intent_detection_mapper import DialogIntentDetectionMapper from .dialog_sentiment_detection_mapper import DialogSentimentDetectionMapper from .dialog_sentiment_intensity_mapper import DialogSentimentIntensityMapper +from .dialog_topic_detection_mapper import DialogTopicDetectionMapper from .expand_macro_mapper import ExpandMacroMapper from .extract_entity_attribute_mapper import ExtractEntityAttributeMapper from .extract_entity_relation_mapper import ExtractEntityRelationMapper @@ -77,9 +78,10 @@ 'CalibrateResponseMapper', 'ChineseConvertMapper', 'CleanCopyrightMapper', 'CleanEmailMapper', 'CleanHtmlMapper', 'CleanIpMapper', 'CleanLinksMapper', 'DialogIntentDetectionMapper', 'DialogSentimentDetectionMapper', - 'DialogSentimentIntensityMapper', 'ExpandMacroMapper', - 'ExtractEntityAttributeMapper', 'ExtractEntityRelationMapper', - 'ExtractEventMapper', 'ExtractKeywordMapper', 'ExtractNicknameMapper', + 'DialogSentimentIntensityMapper', 'DialogTopicDetectionMapper', + 'ExpandMacroMapper', 'ExtractEntityAttributeMapper', + 'ExtractEntityRelationMapper', 'ExtractEventMapper', + 'ExtractKeywordMapper', 'ExtractNicknameMapper', 'ExtractSupportTextMapper', 'FixUnicodeMapper', 'GenerateQAFromExamplesMapper', 'GenerateQAFromTextMapper', 'ImageBlurMapper', 'ImageCaptioningFromGPT4VMapper', diff --git a/data_juicer/ops/mapper/dialog_intent_detection_mapper.py b/data_juicer/ops/mapper/dialog_intent_detection_mapper.py index 759e291b4..4e40c86f4 100644 --- a/data_juicer/ops/mapper/dialog_intent_detection_mapper.py +++ b/data_juicer/ops/mapper/dialog_intent_detection_mapper.py @@ -18,7 +18,7 @@ class DialogIntentDetectionMapper(Mapper): """ Mapper to generate user's intent labels in dialog. Input from history_key, query_key and response_key. Output lists of - intensities and analysis for queries in the dialog, which is + labels and analysis for queries in the dialog, which is store in 'intent.dialog_labels' and 'intent.dialog_labels_analysis' in Data-Juicer meta field. """ diff --git a/data_juicer/ops/mapper/dialog_sentiment_detection_mapper.py b/data_juicer/ops/mapper/dialog_sentiment_detection_mapper.py index 237c217d5..5417e5974 100644 --- a/data_juicer/ops/mapper/dialog_sentiment_detection_mapper.py +++ b/data_juicer/ops/mapper/dialog_sentiment_detection_mapper.py @@ -18,7 +18,7 @@ class DialogSentimentDetectionMapper(Mapper): """ Mapper to generate user's sentiment labels in dialog. Input from history_key, query_key and response_key. Output lists of - intensities and analysis for queries in the dialog, which is + labels and analysis for queries in the dialog, which is store in 'sentiment.dialog_labels' and 'sentiment.dialog_labels_analysis' in Data-Juicer meta field. """ diff --git a/data_juicer/ops/mapper/dialog_topic_detection_mapper.py b/data_juicer/ops/mapper/dialog_topic_detection_mapper.py new file mode 100644 index 000000000..c1f5fa125 --- /dev/null +++ b/data_juicer/ops/mapper/dialog_topic_detection_mapper.py @@ -0,0 +1,208 @@ +import re +from typing import Dict, List, Optional + +from loguru import logger +from pydantic import NonNegativeInt, PositiveInt + +from data_juicer.ops.base_op import OPERATORS, Mapper +from data_juicer.utils.common_utils import nested_set +from data_juicer.utils.constant import Fields, MetaKeys +from data_juicer.utils.model_utils import get_model, prepare_model + +OP_NAME = 'dialog_topic_detection_mapper' + + +# TODO: LLM-based inference. +@OPERATORS.register_module(OP_NAME) +class DialogTopicDetectionMapper(Mapper): + """ + Mapper to generate user's topic labels in dialog. Input from + history_key, query_key and response_key. Output lists of + labels and analysis for queries in the dialog, which is + store in 'sentiment.dialog_labels' and + 'sentiment.dialog_labels_analysis' in Data-Juicer meta field. + """ + + DEFAULT_SYSTEM_PROMPT = ('请判断用户和LLM多轮对话中用户所讨论的话题。\n' + '要求:\n' + '- 针对用户的每个query,需要先进行分析,然后列出用户正在讨论的话题,下面是' + '一个样例,请模仿样例格式输出。\n' + '用户:你好,今天我们来聊聊秦始皇吧。\n' + '话题分析:用户提到秦始皇,这是中国历史上第一位皇帝。\n' + '话题类别:历史\n' + 'LLM:当然可以,秦始皇是中国历史上第一个统一全国的皇帝,他在公元前221年建' + '立了秦朝,并采取了一系列重要的改革措施,如统一文字、度量衡和货币等。\n' + '用户:秦始皇修建的长城和现在的长城有什么区别?\n' + '话题分析:用户提到秦始皇修建的长城,并将其与现代长城进行比较,涉及建筑历史' + '和地理位置。\n' + '话题类别:历史' + 'LLM:秦始皇时期修建的长城主要是为了抵御北方游牧民族的入侵,它的规模和修建' + '技术相对较为简陋。现代人所看到的长城大部分是明朝时期修建和扩建的,明长城不' + '仅规模更大、结构更坚固,而且保存得比较完好。\n' + '用户:有意思,那么长城的具体位置在哪些省份呢?\n' + '话题分析:用户询问长城的具体位置,涉及到地理知识。\n' + '话题类别:地理\n' + 'LLM:长城横跨中国北方多个省份,主要包括河北、山西、内蒙古、宁夏、陕西、甘' + '肃和北京等。每一段长城都建在关键的战略位置,以便最大限度地发挥其防御作用' + '。\n') + DEFAULT_QUERY_TEMPLATE = '用户:{query}\n' + DEFAULT_RESPONSE_TEMPLATE = 'LLM:{response}\n' + DEFAULT_ANALYSIS_TEMPLATE = '话题分析:{analysis}\n' + DEFAULT_LABELS_TEMPLATE = '话题类别:{labels}\n' + DEFAULT_ANALYSIS_PATTERN = '话题分析:(.*?)\n' + DEFAULT_LABELS_PATTERN = '话题类别:(.*?)($|\n)' + + def __init__(self, + api_model: str = 'gpt-4o', + topic_candidates: Optional[List[str]] = None, + max_round: NonNegativeInt = 10, + *, + api_endpoint: Optional[str] = None, + response_path: Optional[str] = None, + system_prompt: Optional[str] = None, + query_template: Optional[str] = None, + response_template: Optional[str] = None, + analysis_template: Optional[str] = None, + labels_template: Optional[str] = None, + analysis_pattern: Optional[str] = None, + labels_pattern: Optional[str] = None, + try_num: PositiveInt = 3, + model_params: Dict = {}, + sampling_params: Dict = {}, + **kwargs): + """ + Initialization method. + + :param api_model: API model name. + :param topic_candidates: The output intent candidates. Use the + intent labels of the open domain if it is None. + :param max_round: The max num of round in the dialog to build the + prompt. + :param api_endpoint: URL endpoint for the API. + :param response_path: Path to extract content from the API response. + Defaults to 'choices.0.message.content'. + :param system_prompt: System prompt for the task. + :param query_template: Template for query part to build the input + prompt. + :param response_template: Template for response part to build the + input prompt. + :param analysis_template: Template for analysis part to build the + input prompt. + :param labels_template: Template for labels part to build the + input prompt. + :param analysis_pattern: Pattern to parse the return sentiment + analysis. + :param labels_pattern: Pattern to parse the return sentiment + labels. + :param try_num: The number of retry attempts when there is an API + call error or output parsing error. + :param model_params: Parameters for initializing the API model. + :param sampling_params: Extra parameters passed to the API call. + e.g {'temperature': 0.9, 'top_p': 0.95} + :param kwargs: Extra keyword arguments. + """ + super().__init__(**kwargs) + + self.topic_candidates = topic_candidates + self.max_round = max_round + + self.system_prompt = system_prompt or self.DEFAULT_SYSTEM_PROMPT + self.query_template = query_template or self.DEFAULT_QUERY_TEMPLATE + self.response_template = response_template or \ + self.DEFAULT_RESPONSE_TEMPLATE + self.analysis_template = analysis_template or \ + self.DEFAULT_ANALYSIS_TEMPLATE + self.labels_template = labels_template or \ + self.DEFAULT_LABELS_TEMPLATE + self.analysis_pattern = analysis_pattern or \ + self.DEFAULT_ANALYSIS_PATTERN + self.labels_pattern = labels_pattern or \ + self.DEFAULT_LABELS_PATTERN + + self.sampling_params = sampling_params + + self.model_key = prepare_model(model_type='api', + model=api_model, + endpoint=api_endpoint, + response_path=response_path, + **model_params) + + self.try_num = try_num + + def build_input(self, history, query): + + if self.topic_candidates: + input_prompt = self.candidate_template.format( + candidate_str=','.join(self.topic_candidates)) + else: + input_prompt = '' + + if self.max_round > 0: + input_prompt = ''.join(history[-self.max_round * 4:]) + + input_prompt += self.query_template.format(query=query[0]) + + return input_prompt + + def parse_output(self, response): + analysis = '' + labels = '' + + match = re.search(self.analysis_pattern, response) + if match: + analysis = match.group(1) + + match = re.search(self.labels_pattern, response) + if match: + labels = match.group(1) + + return analysis, labels + + def process_single(self, sample, rank=None): + client = get_model(self.model_key, rank=rank) + + analysis_list = [] + labels_list = [] + history = [] + + dialog = sample[self.history_key] + if self.query_key in sample and sample[self.query_key]: + if self.response_key in sample and sample[self.response_key]: + dialog.append( + (sample[self.query_key], sample[self.response_key])) + else: + dialog.append((sample[self.query_key], '')) + + for qa in dialog: + input_prompt = self.build_input(history, qa) + messages = [{ + 'role': 'system', + 'content': self.system_prompt, + }, { + 'role': 'user', + 'content': input_prompt, + }] + + for _ in range(self.try_num): + try: + response = client(messages, **self.sampling_params) + analysis, labels = self.parse_output(response) + if len(analysis) > 0: + break + except Exception as e: + logger.warning(f'Exception: {e}') + + analysis_list.append(analysis) + labels_list.append(labels) + + history.append(self.query_template.format(query=qa[0])) + history.append(self.analysis_template.format(analysis=analysis)) + history.append(self.labels_template.format(labels=labels)) + history.append(self.response_template.format(response=qa[1])) + + analysis_key = f'{Fields.meta}.{MetaKeys.dialog_topic_labels_analysis}' # noqa: E501 + sample = nested_set(sample, analysis_key, analysis_list) + labels_key = f'{Fields.meta}.{MetaKeys.dialog_topic_labels}' + sample = nested_set(sample, labels_key, labels_list) + + return sample diff --git a/data_juicer/ops/mapper/query_intent_detection_mapper.py b/data_juicer/ops/mapper/query_intent_detection_mapper.py index e4d44aa1d..b6e69f610 100644 --- a/data_juicer/ops/mapper/query_intent_detection_mapper.py +++ b/data_juicer/ops/mapper/query_intent_detection_mapper.py @@ -68,15 +68,15 @@ def process_batched(self, samples, rank=None): classifier, _ = get_model(self.model_key, rank, self.use_cuda()) results = classifier(queries) - intensities = [r['label'] for r in results] + labels = [r['label'] for r in results] scores = [r['score'] for r in results] if Fields.meta not in samples: - samples[Fields.meta] = [{} for val in intensities] + samples[Fields.meta] = [{} for val in labels] for i in range(len(samples[Fields.meta])): samples[Fields.meta][i] = nested_set(samples[Fields.meta][i], MetaKeys.query_intent_label, - intensities[i]) + labels[i]) samples[Fields.meta][i] = nested_set(samples[Fields.meta][i], MetaKeys.query_intent_score, scores[i]) diff --git a/tests/ops/mapper/test_dialog_intent_detection_mapper.py b/tests/ops/mapper/test_dialog_intent_detection_mapper.py index e7db2cea6..bc3a18752 100644 --- a/tests/ops/mapper/test_dialog_intent_detection_mapper.py +++ b/tests/ops/mapper/test_dialog_intent_detection_mapper.py @@ -26,7 +26,7 @@ def _run_op(self, op, samples, target_len): for analysis, labels in zip(analysis_list, labels_list): logger.info(f'分析:{analysis}') - logger.info(f'情绪:{labels}') + logger.info(f'意图:{labels}') self.assertEqual(len(analysis_list), target_len) self.assertEqual(len(labels_list), target_len) diff --git a/tests/ops/mapper/test_dialog_topic_detection_mapper copy.py b/tests/ops/mapper/test_dialog_topic_detection_mapper copy.py new file mode 100644 index 000000000..22153945b --- /dev/null +++ b/tests/ops/mapper/test_dialog_topic_detection_mapper copy.py @@ -0,0 +1,170 @@ +import unittest +import json + +from loguru import logger + +from data_juicer.core.data import NestedDataset as Dataset +from data_juicer.ops.mapper.dialog_intent_detection_mapper import DialogTopicDetectionMapper +from data_juicer.utils.unittest_utils import (SKIPPED_TESTS, + DataJuicerTestCaseBase) +from data_juicer.utils.constant import Fields, MetaKeys +from data_juicer.utils.common_utils import nested_access + +# Skip tests for this OP. +# These tests have been tested locally. +@SKIPPED_TESTS.register_module() +class TestDialogTopicDetectionMapper(DataJuicerTestCaseBase): + # before runing this test, set below environment variables: + # export OPENAI_API_URL=https://dashscope.aliyuncs.com/compatible-mode/v1 + # export OPENAI_API_KEY=your_key + + def _run_op(self, op, samples, target_len): + dataset = Dataset.from_list(samples) + dataset = dataset.map(op.process, batch_size=2) + analysis_list = nested_access(dataset[0][Fields.meta], MetaKeys.dialog_intent_labels_analysis) + labels_list = nested_access(dataset[0][Fields.meta], MetaKeys.dialog_intent_labels) + + for analysis, labels in zip(analysis_list, labels_list): + logger.info(f'分析:{analysis}') + logger.info(f'话题:{labels}') + + self.assertEqual(len(analysis_list), target_len) + self.assertEqual(len(labels_list), target_len) + + def test_default(self): + + samples = [{ + 'history': [ + ( + '李莲花有口皆碑', + '「微笑」过奖了,我也就是个普通大夫,没什么值得夸耀的。' + ), + ( + '是的,你确实是一个普通大夫,没什么值得夸耀的。', + '「委屈」你这话说的,我也是尽心尽力治病救人了。' + ), + ( + '你自己说的呀,我现在说了,你又不高兴了。', + 'or of of of of or or and or of of of of of of of,,, ' + ), + ( + '你在说什么我听不懂。', + '「委屈」我也没说什么呀,就是觉得你有点冤枉我了' + ) + ] + }] + + op = DialogTopicDetectionMapper(api_model='qwen2.5-72b-instruct') + self._run_op(op, samples, 4) + + def test_max_round(self): + + samples = [{ + 'history': [ + ( + '李莲花有口皆碑', + '「微笑」过奖了,我也就是个普通大夫,没什么值得夸耀的。' + ), + ( + '是的,你确实是一个普通大夫,没什么值得夸耀的。', + '「委屈」你这话说的,我也是尽心尽力治病救人了。' + ), + ( + '你自己说的呀,我现在说了,你又不高兴了。', + 'or of of of of or or and or of of of of of of of,,, ' + ), + ( + '你在说什么我听不懂。', + '「委屈」我也没说什么呀,就是觉得你有点冤枉我了' + ) + ] + }] + + op = DialogTopicDetectionMapper(api_model='qwen2.5-72b-instruct', + max_round=1) + self._run_op(op, samples, 4) + + def test_max_round_zero(self): + + samples = [{ + 'history': [ + ( + '李莲花有口皆碑', + '「微笑」过奖了,我也就是个普通大夫,没什么值得夸耀的。' + ), + ( + '是的,你确实是一个普通大夫,没什么值得夸耀的。', + '「委屈」你这话说的,我也是尽心尽力治病救人了。' + ), + ( + '你自己说的呀,我现在说了,你又不高兴了。', + 'or of of of of or or and or of of of of of of of,,, ' + ), + ( + '你在说什么我听不懂。', + '「委屈」我也没说什么呀,就是觉得你有点冤枉我了' + ) + ] + }] + + op = DialogTopicDetectionMapper(api_model='qwen2.5-72b-instruct', + max_round=0) + self._run_op(op, samples, 4) + + def test_query(self): + + samples = [{ + 'history': [ + ( + '李莲花有口皆碑', + '「微笑」过奖了,我也就是个普通大夫,没什么值得夸耀的。' + ), + ( + '是的,你确实是一个普通大夫,没什么值得夸耀的。', + '「委屈」你这话说的,我也是尽心尽力治病救人了。' + ), + ( + '你自己说的呀,我现在说了,你又不高兴了。', + 'or of of of of or or and or of of of of of of of,,, ' + ) + ], + 'query': '你在说什么我听不懂。', + 'response': '「委屈」我也没说什么呀,就是觉得你有点冤枉我了' + }] + + op = DialogTopicDetectionMapper(api_model='qwen2.5-72b-instruct', + max_round=1) + self._run_op(op, samples, 4) + + def test_topic_candidates(self): + + samples = [{ + 'history': [ + ( + '李莲花有口皆碑', + '「微笑」过奖了,我也就是个普通大夫,没什么值得夸耀的。' + ), + ( + '是的,你确实是一个普通大夫,没什么值得夸耀的。', + '「委屈」你这话说的,我也是尽心尽力治病救人了。' + ), + ( + '你自己说的呀,我现在说了,你又不高兴了。', + 'or of of of of or or and or of of of of of of of,,, ' + ), + ( + '你在说什么我听不懂。', + '「委屈」我也没说什么呀,就是觉得你有点冤枉我了' + ) + ] + }] + + op = DialogTopicDetectionMapper( + api_model='qwen2.5-72b-instruct', + topic_candidates=['评价', '讽刺', '表达困惑'] + ) + self._run_op(op, samples, 4) + + +if __name__ == '__main__': + unittest.main() From c088cb1aad91dd2c96c42cce8850e3d3ac1ce62b Mon Sep 17 00:00:00 2001 From: Haibin <1400012807@pku.edu.cn> Date: Wed, 18 Dec 2024 15:41:25 +0800 Subject: [PATCH 089/118] dialog topic detection --- ...ion_mapper copy.py => test_dialog_topic_detection_mapper.py} | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) rename tests/ops/mapper/{test_dialog_topic_detection_mapper copy.py => test_dialog_topic_detection_mapper.py} (98%) diff --git a/tests/ops/mapper/test_dialog_topic_detection_mapper copy.py b/tests/ops/mapper/test_dialog_topic_detection_mapper.py similarity index 98% rename from tests/ops/mapper/test_dialog_topic_detection_mapper copy.py rename to tests/ops/mapper/test_dialog_topic_detection_mapper.py index 22153945b..fccdc4390 100644 --- a/tests/ops/mapper/test_dialog_topic_detection_mapper copy.py +++ b/tests/ops/mapper/test_dialog_topic_detection_mapper.py @@ -4,7 +4,7 @@ from loguru import logger from data_juicer.core.data import NestedDataset as Dataset -from data_juicer.ops.mapper.dialog_intent_detection_mapper import DialogTopicDetectionMapper +from data_juicer.ops.mapper.dialog_topic_detection_mapper import DialogTopicDetectionMapper from data_juicer.utils.unittest_utils import (SKIPPED_TESTS, DataJuicerTestCaseBase) from data_juicer.utils.constant import Fields, MetaKeys From 12351db2d02ca2e0f87718ca212532b237ca7d94 Mon Sep 17 00:00:00 2001 From: Haibin <1400012807@pku.edu.cn> Date: Wed, 18 Dec 2024 15:46:01 +0800 Subject: [PATCH 090/118] dialog topic detection --- data_juicer/utils/constant.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/data_juicer/utils/constant.py b/data_juicer/utils/constant.py index 5cc9c33c7..7b6c4e081 100644 --- a/data_juicer/utils/constant.py +++ b/data_juicer/utils/constant.py @@ -83,6 +83,9 @@ class MetaKeys(object): query_intent_label = 'intent.query_label' query_intent_score = 'intent.query_label_score' + dialog_topic_labels = 'topic.dialog_labels' + dialog_topic_labels_analysis = 'topic.dialog_labels_analysis' + class StatsKeysMeta(type): """ From 4b4e946dd5a699a9ef4907348d551a01e0c2a3de Mon Sep 17 00:00:00 2001 From: Haibin <1400012807@pku.edu.cn> Date: Wed, 18 Dec 2024 15:49:49 +0800 Subject: [PATCH 091/118] dialog topic detection --- tests/ops/mapper/test_dialog_topic_detection_mapper.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/ops/mapper/test_dialog_topic_detection_mapper.py b/tests/ops/mapper/test_dialog_topic_detection_mapper.py index fccdc4390..b0c765382 100644 --- a/tests/ops/mapper/test_dialog_topic_detection_mapper.py +++ b/tests/ops/mapper/test_dialog_topic_detection_mapper.py @@ -21,8 +21,8 @@ class TestDialogTopicDetectionMapper(DataJuicerTestCaseBase): def _run_op(self, op, samples, target_len): dataset = Dataset.from_list(samples) dataset = dataset.map(op.process, batch_size=2) - analysis_list = nested_access(dataset[0][Fields.meta], MetaKeys.dialog_intent_labels_analysis) - labels_list = nested_access(dataset[0][Fields.meta], MetaKeys.dialog_intent_labels) + analysis_list = nested_access(dataset[0][Fields.meta], MetaKeys.dialog_topic_labels_analysis) + labels_list = nested_access(dataset[0][Fields.meta], MetaKeys.dialog_topic_labels) for analysis, labels in zip(analysis_list, labels_list): logger.info(f'分析:{analysis}') From 4506a8eaf2ccac81951371cfe9521eceae35f0df Mon Sep 17 00:00:00 2001 From: Haibin <1400012807@pku.edu.cn> Date: Wed, 18 Dec 2024 15:54:53 +0800 Subject: [PATCH 092/118] dialog topic detection --- .../ops/mapper/dialog_topic_detection_mapper.py | 14 +++----------- 1 file changed, 3 insertions(+), 11 deletions(-) diff --git a/data_juicer/ops/mapper/dialog_topic_detection_mapper.py b/data_juicer/ops/mapper/dialog_topic_detection_mapper.py index c1f5fa125..60db0f981 100644 --- a/data_juicer/ops/mapper/dialog_topic_detection_mapper.py +++ b/data_juicer/ops/mapper/dialog_topic_detection_mapper.py @@ -1,5 +1,5 @@ import re -from typing import Dict, List, Optional +from typing import Dict, Optional from loguru import logger from pydantic import NonNegativeInt, PositiveInt @@ -54,7 +54,6 @@ class DialogTopicDetectionMapper(Mapper): def __init__(self, api_model: str = 'gpt-4o', - topic_candidates: Optional[List[str]] = None, max_round: NonNegativeInt = 10, *, api_endpoint: Optional[str] = None, @@ -74,8 +73,6 @@ def __init__(self, Initialization method. :param api_model: API model name. - :param topic_candidates: The output intent candidates. Use the - intent labels of the open domain if it is None. :param max_round: The max num of round in the dialog to build the prompt. :param api_endpoint: URL endpoint for the API. @@ -103,7 +100,6 @@ def __init__(self, """ super().__init__(**kwargs) - self.topic_candidates = topic_candidates self.max_round = max_round self.system_prompt = system_prompt or self.DEFAULT_SYSTEM_PROMPT @@ -131,14 +127,10 @@ def __init__(self, def build_input(self, history, query): - if self.topic_candidates: - input_prompt = self.candidate_template.format( - candidate_str=','.join(self.topic_candidates)) - else: - input_prompt = '' - if self.max_round > 0: input_prompt = ''.join(history[-self.max_round * 4:]) + else: + input_prompt = '' input_prompt += self.query_template.format(query=query[0]) From d21db85d7a1ae62fb9f35e9f1e30997ba1deb226 Mon Sep 17 00:00:00 2001 From: Haibin <1400012807@pku.edu.cn> Date: Wed, 18 Dec 2024 15:59:25 +0800 Subject: [PATCH 093/118] dialog topic detection --- .../test_dialog_topic_detection_mapper.py | 29 ------------------- 1 file changed, 29 deletions(-) diff --git a/tests/ops/mapper/test_dialog_topic_detection_mapper.py b/tests/ops/mapper/test_dialog_topic_detection_mapper.py index b0c765382..887e96bad 100644 --- a/tests/ops/mapper/test_dialog_topic_detection_mapper.py +++ b/tests/ops/mapper/test_dialog_topic_detection_mapper.py @@ -136,35 +136,6 @@ def test_query(self): max_round=1) self._run_op(op, samples, 4) - def test_topic_candidates(self): - - samples = [{ - 'history': [ - ( - '李莲花有口皆碑', - '「微笑」过奖了,我也就是个普通大夫,没什么值得夸耀的。' - ), - ( - '是的,你确实是一个普通大夫,没什么值得夸耀的。', - '「委屈」你这话说的,我也是尽心尽力治病救人了。' - ), - ( - '你自己说的呀,我现在说了,你又不高兴了。', - 'or of of of of or or and or of of of of of of of,,, ' - ), - ( - '你在说什么我听不懂。', - '「委屈」我也没说什么呀,就是觉得你有点冤枉我了' - ) - ] - }] - - op = DialogTopicDetectionMapper( - api_model='qwen2.5-72b-instruct', - topic_candidates=['评价', '讽刺', '表达困惑'] - ) - self._run_op(op, samples, 4) - if __name__ == '__main__': unittest.main() From 6f394ee8bafed3ba1efc859a1994db93f1799621 Mon Sep 17 00:00:00 2001 From: Haibin <1400012807@pku.edu.cn> Date: Wed, 18 Dec 2024 16:59:15 +0800 Subject: [PATCH 094/118] query topic detection --- data_juicer/ops/mapper/__init__.py | 30 +++---- .../mapper/query_topic_detection_mapper.py | 84 +++++++++++++++++++ .../test_query_topic_detection_mapper.py | 62 ++++++++++++++ 3 files changed, 162 insertions(+), 14 deletions(-) create mode 100644 data_juicer/ops/mapper/query_topic_detection_mapper.py create mode 100644 tests/ops/mapper/test_query_topic_detection_mapper.py diff --git a/data_juicer/ops/mapper/__init__.py b/data_juicer/ops/mapper/__init__.py index a6efbd737..8ffe7cc8e 100644 --- a/data_juicer/ops/mapper/__init__.py +++ b/data_juicer/ops/mapper/__init__.py @@ -39,6 +39,7 @@ from .python_lambda_mapper import PythonLambdaMapper from .query_intent_detection_mapper import QueryIntentDetectionMapper from .query_sentiment_detection_mapper import QuerySentimentDetectionMapper +from .query_topic_detection_mapper import QueryTopicDetectionMapper from .relation_identity_mapper import RelationIdentityMapper from .remove_bibliography_mapper import RemoveBibliographyMapper from .remove_comments_mapper import RemoveCommentsMapper @@ -90,18 +91,19 @@ 'OptimizeQAMapper', 'OptimizeQueryMapper', 'OptimizeResponseMapper', 'PairPreferenceMapper', 'PunctuationNormalizationMapper', 'PythonFileMapper', 'PythonLambdaMapper', 'QuerySentimentDetectionMapper', - 'QueryIntentDetectionMapper', 'RelationIdentityMapper', - 'RemoveBibliographyMapper', 'RemoveCommentsMapper', 'RemoveHeaderMapper', - 'RemoveLongWordsMapper', 'RemoveNonChineseCharacterlMapper', - 'RemoveRepeatSentencesMapper', 'RemoveSpecificCharsMapper', - 'RemoveTableTextMapper', 'RemoveWordsWithIncorrectSubstringsMapper', - 'ReplaceContentMapper', 'SentenceSplitMapper', 'TextChunkMapper', - 'VideoCaptioningFromAudioMapper', 'VideoCaptioningFromFramesMapper', - 'VideoCaptioningFromSummarizerMapper', 'VideoCaptioningFromVideoMapper', - 'VideoExtractFramesMapper', 'VideoFFmpegWrappedMapper', - 'VideoFaceBlurMapper', 'VideoRemoveWatermarkMapper', - 'VideoResizeAspectRatioMapper', 'VideoResizeResolutionMapper', - 'VideoSplitByDurationMapper', 'VideoSplitByKeyFrameMapper', - 'VideoSplitBySceneMapper', 'VideoTaggingFromAudioMapper', - 'VideoTaggingFromFramesMapper', 'WhitespaceNormalizationMapper' + 'QueryIntentDetectionMapper', 'QueryTopicDetectionMapper', + 'RelationIdentityMapper', 'RemoveBibliographyMapper', + 'RemoveCommentsMapper', 'RemoveHeaderMapper', 'RemoveLongWordsMapper', + 'RemoveNonChineseCharacterlMapper', 'RemoveRepeatSentencesMapper', + 'RemoveSpecificCharsMapper', 'RemoveTableTextMapper', + 'RemoveWordsWithIncorrectSubstringsMapper', 'ReplaceContentMapper', + 'SentenceSplitMapper', 'TextChunkMapper', 'VideoCaptioningFromAudioMapper', + 'VideoCaptioningFromFramesMapper', 'VideoCaptioningFromSummarizerMapper', + 'VideoCaptioningFromVideoMapper', 'VideoExtractFramesMapper', + 'VideoFFmpegWrappedMapper', 'VideoFaceBlurMapper', + 'VideoRemoveWatermarkMapper', 'VideoResizeAspectRatioMapper', + 'VideoResizeResolutionMapper', 'VideoSplitByDurationMapper', + 'VideoSplitByKeyFrameMapper', 'VideoSplitBySceneMapper', + 'VideoTaggingFromAudioMapper', 'VideoTaggingFromFramesMapper', + 'WhitespaceNormalizationMapper' ] diff --git a/data_juicer/ops/mapper/query_topic_detection_mapper.py b/data_juicer/ops/mapper/query_topic_detection_mapper.py new file mode 100644 index 000000000..577f00bad --- /dev/null +++ b/data_juicer/ops/mapper/query_topic_detection_mapper.py @@ -0,0 +1,84 @@ +from typing import Dict, Optional + +from data_juicer.utils.common_utils import nested_set +from data_juicer.utils.constant import Fields, MetaKeys +from data_juicer.utils.model_utils import get_model, prepare_model + +from ..base_op import OPERATORS, Mapper + +OP_NAME = 'query_topic_detection_mapper' + + +@OPERATORS.register_module(OP_NAME) +class QueryTopicDetectionMapper(Mapper): + """ + Mapper to predict user's topic label in query. Input from query_key. + Output topic label and corresponding score for the query, which is + store in 'topic.query_label' and 'topic.query_label_score' in + Data-Juicer meta field. + """ + + _accelerator = 'cuda' + _batched_op = True + + def __init__( + self, + hf_model: + str = 'dstefa/roberta-base_topic_classification_nyt_news', # noqa: E501 E131 + zh_to_en_hf_model: Optional[str] = 'Helsinki-NLP/opus-mt-zh-en', + model_params: Dict = {}, + zh_to_en_model_params: Dict = {}, + **kwargs): + """ + Initialization method. + + :param hf_model: Hugginface model ID to predict topic label. + :param zh_to_en_hf_model: Translation model from Chinese to English. + If not None, translate the query from Chinese to English. + :param model_params: model param for hf_model. + :param zh_to_en_model_params: model param for zh_to_hf_model. + :param kwargs: Extra keyword arguments. + """ + super().__init__(**kwargs) + + self.model_key = prepare_model(model_type='huggingface', + pretrained_model_name_or_path=hf_model, + return_pipe=True, + pipe_task='text-classification', + **model_params) + + if zh_to_en_hf_model is not None: + self.zh_to_en_model_key = prepare_model( + model_type='huggingface', + pretrained_model_name_or_path=zh_to_en_hf_model, + return_pipe=True, + pipe_task='translation', + **zh_to_en_model_params) + else: + self.zh_to_en_model_key = None + + def process_batched(self, samples, rank=None): + queries = samples[self.query_key] + + if self.zh_to_en_model_key is not None: + translater, _ = get_model(self.zh_to_en_model_key, rank, + self.use_cuda()) + results = translater(queries) + queries = [item['translation_text'] for item in results] + + classifier, _ = get_model(self.model_key, rank, self.use_cuda()) + results = classifier(queries) + labels = [r['label'] for r in results] + scores = [r['score'] for r in results] + + if Fields.meta not in samples: + samples[Fields.meta] = [{} for val in labels] + for i in range(len(samples[Fields.meta])): + samples[Fields.meta][i] = nested_set(samples[Fields.meta][i], + MetaKeys.query_topic_label, + labels[i]) + samples[Fields.meta][i] = nested_set(samples[Fields.meta][i], + MetaKeys.query_topic_score, + scores[i]) + + return samples diff --git a/tests/ops/mapper/test_query_topic_detection_mapper.py b/tests/ops/mapper/test_query_topic_detection_mapper.py new file mode 100644 index 000000000..18064ca63 --- /dev/null +++ b/tests/ops/mapper/test_query_topic_detection_mapper.py @@ -0,0 +1,62 @@ +import unittest +import json + +from loguru import logger + +from data_juicer.core.data import NestedDataset as Dataset +from data_juicer.ops.mapper.query_topic_detection_mapper import QueryTopicDetectionMapper +from data_juicer.utils.unittest_utils import (SKIPPED_TESTS, + DataJuicerTestCaseBase) +from data_juicer.utils.constant import Fields, MetaKeys +from data_juicer.utils.common_utils import nested_access + +class TestQueryTopicDetectionMapper(DataJuicerTestCaseBase): + + hf_model = '/mnt/workspace/shared/checkpoints/huggingface/dstefa/roberta-base_topic_classification_nyt_news' + zh_to_en_hf_model = '/mnt/workspace/shared/checkpoints/huggingface/Helsinki-NLP/opus-mt-zh-en' + + def _run_op(self, op, samples, label_key, targets): + dataset = Dataset.from_list(samples) + dataset = dataset.map(op.process, batch_size=2) + + for sample, target in zip(dataset, targets): + label = nested_access(sample[Fields.meta], label_key) + print(label, target) + # self.assertEqual(label, target) + + def test_default(self): + + samples = [{ + 'query': '今天火箭和快船的比赛谁赢了。' + },{ + 'query': '这件衣服好看吗?' + },{ + 'query': '你最近身体怎么样。' + } + ] + targets = ['Sports', 'Lifestyle and Fashion', 'Health and Wellness'] + + op = QueryTopicDetectionMapper( + hf_model = self.hf_model, + zh_to_en_hf_model = self.zh_to_en_hf_model, + ) + self._run_op(op, samples, MetaKeys.query_topic_label, targets) + + def test_no_zh_to_en(self): + + samples = [{ + 'query': '这样好吗?' + },{ + 'query': 'Is this okay?' + } + ] + targets = ['question', 'rhetorical question'] + + op = QueryTopicDetectionMapper( + hf_model = self.hf_model, + zh_to_en_hf_model = None, + ) + self._run_op(op, samples, MetaKeys.query_topic_label, targets) + +if __name__ == '__main__': + unittest.main() From abee815f947c8d8d1ad436602138a3016b56e354 Mon Sep 17 00:00:00 2001 From: Haibin <1400012807@pku.edu.cn> Date: Wed, 18 Dec 2024 17:03:12 +0800 Subject: [PATCH 095/118] query topic detection --- data_juicer/utils/constant.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/data_juicer/utils/constant.py b/data_juicer/utils/constant.py index 7b6c4e081..de302ab5a 100644 --- a/data_juicer/utils/constant.py +++ b/data_juicer/utils/constant.py @@ -85,6 +85,8 @@ class MetaKeys(object): dialog_topic_labels = 'topic.dialog_labels' dialog_topic_labels_analysis = 'topic.dialog_labels_analysis' + query_topic_label = 'topic.query_label' + query_topic_score = 'topic.query_label_score' class StatsKeysMeta(type): From 04947416652ab558d58d6dbe031dbcf333947d0c Mon Sep 17 00:00:00 2001 From: Haibin <1400012807@pku.edu.cn> Date: Wed, 18 Dec 2024 17:29:21 +0800 Subject: [PATCH 096/118] query topic detection --- tests/ops/mapper/test_query_topic_detection_mapper.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/ops/mapper/test_query_topic_detection_mapper.py b/tests/ops/mapper/test_query_topic_detection_mapper.py index 18064ca63..f189851e7 100644 --- a/tests/ops/mapper/test_query_topic_detection_mapper.py +++ b/tests/ops/mapper/test_query_topic_detection_mapper.py @@ -29,12 +29,12 @@ def test_default(self): samples = [{ 'query': '今天火箭和快船的比赛谁赢了。' },{ - 'query': '这件衣服好看吗?' + 'query': '毕加索的这幅画怎么样?' },{ 'query': '你最近身体怎么样。' } ] - targets = ['Sports', 'Lifestyle and Fashion', 'Health and Wellness'] + targets = ['Sports', 'Arts, Culture, and Entertainment', 'Health and Wellness'] op = QueryTopicDetectionMapper( hf_model = self.hf_model, @@ -50,7 +50,7 @@ def test_no_zh_to_en(self): 'query': 'Is this okay?' } ] - targets = ['question', 'rhetorical question'] + targets = ['Lifestyle and Fashion', 'Health and Wellness'] op = QueryTopicDetectionMapper( hf_model = self.hf_model, From 38523a143ca5ac3f9bceb5967a05e32ba26c4415 Mon Sep 17 00:00:00 2001 From: Haibin <1400012807@pku.edu.cn> Date: Wed, 18 Dec 2024 17:31:10 +0800 Subject: [PATCH 097/118] query topic detection --- tests/ops/mapper/test_query_topic_detection_mapper.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/tests/ops/mapper/test_query_topic_detection_mapper.py b/tests/ops/mapper/test_query_topic_detection_mapper.py index f189851e7..fcf5f2c74 100644 --- a/tests/ops/mapper/test_query_topic_detection_mapper.py +++ b/tests/ops/mapper/test_query_topic_detection_mapper.py @@ -21,15 +21,12 @@ def _run_op(self, op, samples, label_key, targets): for sample, target in zip(dataset, targets): label = nested_access(sample[Fields.meta], label_key) - print(label, target) - # self.assertEqual(label, target) + self.assertEqual(label, target) def test_default(self): samples = [{ 'query': '今天火箭和快船的比赛谁赢了。' - },{ - 'query': '毕加索的这幅画怎么样?' },{ 'query': '你最近身体怎么样。' } From b03a33a351f5b9608c71b3b441a443a3af23d243 Mon Sep 17 00:00:00 2001 From: Haibin <1400012807@pku.edu.cn> Date: Wed, 18 Dec 2024 17:32:10 +0800 Subject: [PATCH 098/118] query topic detection --- tests/ops/mapper/test_query_topic_detection_mapper.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/ops/mapper/test_query_topic_detection_mapper.py b/tests/ops/mapper/test_query_topic_detection_mapper.py index fcf5f2c74..9cbfc3069 100644 --- a/tests/ops/mapper/test_query_topic_detection_mapper.py +++ b/tests/ops/mapper/test_query_topic_detection_mapper.py @@ -31,7 +31,7 @@ def test_default(self): 'query': '你最近身体怎么样。' } ] - targets = ['Sports', 'Arts, Culture, and Entertainment', 'Health and Wellness'] + targets = ['Sports', 'Health and Wellness'] op = QueryTopicDetectionMapper( hf_model = self.hf_model, From ad226b1c057e59da0355800a15459ee6f2ddd078 Mon Sep 17 00:00:00 2001 From: Haibin <1400012807@pku.edu.cn> Date: Wed, 18 Dec 2024 18:59:43 +0800 Subject: [PATCH 099/118] doc done --- configs/config_all.yaml | 20 +++++++++++++++++++ data_juicer/utils/auto_install_mapping.py | 2 ++ docs/Operators.md | 6 ++++-- docs/Operators_ZH.md | 4 +++- .../test_query_topic_detection_mapper.py | 4 ++-- 5 files changed, 31 insertions(+), 5 deletions(-) diff --git a/configs/config_all.yaml b/configs/config_all.yaml index 74600663d..462478c23 100644 --- a/configs/config_all.yaml +++ b/configs/config_all.yaml @@ -124,6 +124,21 @@ process: try_num: 3 # The number of retry attempts when there is an API call error or output parsing error. model_params: {} # Parameters for initializing the API model. sampling_params: {} # Extra parameters passed to the API call. e.g {'temperature': 0.9, 'top_p': 0.95} + - dialog_topic_detection_mapper: # Mapper to generate user's topic labels in dialog. + api_model: 'gpt-4o' # API model name. + max_round: 10 # The max num of round in the dialog to build the prompt. + api_endpoint: null # URL endpoint for the API. + response_path: null # Path to extract content from the API response. Defaults to 'choices.0.message.content'. + system_prompt: null # System prompt for the task. + query_template: null # Template for query part to build the input prompt. + response_template: null # Template for response part to build the input prompt. + analysis_template: null # Template for analysis part to build the input prompt. + labels_template: null # Template for labels part to build the input prompt. + analysis_pattern: null # Pattern to parse the return topic analysis. + labels_pattern: null # Pattern to parse the return topic labels. + try_num: 3 # The number of retry attempts when there is an API call error or output parsing error. + model_params: {} # Parameters for initializing the API model. + sampling_params: {} # Extra parameters passed to the API call. e.g {'temperature': 0.9, 'top_p': 0.95} - expand_macro_mapper: # expand macro definitions in Latex text. - extract_entity_attribute_mapper: # Extract attributes for given entities from the text. api_model: 'gpt-4o' # API model name. @@ -334,6 +349,11 @@ process: zh_to_en_hf_model: 'Helsinki-NLP/opus-mt-zh-en' # Translation model from Chinese to English. If not None, translate the query from Chinese to English. model_params: {} # model param for hf_model. zh_to_en_model_params: {} # model param for zh_to_hf_model. + - query_topic_detection_mapper: # Mapper to predict user's topic label in query. + hf_model: 'dstefa/roberta-base_topic_classification_nyt_news' # Hugginface model ID to predict topic label. + zh_to_en_hf_model: 'Helsinki-NLP/opus-mt-zh-en' # Translation model from Chinese to English. If not None, translate the query from Chinese to English. + model_params: {} # model param for hf_model. + zh_to_en_model_params: {} # model param for zh_to_hf_model. - relation_identity_mapper: # identify relation between two entity in the text. api_model: 'gpt-4o' # API model name. source_entity: '孙悟空' # The source entity of the relation to be dentified. diff --git a/data_juicer/utils/auto_install_mapping.py b/data_juicer/utils/auto_install_mapping.py index 6fea1fef8..f691c4c89 100644 --- a/data_juicer/utils/auto_install_mapping.py +++ b/data_juicer/utils/auto_install_mapping.py @@ -83,6 +83,8 @@ 'dialog_intent_detection_mapper': ['openai'], 'dialog_sentiment_detection_mapper': ['openai'], 'dialog_sentiment_intensity_mapper': ['openai'], + 'dialog_topic_intensity_mapper': ['openai'], 'query_intent_detection_mapper': ['transformers'], 'query_sentiment_detection_mapper': ['transformers'], + 'query_topic_detection_mapper': ['transformers'], } diff --git a/docs/Operators.md b/docs/Operators.md index 2ad6251f3..bd099fb07 100644 --- a/docs/Operators.md +++ b/docs/Operators.md @@ -11,7 +11,7 @@ The operators in Data-Juicer are categorized into 5 types. | Type | Number | Description | |-----------------------------------|:------:|-------------------------------------------------| | [ Formatter ]( #formatter ) | 9 | Discovers, loads, and canonicalizes source data | -| [ Mapper ]( #mapper ) | 68 | Edits and transforms samples | +| [ Mapper ]( #mapper ) | 70 | Edits and transforms samples | | [ Filter ]( #filter ) | 44 | Filters out low-quality samples | | [ Deduplicator ]( #deduplicator ) | 8 | Detects and removes duplicate samples | | [ Selector ]( #selector ) | 4 | Selects top samples based on ranking | @@ -71,6 +71,7 @@ All the specific operators are listed below, each featured with several capabili | dialog_intent_detection_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Mapper to generate user's intent labels in dialog. | [code](../data_juicer/ops/mapper/dialog_intent_detection_mapper.py) | [tests](../tests/ops/mapper/test_dialog_intent_detection_mapper.py) | | dialog_sentiment_detection_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Mapper to generate user's sentiment labels in dialog. | [code](../data_juicer/ops/mapper/dialog_sentiment_detection_mapper.py) | [tests](../tests/ops/mapper/test_dialog_sentiment_detection_mapper.py) | | dialog_sentiment_intensity_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Mapper to predict user's sentiment intensity (from -5 to 5 in default prompt) in dialog. | [code](../data_juicer/ops/mapper/dialog_sentiment_intensity_mapper.py) | [tests](../tests/ops/mapper/test_dialog_sentiment_intensity_mapper.py) | +| dialog_topic_detection_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Mapper to generate user's topic labels in dialog. | [code](../data_juicer/ops/mapper/dialog_topic_detection_mapper.py) | [tests](../tests/ops/mapper/test_dialog_topic_detection_mapper.py) | | expand_macro_mapper | ![LaTeX](https://img.shields.io/badge/LaTeX-D99379?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Expands macros usually defined at the top of TeX documents | [code](../data_juicer/ops/mapper/expand_macro_mapper.py) | [tests](../tests/ops/mapper/test_expand_macro_mapper.py) | | extract_entity_attribute_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Extract attributes for given entities from the text. | [code](../data_juicer/ops/mapper/extract_entity_attribute_mapper.py) | [tests](../tests/ops/mapper/test_extract_entity_attribute_mapper.py) | | extract_entity_relation_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Extract entities and relations in the text for knowledge graph. | [code](../data_juicer/ops/mapper/extract_entity_relation_mapper.py) | [tests](../tests/ops/mapper/test_extract_entity_relation_mapper.py) | @@ -96,8 +97,9 @@ All the specific operators are listed below, each featured with several capabili | punctuation_normalization_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Normalizes various Unicode punctuations to their ASCII equivalents | [code](../data_juicer/ops/mapper/punctuation_normalization_mapper.py) | [tests](../tests/ops/mapper/test_punctuation_normalization_mapper.py) | | python_file_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Executing Python function defined in a file | [code](../data_juicer/ops/mapper/python_file_mapper.py) | [tests](../tests/ops/mapper/test_python_file_mapper.py) | | python_lambda_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Executing Python lambda function on data samples | [code](../data_juicer/ops/mapper/python_lambda_mapper.py) | [tests](../tests/ops/mapper/test_python_lambda_mapper.py) | -| query_intent_detection_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Mapper to predict user's Intent label in query. | [code](../data_juicer/ops/mapper/query_intent_detection_mapper.py) | [tests](../tests/ops/mapper/test_query_intent_detection_mapper.py) | +| query_intent_detection_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Mapper to predict user's intent label in query. | [code](../data_juicer/ops/mapper/query_intent_detection_mapper.py) | [tests](../tests/ops/mapper/test_query_intent_detection_mapper.py) | | query_sentiment_detection_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Mapper to predict user's sentiment label ('negative', 'neutral' and 'positive') in query. | [code](../data_juicer/ops/mapper/query_sentiment_detection_mapper.py) | [tests](../tests/ops/mapper/test_query_sentiment_detection_mapper.py) | +| query_topic_detection_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Mapper to predict user's topic label in query. | [code](../data_juicer/ops/mapper/query_topic_detection_mapper.py) | [tests](../tests/ops/mapper/test_query_topic_detection_mapper.py) | | relation_identity_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Identify relation between two entity in the text. | [code](../data_juicer/ops/mapper/relation_identity_mapper.py) | [tests](../tests/ops/mapper/test_relation_identity_mapper.py) | | remove_bibliography_mapper | ![LaTeX](https://img.shields.io/badge/LaTeX-D99379?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Removes the bibliography of TeX documents | [code](../data_juicer/ops/mapper/remove_bibliography_mapper.py) | [tests](../tests/ops/mapper/test_remove_bibliography_mapper.py) | | remove_comments_mapper | ![LaTeX](https://img.shields.io/badge/LaTeX-D99379?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Removes the comments of TeX documents | [code](../data_juicer/ops/mapper/remove_comments_mapper.py) | [tests](../tests/ops/mapper/test_remove_comments_mapper.py) | diff --git a/docs/Operators_ZH.md b/docs/Operators_ZH.md index a68ad56e8..749459ad7 100644 --- a/docs/Operators_ZH.md +++ b/docs/Operators_ZH.md @@ -11,7 +11,7 @@ Data-Juicer 中的算子分为以下 5 种类型。 | 类型 | 数量 | 描述 | |------------------------------------|:--:|---------------| | [ Formatter ]( #formatter ) | 9 | 发现、加载、规范化原始数据 | -| [ Mapper ]( #mapper ) | 68 | 对数据样本进行编辑和转换 | +| [ Mapper ]( #mapper ) | 70 | 对数据样本进行编辑和转换 | | [ Filter ]( #filter ) | 44 | 过滤低质量样本 | | [ Deduplicator ]( #deduplicator ) | 8 | 识别、删除重复样本 | | [ Selector ]( #selector ) | 4 | 基于排序选取高质量样本 | @@ -70,6 +70,7 @@ Data-Juicer 中的算子分为以下 5 种类型。 | dialog_intent_detection_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 抽取对话中的用户意图标签。 | [code](../data_juicer/ops/mapper/dialog_intent_detection_mapper.py) | [tests](../tests/ops/mapper/test_dialog_intent_detection_mapper.py) | | dialog_sentiment_detection_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 抽取对话中用户的情感标签 | [code](../data_juicer/ops/mapper/dialog_sentiment_detection_mapper.py) | [tests](../tests/ops/mapper/test_dialog_sentiment_detection_mapper.py) | | dialog_sentiment_intensity_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 预测对话中的情绪强度(默认从-5到5)。 | [code](../data_juicer/ops/mapper/dialog_sentiment_intensity_mapper.py) | [tests](../tests/ops/mapper/test_dialog_sentiment_intensity_mapper.py) | +| dialog_topic_detection_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 抽取对话中的用户的话题标签。 | [code](../data_juicer/ops/mapper/dialog_topic_detection_mapper.py) | [tests](../tests/ops/mapper/test_dialog_topic_detection_mapper.py) | | expand_macro_mapper | ![LaTeX](https://img.shields.io/badge/LaTeX-D99379?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 扩展通常在 TeX 文档顶部定义的宏 | [code](../data_juicer/ops/mapper/expand_macro_mapper.py) | [tests](../tests/ops/mapper/test_expand_macro_mapper.py) | | extract_entity_attribute_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 给定主体和属性名,从文本中抽取主体的属性 | [code](../data_juicer/ops/mapper/extract_entity_attribute_mapper.py) | [tests](../tests/ops/mapper/test_extract_entity_attribute_mapper.py) | | extract_entity_relation_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 从文本中抽取知识图谱的实体和关系 | [code](../data_juicer/ops/mapper/extract_entity_relation_mapper.py) | [tests](../tests/ops/mapper/test_extract_entity_relation_mapper.py) | @@ -97,6 +98,7 @@ Data-Juicer 中的算子分为以下 5 种类型。 | python_lambda_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 执行 Python lambda 函数处理样本 | [code](../data_juicer/ops/mapper/python_lambda_mapper.py) | [tests](../tests/ops/mapper/test_python_lambda_mapper.py) | | query_intent_detection_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 预测用户查询中的意图标签。 | [code](../data_juicer/ops/mapper/query_intent_detection_mapper.py) | [tests](../tests/ops/mapper/test_query_intent_detection_mapper.py) | | query_sentiment_detection_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 预测用户查询中的情感强度标签('negative'、'neutral'和'positive')。 | [code](../data_juicer/ops/mapper/query_sentiment_detection_mapper.py) | [tests](../tests/ops/mapper/test_query_sentiment_detection_mapper.py) | +| query_topic_detection_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 预测用户查询中的话题标签。 | [code](../data_juicer/ops/mapper/query_topic_detection_mapper.py) | [tests](../tests/ops/mapper/test_query_topic_detection_mapper.py) | | relation_identity_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 识别一段文本中两个实体之间的关系 | [code](../data_juicer/ops/mapper/relation_identity_mapper.py) | [tests](../tests/ops/mapper/test_relation_identity_mapper.py) | | remove_bibliography_mapper | ![LaTeX](https://img.shields.io/badge/LaTeX-D99379?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 删除 TeX 文档的参考文献 | [code](../data_juicer/ops/mapper/remove_bibliography_mapper.py) | [tests](../tests/ops/mapper/test_remove_bibliography_mapper.py) | | remove_comments_mapper | ![LaTeX](https://img.shields.io/badge/LaTeX-D99379?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 删除 TeX 文档中的注释 | [code](../data_juicer/ops/mapper/remove_comments_mapper.py) | [tests](../tests/ops/mapper/test_remove_comments_mapper.py) | diff --git a/tests/ops/mapper/test_query_topic_detection_mapper.py b/tests/ops/mapper/test_query_topic_detection_mapper.py index 9cbfc3069..6304290c7 100644 --- a/tests/ops/mapper/test_query_topic_detection_mapper.py +++ b/tests/ops/mapper/test_query_topic_detection_mapper.py @@ -12,8 +12,8 @@ class TestQueryTopicDetectionMapper(DataJuicerTestCaseBase): - hf_model = '/mnt/workspace/shared/checkpoints/huggingface/dstefa/roberta-base_topic_classification_nyt_news' - zh_to_en_hf_model = '/mnt/workspace/shared/checkpoints/huggingface/Helsinki-NLP/opus-mt-zh-en' + hf_model = 'dstefa/roberta-base_topic_classification_nyt_news' + zh_to_en_hf_model = 'Helsinki-NLP/opus-mt-zh-en' def _run_op(self, op, samples, label_key, targets): dataset = Dataset.from_list(samples) From b02745b131fa89b18cef6d41349f6a3892c8b353 Mon Sep 17 00:00:00 2001 From: Haibin <1400012807@pku.edu.cn> Date: Thu, 19 Dec 2024 15:46:32 +0800 Subject: [PATCH 100/118] meta tags aggregator --- data_juicer/ops/aggregator/__init__.py | 3 +- .../aggregator/entity_attribute_aggregator.py | 4 - .../ops/aggregator/meta_tags_aggregator.py | 211 ++++++++++++++++++ .../most_relavant_entities_aggregator.py | 4 - .../ops/aggregator/nested_aggregator.py | 4 - .../mapper/dialog_intent_detection_mapper.py | 4 +- .../dialog_sentiment_detection_mapper.py | 4 +- .../dialog_sentiment_intensity_mapper.py | 4 +- .../mapper/dialog_topic_detection_mapper.py | 4 +- .../mapper/query_intent_detection_mapper.py | 2 +- .../query_sentiment_detection_mapper.py | 4 +- .../mapper/query_topic_detection_mapper.py | 2 +- data_juicer/utils/constant.py | 32 +-- .../Aggregator/test_meta_tags_aggregator.py | 66 ++++++ 14 files changed, 307 insertions(+), 41 deletions(-) create mode 100644 data_juicer/ops/aggregator/meta_tags_aggregator.py create mode 100644 tests/ops/Aggregator/test_meta_tags_aggregator.py diff --git a/data_juicer/ops/aggregator/__init__.py b/data_juicer/ops/aggregator/__init__.py index 4afe2974a..8aa87cbbd 100644 --- a/data_juicer/ops/aggregator/__init__.py +++ b/data_juicer/ops/aggregator/__init__.py @@ -1,8 +1,9 @@ from .entity_attribute_aggregator import EntityAttributeAggregator +from .meta_tags_aggregator import MetaTagsAggregator from .most_relavant_entities_aggregator import MostRelavantEntitiesAggregator from .nested_aggregator import NestedAggregator __all__ = [ - 'NestedAggregator', 'EntityAttributeAggregator', + 'NestedAggregator', 'MetaTagsAggregator', 'EntityAttributeAggregator', 'MostRelavantEntitiesAggregator' ] diff --git a/data_juicer/ops/aggregator/entity_attribute_aggregator.py b/data_juicer/ops/aggregator/entity_attribute_aggregator.py index 96fbbb63f..16ec5fd07 100644 --- a/data_juicer/ops/aggregator/entity_attribute_aggregator.py +++ b/data_juicer/ops/aggregator/entity_attribute_aggregator.py @@ -8,14 +8,10 @@ from data_juicer.utils.common_utils import (avg_split_string_list_under_limit, is_string_list, nested_access, nested_set) -from data_juicer.utils.lazy_loader import LazyLoader from data_juicer.utils.model_utils import get_model, prepare_model from .nested_aggregator import NestedAggregator -torch = LazyLoader('torch', 'torch') -vllm = LazyLoader('vllm', 'vllm') - OP_NAME = 'entity_attribute_aggregator' diff --git a/data_juicer/ops/aggregator/meta_tags_aggregator.py b/data_juicer/ops/aggregator/meta_tags_aggregator.py new file mode 100644 index 000000000..d66e5243f --- /dev/null +++ b/data_juicer/ops/aggregator/meta_tags_aggregator.py @@ -0,0 +1,211 @@ +import re +from typing import Dict, List, Optional + +from loguru import logger +from pydantic import PositiveInt + +from data_juicer.ops.base_op import OPERATORS, Aggregator +from data_juicer.utils.common_utils import is_string_list +from data_juicer.utils.constant import Fields, MetaKeys +from data_juicer.utils.model_utils import get_model, prepare_model + +OP_NAME = 'meta_tags_aggregator' + + +# TODO: LLM-based inference. +@OPERATORS.register_module(OP_NAME) +class MetaTagsAggregator(Aggregator): + """ + Merge similar meta tags to one tags. + """ + + DEFAULT_SYSTEM_PROMPT = ('给定一些标签以及这些标签出现的频次,合并意思相近的标签。\n' + '要求:\n' + '- 任务分为两种情况,一种是给定合并后的标签,需要将合并前的标签映射到' + '这些标签。如果给定的合并后的标签中有类似“其他”这种标签,将无法归类的' + '标签合并到“其他”。以下是这种情况的一个样例:\n' + '合并后的标签应限定在[科技, 健康, 其他]中。\n' + '| 合并前标签 | 频次 |\n' + '| ------ | ------ |\n' + '| 医疗 | 20 |\n' + '| 信息技术 | 16 |\n' + '| 学习 | 19 |\n' + '| 气候变化 | 22 |\n' + '| 人工智能 | 11 |\n' + '| 养生 | 17 |\n' + '| 科学创新 | 10 |\n' + '\n' + '分析:“信息技术”、“人工智能”、“科学创新”都属于“科技”类别,“医疗' + '”和“养生”跟“健康”有关联,“学习”、“气候变化”和“科技”还有“健康”关' + '联不强,应该被归为“其他”。\n' + '标签合并:\n' + '医疗 -> 健康\n' + '信息技术 -> 科技\n' + '学习 -> 其他\n' + '气候变化 -> 其他\n' + '人工智能 -> 科技\n' + '养生 -> 健康\n' + '科学创新 -> 科技\n' + '- 另外一种情况没有事先给定合并后的标签,需要生成合理的标签类别,如果有' + '一些比较特别且频次比较低的标签,可以统一归到“其他”这个标签中:' + '| 合并前标签 | 频次 |\n' + '| ------ | ------ |\n' + '| 医疗 | 20 |\n' + '| 信息技术 | 16 |\n' + '| 学习 | 2 |\n' + '| 气候变化 | 1 |\n' + '| 人工智能 | 11 |\n' + '| 养生 | 17 |\n' + '| 科学创新 | 10 |\n' + '\n' + '分析:“信息技术”、“人工智能”、“科学创新”这三个标签比较相近,归为' + '同一类,都属于“科技”类别,“医疗”和“养生”都跟“健康”有关系,可以归' + '类为“健康”,“学习”和“气候变化”跟其他标签关联度不强,且频次较低,' + '统一归类为“其他”。\n' + '标签合并:\n' + '医疗 -> 健康\n' + '信息技术 -> 科技\n' + '学习 -> 其他\n' + '气候变化 -> 其他\n' + '人工智能 -> 科技\n' + '养生 -> 健康\n' + '科学创新 -> 科技\n') + + DEFAULT_INPUT_TEMPLATE = ('| 合并前标签 | 频次 |\n' + '| ------ | ------ |\n' + '{tag_strs}') + DEFAULT_TAG_TEMPLATE = '| {tag} | {cnt} |' + + DEFAULT_OUTPUT_PATTERN = r'\n\s*(.*?)\s*->\s*(.*?)\s*(\Z|\n)' + + def __init__(self, + api_model: str = 'gpt-4o', + meta_tag_key: str = MetaKeys.dialog_sentiment_labels, + target_tags: Optional[List[str]] = None, + *, + api_endpoint: Optional[str] = None, + response_path: Optional[str] = None, + system_prompt: Optional[str] = None, + input_template: Optional[str] = None, + tag_template: Optional[str] = None, + output_pattern: Optional[str] = None, + try_num: PositiveInt = 3, + model_params: Dict = {}, + sampling_params: Dict = {}, + **kwargs): + """ + Initialization method. + :param api_model: API model name. + :param meta_tag_key: The key of the meta tag to be mapped. + :param target_tags: The tags that is supposed to be mapped to. + :param api_endpoint: URL endpoint for the API. + :param response_path: Path to extract content from the API response. + Defaults to 'choices.0.message.content'. + :param system_prompt: The system prompt. + :param input_template: The input template. + :param tag_template: The tap template for each tag and its + frequency. + :param output_pattern: The output pattern. + :param try_num: The number of retry attempts when there is an API + call error or output parsing error. + :param model_params: Parameters for initializing the API model. + :param sampling_params: Extra parameters passed to the API call. + e.g {'temperature': 0.9, 'top_p': 0.95} + :param kwargs: Extra keyword arguments. + """ + super().__init__(**kwargs) + + self.meta_tag_key = meta_tag_key + self.target_tags = target_tags + + self.system_prompt = system_prompt or self.DEFAULT_SYSTEM_PROMPT + self.input_template = input_template or self.DEFAULT_INPUT_TEMPLATE + self.tag_template = tag_template or self.DEFAULT_TAG_TEMPLATE + self.output_pattern = output_pattern or self.DEFAULT_OUTPUT_PATTERN + + self.sampling_params = sampling_params + self.model_key = prepare_model(model_type='api', + model=api_model, + endpoint=api_endpoint, + response_path=response_path, + return_processor=True, + **model_params) + + self.try_num = try_num + + def parse_output(self, response): + pattern = re.compile(self.output_pattern, re.VERBOSE | re.DOTALL) + matches = pattern.findall(response) + tag_map = {tag1: tag2 for tag1, tag2 in matches} + return tag_map + + def meta_map(self, meta_cnts, rank=None): + + model, _ = get_model(self.model_key, rank, self.use_cuda()) + tag_strs = [ + self.tag_template.format(tag=k, cnt=meta_cnts[k]) + for k in meta_cnts + ] + + input_prompt = self.input_template.format(tag_strs='\n'.join(tag_strs)) + + messages = [{ + 'role': 'system', + 'content': self.system_prompt + }, { + 'role': 'user', + 'content': input_prompt + }] + tag_map = {} + for i in range(self.try_num): + try: + response = model(messages, **self.sampling_params) + tag_map = self.parse_output(response) + if len(tag_map) > 0: + break + except Exception as e: + logger.warning(f'Exception: {e}') + + return tag_map + + def process_single(self, sample=None, rank=None): + + if Fields.meta not in sample: + logger.warning('Not any meta in the sample!') + return sample + + metas = sample[Fields.meta] + # if not batched sample + if not isinstance(metas, list): + logger.warning('Not a batched sample!') + return sample + + meta_cnts = {} + + def update_dict(key): + if key in meta_cnts: + meta_cnts[key] += 1 + else: + meta_cnts[key] = 1 + + for meta in metas: + tag = meta[self.meta_tag_key] + if isinstance(tag, str): + update_dict(tag) + elif is_string_list(tag): + for t in tag: + update_dict(t) + else: + logger.warning('Meta tag must be string or list of string!') + return sample + + tag_map = self.meta_map(meta_cnts, rank=rank) + for i in range(len(metas)): + if isinstance(metas[i], str) and metas[i] in tag_map: + metas[i] = tag_map[metas[i]] + elif is_string_list(metas[i]): + metas[i] = [ + tag_map[t] if t in tag_map else t for t in metas[i] + ] + + return sample diff --git a/data_juicer/ops/aggregator/most_relavant_entities_aggregator.py b/data_juicer/ops/aggregator/most_relavant_entities_aggregator.py index 69e1a209c..7ca49f505 100644 --- a/data_juicer/ops/aggregator/most_relavant_entities_aggregator.py +++ b/data_juicer/ops/aggregator/most_relavant_entities_aggregator.py @@ -7,14 +7,10 @@ from data_juicer.ops.base_op import OPERATORS, Aggregator from data_juicer.utils.common_utils import (is_string_list, nested_access, nested_set) -from data_juicer.utils.lazy_loader import LazyLoader from data_juicer.utils.model_utils import get_model, prepare_model from ..common import split_text_by_punctuation -torch = LazyLoader('torch', 'torch') -vllm = LazyLoader('vllm', 'vllm') - OP_NAME = 'most_relavant_entities_aggregator' diff --git a/data_juicer/ops/aggregator/nested_aggregator.py b/data_juicer/ops/aggregator/nested_aggregator.py index 124eb1470..ab25e057d 100644 --- a/data_juicer/ops/aggregator/nested_aggregator.py +++ b/data_juicer/ops/aggregator/nested_aggregator.py @@ -6,12 +6,8 @@ from data_juicer.ops.base_op import OPERATORS, Aggregator from data_juicer.utils.common_utils import (avg_split_string_list_under_limit, is_string_list, nested_access) -from data_juicer.utils.lazy_loader import LazyLoader from data_juicer.utils.model_utils import get_model, prepare_model -torch = LazyLoader('torch', 'torch') -vllm = LazyLoader('vllm', 'vllm') - OP_NAME = 'nested_aggregator' diff --git a/data_juicer/ops/mapper/dialog_intent_detection_mapper.py b/data_juicer/ops/mapper/dialog_intent_detection_mapper.py index 4e40c86f4..7c8cba9ed 100644 --- a/data_juicer/ops/mapper/dialog_intent_detection_mapper.py +++ b/data_juicer/ops/mapper/dialog_intent_detection_mapper.py @@ -19,8 +19,8 @@ class DialogIntentDetectionMapper(Mapper): Mapper to generate user's intent labels in dialog. Input from history_key, query_key and response_key. Output lists of labels and analysis for queries in the dialog, which is - store in 'intent.dialog_labels' and - 'intent.dialog_labels_analysis' in Data-Juicer meta field. + store in 'dialog_intent_labels' and + 'dialog_intent_labels_analysis' in Data-Juicer meta field. """ DEFAULT_SYSTEM_PROMPT = ( diff --git a/data_juicer/ops/mapper/dialog_sentiment_detection_mapper.py b/data_juicer/ops/mapper/dialog_sentiment_detection_mapper.py index 5417e5974..33bccc5ce 100644 --- a/data_juicer/ops/mapper/dialog_sentiment_detection_mapper.py +++ b/data_juicer/ops/mapper/dialog_sentiment_detection_mapper.py @@ -19,8 +19,8 @@ class DialogSentimentDetectionMapper(Mapper): Mapper to generate user's sentiment labels in dialog. Input from history_key, query_key and response_key. Output lists of labels and analysis for queries in the dialog, which is - store in 'sentiment.dialog_labels' and - 'sentiment.dialog_labels_analysis' in Data-Juicer meta field. + store in 'dialog_sentiment_labels' and + 'dialog_sentiment_labels_analysis' in Data-Juicer meta field. """ DEFAULT_SYSTEM_PROMPT = ('请判断用户和LLM多轮对话中用户所具有的情绪。\n' diff --git a/data_juicer/ops/mapper/dialog_sentiment_intensity_mapper.py b/data_juicer/ops/mapper/dialog_sentiment_intensity_mapper.py index 56372c6f3..198314ee3 100644 --- a/data_juicer/ops/mapper/dialog_sentiment_intensity_mapper.py +++ b/data_juicer/ops/mapper/dialog_sentiment_intensity_mapper.py @@ -19,8 +19,8 @@ class DialogSentimentIntensityMapper(Mapper): Mapper to predict user's sentiment intensity (from -5 to 5 in default prompt) in dialog. Input from history_key, query_key and response_key. Output lists of intensities and analysis for queries in - the dialog, which is store in 'sentiment.dialog_intensity' and - 'sentiment.dialog_intensity_analysis' in Data-Juicer meta field. + the dialog, which is store in 'dialog_sentiment_intensity' and + 'dialog_sentiment_intensity_analysis' in Data-Juicer meta field. """ DEFAULT_SYSTEM_PROMPT = ('请判断用户和LLM多轮对话中用户的情绪变化。\n' diff --git a/data_juicer/ops/mapper/dialog_topic_detection_mapper.py b/data_juicer/ops/mapper/dialog_topic_detection_mapper.py index 60db0f981..7e8ee0b54 100644 --- a/data_juicer/ops/mapper/dialog_topic_detection_mapper.py +++ b/data_juicer/ops/mapper/dialog_topic_detection_mapper.py @@ -19,8 +19,8 @@ class DialogTopicDetectionMapper(Mapper): Mapper to generate user's topic labels in dialog. Input from history_key, query_key and response_key. Output lists of labels and analysis for queries in the dialog, which is - store in 'sentiment.dialog_labels' and - 'sentiment.dialog_labels_analysis' in Data-Juicer meta field. + store in 'dialog_sentiment_labels' and + 'dialog_sentiment_labels_analysis' in Data-Juicer meta field. """ DEFAULT_SYSTEM_PROMPT = ('请判断用户和LLM多轮对话中用户所讨论的话题。\n' diff --git a/data_juicer/ops/mapper/query_intent_detection_mapper.py b/data_juicer/ops/mapper/query_intent_detection_mapper.py index b6e69f610..b0d240e2d 100644 --- a/data_juicer/ops/mapper/query_intent_detection_mapper.py +++ b/data_juicer/ops/mapper/query_intent_detection_mapper.py @@ -14,7 +14,7 @@ class QueryIntentDetectionMapper(Mapper): """ Mapper to predict user's Intent label in query. Input from query_key. Output intent label and corresponding score for the query, which is - store in 'intent.query_label' and 'intent.query_label_score' in + store in 'query_intent_label' and 'query_intent_label_score' in Data-Juicer meta field. """ diff --git a/data_juicer/ops/mapper/query_sentiment_detection_mapper.py b/data_juicer/ops/mapper/query_sentiment_detection_mapper.py index ef49a344a..634bdeab3 100644 --- a/data_juicer/ops/mapper/query_sentiment_detection_mapper.py +++ b/data_juicer/ops/mapper/query_sentiment_detection_mapper.py @@ -15,8 +15,8 @@ class QuerySentimentDetectionMapper(Mapper): Mapper to predict user's sentiment label ('negative', 'neutral' and 'positive') in query. Input from query_key. Output label and corresponding score for the query, which is - store in 'sentiment.query_label' and - 'sentiment.query_label_score' in Data-Juicer meta field. + store in 'query_sentiment_label' and + 'query_sentiment_label_score' in Data-Juicer meta field. """ _accelerator = 'cuda' diff --git a/data_juicer/ops/mapper/query_topic_detection_mapper.py b/data_juicer/ops/mapper/query_topic_detection_mapper.py index 577f00bad..8e5687ee3 100644 --- a/data_juicer/ops/mapper/query_topic_detection_mapper.py +++ b/data_juicer/ops/mapper/query_topic_detection_mapper.py @@ -14,7 +14,7 @@ class QueryTopicDetectionMapper(Mapper): """ Mapper to predict user's topic label in query. Input from query_key. Output topic label and corresponding score for the query, which is - store in 'topic.query_label' and 'topic.query_label_score' in + store in 'query_topic_label' and 'query_topic_label_score' in Data-Juicer meta field. """ diff --git a/data_juicer/utils/constant.py b/data_juicer/utils/constant.py index de302ab5a..ba693f63e 100644 --- a/data_juicer/utils/constant.py +++ b/data_juicer/utils/constant.py @@ -71,22 +71,22 @@ class Fields(object): class MetaKeys(object): - dialog_sentiment_intensity = 'sentiment.dialog_intensity' - dialog_sentiment_intensity_analysis = 'sentiment.dialog_intensity_analysis' - query_sentiment_label = 'sentiment.query_label' - query_sentiment_score = 'sentiment.query_label_score' - dialog_sentiment_labels = 'sentiment.dialog_labels' - dialog_sentiment_labels_analysis = 'sentiment.dialog_labels_analysis' - - dialog_intent_labels = 'intent.dialog_labels' - dialog_intent_labels_analysis = 'intent.dialog_labels_analysis' - query_intent_label = 'intent.query_label' - query_intent_score = 'intent.query_label_score' - - dialog_topic_labels = 'topic.dialog_labels' - dialog_topic_labels_analysis = 'topic.dialog_labels_analysis' - query_topic_label = 'topic.query_label' - query_topic_score = 'topic.query_label_score' + dialog_sentiment_intensity = 'dialog_sentiment_intensity' + dialog_sentiment_intensity_analysis = 'dialog_sentiment_intensity_analysis' + query_sentiment_label = 'query_sentiment_label' + query_sentiment_score = 'query_sentiment_label_score' + dialog_sentiment_labels = 'dialog_sentiment_labels' + dialog_sentiment_labels_analysis = 'dialog_sentiment_labels_analysis' + + dialog_intent_labels = 'dialog_intent_labels' + dialog_intent_labels_analysis = 'dialog_intent_labels_analysis' + query_intent_label = 'query_intent_label' + query_intent_score = 'query_intent_label_score' + + dialog_topic_labels = 'dialog_topic_labels' + dialog_topic_labels_analysis = 'dialog_topic_labels_analysis' + query_topic_label = 'query_topic_label' + query_topic_score = 'query_topic_label_score' class StatsKeysMeta(type): diff --git a/tests/ops/Aggregator/test_meta_tags_aggregator.py b/tests/ops/Aggregator/test_meta_tags_aggregator.py new file mode 100644 index 000000000..6c715f4f6 --- /dev/null +++ b/tests/ops/Aggregator/test_meta_tags_aggregator.py @@ -0,0 +1,66 @@ +import unittest + +from loguru import logger + +from data_juicer.core.data import NestedDataset as Dataset +from data_juicer.ops.aggregator import MetaTagsAggregator +from data_juicer.utils.constant import Fields, MetaKeys +from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase, SKIPPED_TESTS + + +@SKIPPED_TESTS.register_module() +class MetaTagsAggregatorTest(DataJuicerTestCaseBase): + + def _run_helper(self, op, samples): + + # before runing this test, set below environment variables: + # export OPENAI_BASE_URL=https://dashscope.aliyuncs.com/compatible-mode/v1/ + # export OPENAI_API_KEY=your_dashscope_key + + dataset = Dataset.from_list(samples) + new_dataset = op.run(dataset) + + for data in new_dataset: + for k in data: + logger.info(f"{k}: {data[k]}") + + self.assertEqual(len(new_dataset), len(samples)) + + def test_default_aggregator(self): + samples = [ + { + Fields.meta: [ + { + MetaKeys.query_sentiment_label: '开心' + } + ], + Fields.meta: [ + { + MetaKeys.query_sentiment_label: '快乐' + } + ], + Fields.meta: [ + { + MetaKeys.query_sentiment_label: '难过' + } + ], + Fields.meta: [ + { + MetaKeys.query_sentiment_label: '不开心' + } + ], + Fields.meta: [ + { + MetaKeys.query_sentiment_label: '愤怒' + } + ] + }, + ] + op = MetaTagsAggregator( + api_model='qwen2.5-72b-instruct', + meta_tag_key=MetaKeys.query_sentiment_label, + ) + self._run_helper(op, samples) + +if __name__ == '__main__': + unittest.main() \ No newline at end of file From f2654f16cd8df26d6d0c7bdee96dc1bcb266aa73 Mon Sep 17 00:00:00 2001 From: Haibin <1400012807@pku.edu.cn> Date: Thu, 19 Dec 2024 15:54:47 +0800 Subject: [PATCH 101/118] meta tags aggregator --- .../ops/aggregator/meta_tags_aggregator.py | 2 ++ .../ops/Aggregator/test_meta_tags_aggregator.py | 16 ++++------------ 2 files changed, 6 insertions(+), 12 deletions(-) diff --git a/data_juicer/ops/aggregator/meta_tags_aggregator.py b/data_juicer/ops/aggregator/meta_tags_aggregator.py index d66e5243f..deda5e4aa 100644 --- a/data_juicer/ops/aggregator/meta_tags_aggregator.py +++ b/data_juicer/ops/aggregator/meta_tags_aggregator.py @@ -134,8 +134,10 @@ def __init__(self, self.try_num = try_num def parse_output(self, response): + print(response) pattern = re.compile(self.output_pattern, re.VERBOSE | re.DOTALL) matches = pattern.findall(response) + print(matches) tag_map = {tag1: tag2 for tag1, tag2 in matches} return tag_map diff --git a/tests/ops/Aggregator/test_meta_tags_aggregator.py b/tests/ops/Aggregator/test_meta_tags_aggregator.py index 6c715f4f6..6a686deb8 100644 --- a/tests/ops/Aggregator/test_meta_tags_aggregator.py +++ b/tests/ops/Aggregator/test_meta_tags_aggregator.py @@ -32,24 +32,16 @@ def test_default_aggregator(self): Fields.meta: [ { MetaKeys.query_sentiment_label: '开心' - } - ], - Fields.meta: [ + }, { MetaKeys.query_sentiment_label: '快乐' - } - ], - Fields.meta: [ + }, { MetaKeys.query_sentiment_label: '难过' - } - ], - Fields.meta: [ + }, { MetaKeys.query_sentiment_label: '不开心' - } - ], - Fields.meta: [ + }, { MetaKeys.query_sentiment_label: '愤怒' } From 23e5d6f4937da7b3411d94929985fd1a920be607 Mon Sep 17 00:00:00 2001 From: Haibin <1400012807@pku.edu.cn> Date: Thu, 19 Dec 2024 15:57:13 +0800 Subject: [PATCH 102/118] meta tags aggregator --- data_juicer/ops/aggregator/meta_tags_aggregator.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/data_juicer/ops/aggregator/meta_tags_aggregator.py b/data_juicer/ops/aggregator/meta_tags_aggregator.py index deda5e4aa..58a5dc018 100644 --- a/data_juicer/ops/aggregator/meta_tags_aggregator.py +++ b/data_juicer/ops/aggregator/meta_tags_aggregator.py @@ -46,8 +46,7 @@ class MetaTagsAggregator(Aggregator): '人工智能 -> 科技\n' '养生 -> 健康\n' '科学创新 -> 科技\n' - '- 另外一种情况没有事先给定合并后的标签,需要生成合理的标签类别,如果有' - '一些比较特别且频次比较低的标签,可以统一归到“其他”这个标签中:' + '- 另外一种情况没有事先给定合并后的标签,需要生成合理的标签类别:' '| 合并前标签 | 频次 |\n' '| ------ | ------ |\n' '| 医疗 | 20 |\n' From 1c747096deb6a7b30a8eb936bfa1018def28d656 Mon Sep 17 00:00:00 2001 From: Haibin <1400012807@pku.edu.cn> Date: Thu, 19 Dec 2024 16:06:10 +0800 Subject: [PATCH 103/118] meta tags aggregator --- .../ops/aggregator/meta_tags_aggregator.py | 38 +++++++++---------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/data_juicer/ops/aggregator/meta_tags_aggregator.py b/data_juicer/ops/aggregator/meta_tags_aggregator.py index 58a5dc018..454602657 100644 --- a/data_juicer/ops/aggregator/meta_tags_aggregator.py +++ b/data_juicer/ops/aggregator/meta_tags_aggregator.py @@ -35,17 +35,17 @@ class MetaTagsAggregator(Aggregator): '| 养生 | 17 |\n' '| 科学创新 | 10 |\n' '\n' - '分析:“信息技术”、“人工智能”、“科学创新”都属于“科技”类别,“医疗' + '## 分析:“信息技术”、“人工智能”、“科学创新”都属于“科技”类别,“医疗' '”和“养生”跟“健康”有关联,“学习”、“气候变化”和“科技”还有“健康”关' '联不强,应该被归为“其他”。\n' - '标签合并:\n' - '医疗 -> 健康\n' - '信息技术 -> 科技\n' - '学习 -> 其他\n' - '气候变化 -> 其他\n' - '人工智能 -> 科技\n' - '养生 -> 健康\n' - '科学创新 -> 科技\n' + '## 标签合并:\n' + '医疗归类为健康\n' + '信息技术归类为科技\n' + '学习归类为其他\n' + '气候变化归类为其他\n' + '人工智能归类为科技\n' + '养生归类为健康\n' + '科学创新归类为科技\n' '- 另外一种情况没有事先给定合并后的标签,需要生成合理的标签类别:' '| 合并前标签 | 频次 |\n' '| ------ | ------ |\n' @@ -57,25 +57,25 @@ class MetaTagsAggregator(Aggregator): '| 养生 | 17 |\n' '| 科学创新 | 10 |\n' '\n' - '分析:“信息技术”、“人工智能”、“科学创新”这三个标签比较相近,归为' + '## 分析:“信息技术”、“人工智能”、“科学创新”这三个标签比较相近,归为' '同一类,都属于“科技”类别,“医疗”和“养生”都跟“健康”有关系,可以归' '类为“健康”,“学习”和“气候变化”跟其他标签关联度不强,且频次较低,' '统一归类为“其他”。\n' - '标签合并:\n' - '医疗 -> 健康\n' - '信息技术 -> 科技\n' - '学习 -> 其他\n' - '气候变化 -> 其他\n' - '人工智能 -> 科技\n' - '养生 -> 健康\n' - '科学创新 -> 科技\n') + '## 标签合并:\n' + '医疗归类为健康\n' + '信息技术归类为科技\n' + '学习归类为其他\n' + '气候变化归类为其他\n' + '人工智能归类为科技\n' + '养生归类为健康\n' + '科学创新归类为科技\n') DEFAULT_INPUT_TEMPLATE = ('| 合并前标签 | 频次 |\n' '| ------ | ------ |\n' '{tag_strs}') DEFAULT_TAG_TEMPLATE = '| {tag} | {cnt} |' - DEFAULT_OUTPUT_PATTERN = r'\n\s*(.*?)\s*->\s*(.*?)\s*(\Z|\n)' + DEFAULT_OUTPUT_PATTERN = r'\n(.*?)归类为(.*?)(\Z|\n)' def __init__(self, api_model: str = 'gpt-4o', From a9977262e4a51a5a76b0c26a085c74bc9b1f052b Mon Sep 17 00:00:00 2001 From: Haibin <1400012807@pku.edu.cn> Date: Thu, 19 Dec 2024 16:14:19 +0800 Subject: [PATCH 104/118] meta tags aggregator --- data_juicer/ops/aggregator/meta_tags_aggregator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/data_juicer/ops/aggregator/meta_tags_aggregator.py b/data_juicer/ops/aggregator/meta_tags_aggregator.py index 454602657..a24025392 100644 --- a/data_juicer/ops/aggregator/meta_tags_aggregator.py +++ b/data_juicer/ops/aggregator/meta_tags_aggregator.py @@ -75,7 +75,7 @@ class MetaTagsAggregator(Aggregator): '{tag_strs}') DEFAULT_TAG_TEMPLATE = '| {tag} | {cnt} |' - DEFAULT_OUTPUT_PATTERN = r'\n(.*?)归类为(.*?)(\Z|\n)' + DEFAULT_OUTPUT_PATTERN = r'\n(\w+)归类为(\w+)($|\n)' def __init__(self, api_model: str = 'gpt-4o', From 26428479e8f402753e5577d18ad9a5eedfcbdff9 Mon Sep 17 00:00:00 2001 From: Haibin <1400012807@pku.edu.cn> Date: Thu, 19 Dec 2024 16:17:58 +0800 Subject: [PATCH 105/118] meta tags aggregator --- .../ops/aggregator/meta_tags_aggregator.py | 30 +++++++++---------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/data_juicer/ops/aggregator/meta_tags_aggregator.py b/data_juicer/ops/aggregator/meta_tags_aggregator.py index a24025392..b47528d7a 100644 --- a/data_juicer/ops/aggregator/meta_tags_aggregator.py +++ b/data_juicer/ops/aggregator/meta_tags_aggregator.py @@ -39,13 +39,13 @@ class MetaTagsAggregator(Aggregator): '”和“养生”跟“健康”有关联,“学习”、“气候变化”和“科技”还有“健康”关' '联不强,应该被归为“其他”。\n' '## 标签合并:\n' - '医疗归类为健康\n' - '信息技术归类为科技\n' - '学习归类为其他\n' - '气候变化归类为其他\n' - '人工智能归类为科技\n' - '养生归类为健康\n' - '科学创新归类为科技\n' + '** 医疗归类为健康 **\n' + '** 信息技术归类为科技 **\n' + '** 学习归类为其他 **\n' + '** 气候变化归类为其他 **\n' + '** 人工智能归类为科技 **\n' + '** 养生归类为健康 **\n' + '** 科学创新归类为科技 **\n' '- 另外一种情况没有事先给定合并后的标签,需要生成合理的标签类别:' '| 合并前标签 | 频次 |\n' '| ------ | ------ |\n' @@ -62,20 +62,20 @@ class MetaTagsAggregator(Aggregator): '类为“健康”,“学习”和“气候变化”跟其他标签关联度不强,且频次较低,' '统一归类为“其他”。\n' '## 标签合并:\n' - '医疗归类为健康\n' - '信息技术归类为科技\n' - '学习归类为其他\n' - '气候变化归类为其他\n' - '人工智能归类为科技\n' - '养生归类为健康\n' - '科学创新归类为科技\n') + '** 医疗归类为健康 **\n' + '** 信息技术归类为科技 **\n' + '** 学习归类为其他 **\n' + '** 气候变化归类为其他 **\n' + '** 人工智能归类为科技 **\n' + '** 养生归类为健康 **\n' + '** 科学创新归类为科技 **\n') DEFAULT_INPUT_TEMPLATE = ('| 合并前标签 | 频次 |\n' '| ------ | ------ |\n' '{tag_strs}') DEFAULT_TAG_TEMPLATE = '| {tag} | {cnt} |' - DEFAULT_OUTPUT_PATTERN = r'\n(\w+)归类为(\w+)($|\n)' + DEFAULT_OUTPUT_PATTERN = r'\*\*\s*(\w+)归类为(\w+)\s*\*\*' def __init__(self, api_model: str = 'gpt-4o', From 2dae3b86ee28326aa70227f39ad42cb5962987d6 Mon Sep 17 00:00:00 2001 From: Haibin <1400012807@pku.edu.cn> Date: Thu, 19 Dec 2024 16:19:44 +0800 Subject: [PATCH 106/118] meta tags aggregator --- data_juicer/ops/aggregator/meta_tags_aggregator.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/data_juicer/ops/aggregator/meta_tags_aggregator.py b/data_juicer/ops/aggregator/meta_tags_aggregator.py index b47528d7a..02f67e9d0 100644 --- a/data_juicer/ops/aggregator/meta_tags_aggregator.py +++ b/data_juicer/ops/aggregator/meta_tags_aggregator.py @@ -203,9 +203,9 @@ def update_dict(key): tag_map = self.meta_map(meta_cnts, rank=rank) for i in range(len(metas)): if isinstance(metas[i], str) and metas[i] in tag_map: - metas[i] = tag_map[metas[i]] + sample[Fields.meta][i] = tag_map[metas[i]] elif is_string_list(metas[i]): - metas[i] = [ + sample[Fields.meta][i] = [ tag_map[t] if t in tag_map else t for t in metas[i] ] From 8bb250989d59aca8861fe3a72aa8aee7f9d3447f Mon Sep 17 00:00:00 2001 From: Haibin <1400012807@pku.edu.cn> Date: Thu, 19 Dec 2024 16:22:24 +0800 Subject: [PATCH 107/118] meta tags aggregator --- data_juicer/ops/aggregator/meta_tags_aggregator.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/data_juicer/ops/aggregator/meta_tags_aggregator.py b/data_juicer/ops/aggregator/meta_tags_aggregator.py index 02f67e9d0..12518e255 100644 --- a/data_juicer/ops/aggregator/meta_tags_aggregator.py +++ b/data_juicer/ops/aggregator/meta_tags_aggregator.py @@ -202,11 +202,12 @@ def update_dict(key): tag_map = self.meta_map(meta_cnts, rank=rank) for i in range(len(metas)): - if isinstance(metas[i], str) and metas[i] in tag_map: - sample[Fields.meta][i] = tag_map[metas[i]] - elif is_string_list(metas[i]): - sample[Fields.meta][i] = [ - tag_map[t] if t in tag_map else t for t in metas[i] + tag = metas[i][self.meta_tag_key] + if isinstance(tag, str) and tag in tag_map: + metas[i][self.meta_tag_key] = tag_map[tag] + elif is_string_list(tag): + metas[i][self.meta_tag_key] = [ + tag_map[t] if t in tag_map else t for t in tag ] return sample From 90303ee2b4530a0c8edfd8cf59271c719c32e01f Mon Sep 17 00:00:00 2001 From: Haibin <1400012807@pku.edu.cn> Date: Thu, 19 Dec 2024 16:27:18 +0800 Subject: [PATCH 108/118] meta tags aggregator --- .../Aggregator/test_meta_tags_aggregator.py | 59 +++++++++++++++++++ 1 file changed, 59 insertions(+) diff --git a/tests/ops/Aggregator/test_meta_tags_aggregator.py b/tests/ops/Aggregator/test_meta_tags_aggregator.py index 6a686deb8..75d07df9c 100644 --- a/tests/ops/Aggregator/test_meta_tags_aggregator.py +++ b/tests/ops/Aggregator/test_meta_tags_aggregator.py @@ -53,6 +53,65 @@ def test_default_aggregator(self): meta_tag_key=MetaKeys.query_sentiment_label, ) self._run_helper(op, samples) + + + def test_target_tags(self): + samples = [ + { + Fields.meta: [ + { + MetaKeys.query_sentiment_label: '开心' + }, + { + MetaKeys.query_sentiment_label: '快乐' + }, + { + MetaKeys.query_sentiment_label: '难过' + }, + { + MetaKeys.query_sentiment_label: '不开心' + }, + { + MetaKeys.query_sentiment_label: '愤怒' + } + ] + }, + ] + op = MetaTagsAggregator( + api_model='qwen2.5-72b-instruct', + meta_tag_key=MetaKeys.query_sentiment_label, + target_tags=['开心', '难过', '其他'] + ) + self._run_helper(op, samples) + + def test_tag_list(self): + samples = [ + { + Fields.meta: [ + { + MetaKeys.dialog_sentiment_label: ['开心', '平静'] + }, + { + MetaKeys.dialog_sentiment_label: ['快乐', '开心', '幸福'] + }, + { + MetaKeys.dialog_sentiment_label: ['难过'] + }, + { + MetaKeys.dialog_sentiment_label: ['不开心', '没头脑', '不高兴'] + }, + { + MetaKeys.dialog_sentiment_label: ['愤怒', '愤慨'] + } + ] + }, + ] + op = MetaTagsAggregator( + api_model='qwen2.5-72b-instruct', + meta_tag_key=MetaKeys.dialog_sentiment_label, + target_tags=['开心', '难过', '其他'] + ) + self._run_helper(op, samples) if __name__ == '__main__': unittest.main() \ No newline at end of file From e4c6ff134f3a677224532579169f498291570de6 Mon Sep 17 00:00:00 2001 From: Haibin <1400012807@pku.edu.cn> Date: Thu, 19 Dec 2024 16:28:49 +0800 Subject: [PATCH 109/118] meta tags aggregator --- tests/ops/Aggregator/test_meta_tags_aggregator.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/ops/Aggregator/test_meta_tags_aggregator.py b/tests/ops/Aggregator/test_meta_tags_aggregator.py index 75d07df9c..7aba225ae 100644 --- a/tests/ops/Aggregator/test_meta_tags_aggregator.py +++ b/tests/ops/Aggregator/test_meta_tags_aggregator.py @@ -89,26 +89,26 @@ def test_tag_list(self): { Fields.meta: [ { - MetaKeys.dialog_sentiment_label: ['开心', '平静'] + MetaKeys.dialog_sentiment_labels: ['开心', '平静'] }, { - MetaKeys.dialog_sentiment_label: ['快乐', '开心', '幸福'] + MetaKeys.dialog_sentiment_labels: ['快乐', '开心', '幸福'] }, { - MetaKeys.dialog_sentiment_label: ['难过'] + MetaKeys.dialog_sentiment_labels: ['难过'] }, { - MetaKeys.dialog_sentiment_label: ['不开心', '没头脑', '不高兴'] + MetaKeys.dialog_sentiment_labels: ['不开心', '没头脑', '不高兴'] }, { - MetaKeys.dialog_sentiment_label: ['愤怒', '愤慨'] + MetaKeys.dialog_sentiment_labels: ['愤怒', '愤慨'] } ] }, ] op = MetaTagsAggregator( api_model='qwen2.5-72b-instruct', - meta_tag_key=MetaKeys.dialog_sentiment_label, + meta_tag_key=MetaKeys.dialog_sentiment_labels, target_tags=['开心', '难过', '其他'] ) self._run_helper(op, samples) From 12f8946b38aff7f606284dd1c554e3cf35dc8408 Mon Sep 17 00:00:00 2001 From: Haibin <1400012807@pku.edu.cn> Date: Thu, 19 Dec 2024 16:47:45 +0800 Subject: [PATCH 110/118] meta tags aggregator --- .../ops/aggregator/meta_tags_aggregator.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/data_juicer/ops/aggregator/meta_tags_aggregator.py b/data_juicer/ops/aggregator/meta_tags_aggregator.py index 12518e255..ec9f5201d 100644 --- a/data_juicer/ops/aggregator/meta_tags_aggregator.py +++ b/data_juicer/ops/aggregator/meta_tags_aggregator.py @@ -70,9 +70,11 @@ class MetaTagsAggregator(Aggregator): '** 养生归类为健康 **\n' '** 科学创新归类为科技 **\n') - DEFAULT_INPUT_TEMPLATE = ('| 合并前标签 | 频次 |\n' + DEFAULT_INPUT_TEMPLATE = ('{target_tag_str}' + '| 合并前标签 | 频次 |\n' '| ------ | ------ |\n' '{tag_strs}') + DEFAULT_TARGET_TAG_TEMPLATE = '合并后的标签应限定在[{target_tags}]中。\n' DEFAULT_TAG_TEMPLATE = '| {tag} | {cnt} |' DEFAULT_OUTPUT_PATTERN = r'\*\*\s*(\w+)归类为(\w+)\s*\*\*' @@ -86,6 +88,7 @@ def __init__(self, response_path: Optional[str] = None, system_prompt: Optional[str] = None, input_template: Optional[str] = None, + target_tag_template: Optional[str] = None, tag_template: Optional[str] = None, output_pattern: Optional[str] = None, try_num: PositiveInt = 3, @@ -102,6 +105,7 @@ def __init__(self, Defaults to 'choices.0.message.content'. :param system_prompt: The system prompt. :param input_template: The input template. + :param target_tag_template: The tap template for target tags. :param tag_template: The tap template for each tag and its frequency. :param output_pattern: The output pattern. @@ -115,13 +119,19 @@ def __init__(self, super().__init__(**kwargs) self.meta_tag_key = meta_tag_key - self.target_tags = target_tags self.system_prompt = system_prompt or self.DEFAULT_SYSTEM_PROMPT self.input_template = input_template or self.DEFAULT_INPUT_TEMPLATE + target_tag_template = target_tag_template or \ + self.DEFAULT_TARGET_TAG_TEMPLATE self.tag_template = tag_template or self.DEFAULT_TAG_TEMPLATE self.output_pattern = output_pattern or self.DEFAULT_OUTPUT_PATTERN + self.target_tag_str = '' + if target_tags: + self.target_tag_str = target_tag_template( + target_tags=', '.join(target_tags)) + self.sampling_params = sampling_params self.model_key = prepare_model(model_type='api', model=api_model, @@ -143,12 +153,13 @@ def parse_output(self, response): def meta_map(self, meta_cnts, rank=None): model, _ = get_model(self.model_key, rank, self.use_cuda()) + tag_strs = [ self.tag_template.format(tag=k, cnt=meta_cnts[k]) for k in meta_cnts ] - - input_prompt = self.input_template.format(tag_strs='\n'.join(tag_strs)) + input_prompt = self.input_template.format( + target_tag_str=self.target_tag_str, tag_strs='\n'.join(tag_strs)) messages = [{ 'role': 'system', From 09b159966f222aca1b96e38be72880f898103a54 Mon Sep 17 00:00:00 2001 From: Haibin <1400012807@pku.edu.cn> Date: Thu, 19 Dec 2024 16:48:49 +0800 Subject: [PATCH 111/118] meta tags aggregator --- data_juicer/ops/aggregator/meta_tags_aggregator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/data_juicer/ops/aggregator/meta_tags_aggregator.py b/data_juicer/ops/aggregator/meta_tags_aggregator.py index ec9f5201d..60c94924c 100644 --- a/data_juicer/ops/aggregator/meta_tags_aggregator.py +++ b/data_juicer/ops/aggregator/meta_tags_aggregator.py @@ -129,7 +129,7 @@ def __init__(self, self.target_tag_str = '' if target_tags: - self.target_tag_str = target_tag_template( + self.target_tag_str = target_tag_template.format( target_tags=', '.join(target_tags)) self.sampling_params = sampling_params From 203bc642fc75e06f2e6fc74540d4578ab42fd589 Mon Sep 17 00:00:00 2001 From: Haibin <1400012807@pku.edu.cn> Date: Thu, 19 Dec 2024 17:26:11 +0800 Subject: [PATCH 112/118] naive reverse grouper --- .../ops/aggregator/meta_tags_aggregator.py | 2 - data_juicer/ops/grouper/__init__.py | 3 +- .../ops/grouper/naive_reverse_grouper.py | 26 ++++++ .../ops/grouper/test_naive_reverse_grouper.py | 83 +++++++++++++++++++ 4 files changed, 111 insertions(+), 3 deletions(-) create mode 100644 data_juicer/ops/grouper/naive_reverse_grouper.py create mode 100644 tests/ops/grouper/test_naive_reverse_grouper.py diff --git a/data_juicer/ops/aggregator/meta_tags_aggregator.py b/data_juicer/ops/aggregator/meta_tags_aggregator.py index 60c94924c..a60d9096a 100644 --- a/data_juicer/ops/aggregator/meta_tags_aggregator.py +++ b/data_juicer/ops/aggregator/meta_tags_aggregator.py @@ -143,10 +143,8 @@ def __init__(self, self.try_num = try_num def parse_output(self, response): - print(response) pattern = re.compile(self.output_pattern, re.VERBOSE | re.DOTALL) matches = pattern.findall(response) - print(matches) tag_map = {tag1: tag2 for tag1, tag2 in matches} return tag_map diff --git a/data_juicer/ops/grouper/__init__.py b/data_juicer/ops/grouper/__init__.py index 048b305e4..f81ba6aec 100644 --- a/data_juicer/ops/grouper/__init__.py +++ b/data_juicer/ops/grouper/__init__.py @@ -1,4 +1,5 @@ from .key_value_grouper import KeyValueGrouper from .naive_grouper import NaiveGrouper +from .naive_reverse_grouper import NaiveReverseGrouper -__all__ = ['NaiveGrouper', 'KeyValueGrouper'] +__all__ = ['KeyValueGrouper', 'NaiveGrouper', 'NaiveReverseGrouper'] diff --git a/data_juicer/ops/grouper/naive_reverse_grouper.py b/data_juicer/ops/grouper/naive_reverse_grouper.py new file mode 100644 index 000000000..804b17d9e --- /dev/null +++ b/data_juicer/ops/grouper/naive_reverse_grouper.py @@ -0,0 +1,26 @@ +from ..base_op import OPERATORS, Grouper, convert_dict_list_to_list_dict + + +@OPERATORS.register_module('naive_reverse_grouper') +class NaiveReverseGrouper(Grouper): + """Split one batched sample to samples. """ + + def __init__(self, *args, **kwargs): + """ + Initialization method. + + :param args: extra args + :param kwargs: extra args + """ + super().__init__(*args, **kwargs) + + def process(self, dataset): + + if len(dataset) == 0: + return dataset + + samples = [] + for sample in dataset: + samples.append(convert_dict_list_to_list_dict(sample)) + + return samples diff --git a/tests/ops/grouper/test_naive_reverse_grouper.py b/tests/ops/grouper/test_naive_reverse_grouper.py new file mode 100644 index 000000000..29c06451d --- /dev/null +++ b/tests/ops/grouper/test_naive_reverse_grouper.py @@ -0,0 +1,83 @@ +import unittest + +from data_juicer.core.data import NestedDataset as Dataset +from data_juicer.ops.grouper.naive_reverse_grouper import NaiveReverseGrouper +from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase + + +class NaiveReverseGrouperTest(DataJuicerTestCaseBase): + + def _run_helper(self, op, samples, target): + dataset = Dataset.from_list(samples) + new_dataset = op.run(dataset) + + for d, t in zip(new_dataset, target): + self.assertEqual(d['text'], t['text']) + + def test_one_batched_sample(self): + + source = [ + { + 'text':[ + "Today is Sunday and it's a happy day!", + "Sur la plateforme MT4, plusieurs manières d'accéder à \n" + 'ces fonctionnalités sont conçues simultanément.', + '欢迎来到阿里巴巴!' + ] + } + ] + + target = [ + { + 'text': "Today is Sunday and it's a happy day!" + }, + { + 'text': + "Sur la plateforme MT4, plusieurs manières d'accéder à \n" + 'ces fonctionnalités sont conçues simultanément.' + }, + { + 'text': '欢迎来到阿里巴巴!' + }, + ] + + op = NaiveReverseGrouper() + self._run_helper(op, source, target) + + + def test_two_batch_sample(self): + + source = [ + { + 'text':[ + "Today is Sunday and it's a happy day!", + "Sur la plateforme MT4, plusieurs manières d'accéder à \n" + 'ces fonctionnalités sont conçues simultanément.' + ] + }, + { + 'text':[ + '欢迎来到阿里巴巴!' + ] + } + ] + + target = [ + { + 'text': "Today is Sunday and it's a happy day!" + }, + { + 'text': + "Sur la plateforme MT4, plusieurs manières d'accéder à \n" + 'ces fonctionnalités sont conçues simultanément.' + }, + { + 'text': '欢迎来到阿里巴巴!' + }, + ] + + op = NaiveReverseGrouper() + self._run_helper(op, source, target) + +if __name__ == '__main__': + unittest.main() \ No newline at end of file From cf01e7e48ca4e6c99622ac1785ef35a8a99eb3d9 Mon Sep 17 00:00:00 2001 From: Haibin <1400012807@pku.edu.cn> Date: Thu, 19 Dec 2024 17:29:40 +0800 Subject: [PATCH 113/118] naive reverse grouper --- data_juicer/ops/grouper/naive_reverse_grouper.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/data_juicer/ops/grouper/naive_reverse_grouper.py b/data_juicer/ops/grouper/naive_reverse_grouper.py index 804b17d9e..385a83821 100644 --- a/data_juicer/ops/grouper/naive_reverse_grouper.py +++ b/data_juicer/ops/grouper/naive_reverse_grouper.py @@ -21,6 +21,6 @@ def process(self, dataset): samples = [] for sample in dataset: - samples.append(convert_dict_list_to_list_dict(sample)) + samples.extend(convert_dict_list_to_list_dict(sample)) return samples From 0ba6459acac1243ddb3dfd6d89e349ecd172e298 Mon Sep 17 00:00:00 2001 From: Haibin <1400012807@pku.edu.cn> Date: Fri, 20 Dec 2024 10:51:29 +0800 Subject: [PATCH 114/118] tags specified field --- data_juicer/ops/selector/__init__.py | 4 +- .../selector/tags_specified_field_selector.py | 54 ++++++++++++++++ .../selector/test_tags_specified_selector.py | 63 +++++++++++++++++++ 3 files changed, 120 insertions(+), 1 deletion(-) create mode 100644 data_juicer/ops/selector/tags_specified_field_selector.py create mode 100644 tests/ops/selector/test_tags_specified_selector.py diff --git a/data_juicer/ops/selector/__init__.py b/data_juicer/ops/selector/__init__.py index 22df12987..0339a2c5b 100644 --- a/data_juicer/ops/selector/__init__.py +++ b/data_juicer/ops/selector/__init__.py @@ -1,9 +1,11 @@ from .frequency_specified_field_selector import FrequencySpecifiedFieldSelector from .random_selector import RandomSelector from .range_specified_field_selector import RangeSpecifiedFieldSelector +from .tags_specified_field_selector import TagsSpecifiedFieldSelector from .topk_specified_field_selector import TopkSpecifiedFieldSelector __all__ = [ 'FrequencySpecifiedFieldSelector', 'RandomSelector', - 'RangeSpecifiedFieldSelector', 'TopkSpecifiedFieldSelector' + 'RangeSpecifiedFieldSelector', 'TagsSpecifiedFieldSelector', + 'TopkSpecifiedFieldSelector' ] diff --git a/data_juicer/ops/selector/tags_specified_field_selector.py b/data_juicer/ops/selector/tags_specified_field_selector.py new file mode 100644 index 000000000..6fb32251a --- /dev/null +++ b/data_juicer/ops/selector/tags_specified_field_selector.py @@ -0,0 +1,54 @@ +import numbers +from typing import List + +from ..base_op import OPERATORS, Selector + + +@OPERATORS.register_module('tags_specified_field_selector') +class TagsSpecifiedFieldSelector(Selector): + """Selector to select samples based on the tags of specified + field.""" + + def __init__(self, + field_key: str = '', + target_tags: List[str] = None, + *args, + **kwargs): + """ + Initialization method. + + :param field_key: Selector based on the specified value + corresponding to the target key. The target key + corresponding to multi-level field information need to be + separated by '.'. + :param target_tags: Target tags to be select. + :param args: extra args + :param kwargs: extra args + """ + super().__init__(*args, **kwargs) + self.field_key = field_key + self.target_tags = set(target_tags) + + def process(self, dataset): + if len(dataset) <= 1 or not self.field_key: + return dataset + + field_keys = self.field_key.split('.') + assert field_keys[0] in dataset.features.keys( + ), "'{}' not in {}".format(field_keys[0], dataset.features.keys()) + + selected_index = [] + for i, item in enumerate(dataset[field_keys[0]]): + field_value = item + for key in field_keys[1:]: + assert key in field_value.keys(), "'{}' not in {}".format( + key, field_value.keys()) + field_value = field_value[key] + assert field_value is None or isinstance( + field_value, str) or isinstance( + field_value, numbers.Number + ), 'The {} item is not String, Numbers or NoneType'.format(i) + if field_value in self.target_tags: + selected_index.append(i) + + return dataset.select(selected_index) diff --git a/tests/ops/selector/test_tags_specified_selector.py b/tests/ops/selector/test_tags_specified_selector.py new file mode 100644 index 000000000..87c232a2b --- /dev/null +++ b/tests/ops/selector/test_tags_specified_selector.py @@ -0,0 +1,63 @@ +import unittest + +from data_juicer.core.data import NestedDataset as Dataset + +from data_juicer.ops.selector.tags_specified_field_selector import \ + TagsSpecifiedFieldSelector +from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase + + +class TagsSpecifiedFieldSelectorTest(DataJuicerTestCaseBase): + + def _run_tag_selector(self, dataset: Dataset, target_list, op): + dataset = op.process(dataset) + res_list = dataset.to_list() + self.assertEqual(res_list, target_list) + + def test_tag_select(self): + ds_list = [{ + 'text': 'a', + 'meta': { + 'sentiment': 'happy', + } + }, { + 'text': 'b', + 'meta': { + 'sentiment': 'happy', + } + }, { + 'text': 'c', + 'meta': { + 'sentiment': 'sad', + } + }, { + 'text': 'd', + 'meta': { + 'sentiment': 'angry', + } + }] + tgt_list = [{ + 'text': 'a', + 'meta': { + 'sentiment': 'happy', + } + }, { + 'text': 'b', + 'meta': { + 'sentiment': 'happy', + } + }, { + 'text': 'c', + 'meta': { + 'sentiment': 'sad', + } + }] + dataset = Dataset.from_list(ds_list) + op = TagsSpecifiedFieldSelector( + field_key='meta.sentiment', + target_tags=['happy', 'sad']) + self._run_tag_selector(dataset, tgt_list, op) + + +if __name__ == '__main__': + unittest.main() From 9f098bd24c38d36c719c3f264edd23b8e24617c4 Mon Sep 17 00:00:00 2001 From: Haibin <1400012807@pku.edu.cn> Date: Fri, 20 Dec 2024 12:03:37 +0800 Subject: [PATCH 115/118] doc done --- configs/config_all.yaml | 18 ++++++++++++++++++ .../ops/aggregator/meta_tags_aggregator.py | 2 +- .../ops/grouper/naive_reverse_grouper.py | 2 +- data_juicer/utils/auto_install_mapping.py | 1 + docs/Operators.md | 12 ++++++++---- docs/Operators_ZH.md | 11 +++++++---- 6 files changed, 36 insertions(+), 10 deletions(-) diff --git a/configs/config_all.yaml b/configs/config_all.yaml index 462478c23..42d1e779e 100644 --- a/configs/config_all.yaml +++ b/configs/config_all.yaml @@ -792,6 +792,9 @@ process: upper_percentile: # the upper bound of the percentile to be sampled lower_rank: # the lower rank of the percentile to be sampled upper_rank: # the upper rank of the percentile to be sampled + - tags_specified_field_selector: # Selector to select samples based on the tags of specified field. + field_key: '__dj__meta__.query_sentiment_label' # the target keys corresponding to multi-level field information need to be separated by '.' + target_tags: ['happy', 'sad'] # Target tags to be select. - topk_specified_field_selector: # selector to select top samples based on the sorted specified field field_key: '' # the target keys corresponding to multi-level field information need to be separated by '.' top_ratio: # ratio of selected top samples @@ -800,6 +803,7 @@ process: # Grouper ops. - naive_grouper: # Group all samples to one batched sample. + - naive_reverse_grouper: # Split one batched sample to samples. - key_value_grouper: # Group samples to batched samples according values in given keys. group_by_keys: null # Group samples according values in the keys. Support for nested keys such as "__dj__stats__.text_len". It is [self.text_key] in default. @@ -821,6 +825,20 @@ process: try_num: 3 # The number of retry attempts when there is an API call error or output parsing error. model_params: {} # Parameters for initializing the API model. sampling_params: {} # Extra parameters passed to the API call. e.g {'temperature': 0.9, 'top_p': 0.95} + - meta_tags_aggregator: # Merge similar meta tags to one tag. + api_model: 'gpt-4o' # API model name. + meta_tag_key: '__dj__meta__.query_sentiment_label' # The key of the meta tag to be mapped. + target_tags: ['开心', '难过', '其他'] # The tags that is supposed to be mapped to. + api_endpoint: null # URL endpoint for the API. + response_path: null # Path to extract content from the API response. Defaults to 'choices.0.message.content'. + system_prompt: null # The system prompt. + input_template: null # The input template. + target_tag_template: null # The tap template for target tags. + tag_template: null # The tap template for each tag and its frequency. + output_pattern: null # The output pattern. + try_num: 3 # The number of retry attempts when there is an API call error or output parsing error. + model_params: {} # Parameters for initializing the API model. + sampling_params: {} # Extra parameters passed to the API call. e.g {'temperature': 0.9, 'top_p': 0.95} - most_relavant_entities_aggregator: # Extract entities closely related to a given entity from some texts, and sort them in descending order of importance. api_model: 'gpt-4o' # API model name. entity: '孙悟空' # The given entity. diff --git a/data_juicer/ops/aggregator/meta_tags_aggregator.py b/data_juicer/ops/aggregator/meta_tags_aggregator.py index a60d9096a..808ef73da 100644 --- a/data_juicer/ops/aggregator/meta_tags_aggregator.py +++ b/data_juicer/ops/aggregator/meta_tags_aggregator.py @@ -16,7 +16,7 @@ @OPERATORS.register_module(OP_NAME) class MetaTagsAggregator(Aggregator): """ - Merge similar meta tags to one tags. + Merge similar meta tags to one tag. """ DEFAULT_SYSTEM_PROMPT = ('给定一些标签以及这些标签出现的频次,合并意思相近的标签。\n' diff --git a/data_juicer/ops/grouper/naive_reverse_grouper.py b/data_juicer/ops/grouper/naive_reverse_grouper.py index 385a83821..2535205b9 100644 --- a/data_juicer/ops/grouper/naive_reverse_grouper.py +++ b/data_juicer/ops/grouper/naive_reverse_grouper.py @@ -3,7 +3,7 @@ @OPERATORS.register_module('naive_reverse_grouper') class NaiveReverseGrouper(Grouper): - """Split one batched sample to samples. """ + """Split batched samples to samples. """ def __init__(self, *args, **kwargs): """ diff --git a/data_juicer/utils/auto_install_mapping.py b/data_juicer/utils/auto_install_mapping.py index 2da2e3616..3b8ec20aa 100644 --- a/data_juicer/utils/auto_install_mapping.py +++ b/data_juicer/utils/auto_install_mapping.py @@ -103,4 +103,5 @@ 'query_intent_detection_mapper': ['transformers'], 'query_sentiment_detection_mapper': ['transformers'], 'query_topic_detection_mapper': ['transformers'], + 'meta_tags_aggregator': ['openai'], } diff --git a/docs/Operators.md b/docs/Operators.md index bd099fb07..963155333 100644 --- a/docs/Operators.md +++ b/docs/Operators.md @@ -14,9 +14,9 @@ The operators in Data-Juicer are categorized into 5 types. | [ Mapper ]( #mapper ) | 70 | Edits and transforms samples | | [ Filter ]( #filter ) | 44 | Filters out low-quality samples | | [ Deduplicator ]( #deduplicator ) | 8 | Detects and removes duplicate samples | -| [ Selector ]( #selector ) | 4 | Selects top samples based on ranking | -| [ Grouper ]( #grouper ) | 2 | Group samples to batched samples | -| [ Aggregator ]( #aggregator ) | 3 | Aggregate for batched samples, such as summary or conclusion | +| [ Selector ]( #selector ) | 5 | Selects top samples based on ranking | +| [ Grouper ]( #grouper ) | 3 | Group samples to batched samples | +| [ Aggregator ]( #aggregator ) | 4 | Aggregate for batched samples, such as summary or conclusion | All the specific operators are listed below, each featured with several capability tags. @@ -199,20 +199,24 @@ All the specific operators are listed below, each featured with several capabili | frequency_specified_field_selector | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Selects top samples by comparing the frequency of the specified field | [code](../data_juicer/ops/selector/frequency_specified_field_selector.py) | [tests](../tests/ops/selector/test_frequency_specified_field_selector.py) | | random_selector | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Selects samples randomly | [code](../data_juicer/ops/selector/random_selector.py) | [tests](../tests/ops/selector/test_random_selector.py) | | range_specified_field_selector | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Selects samples within a specified range by comparing the values of the specified field | [code](../data_juicer/ops/selector/range_specified_field_selector.py) | [tests](../tests/ops/selector/test_range_specified_field_selector.py) | +| tags_specified_field_selector | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Select samples based on the tags of specified + field. | [code](../data_juicer/ops/selector/tags_specified_field_selector.py) | [tests](../tests/ops/selector/test_tags_specified_field_selector.py) | | topk_specified_field_selector | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Selects top samples by comparing the values of the specified field | [code](../data_juicer/ops/selector/topk_specified_field_selector.py) | [tests](../tests/ops/selector/test_topk_specified_field_selector.py) | ## Grouper | Operator | Tags | Description | Source code | Unit tests | |------------------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|-----------------------------------------------------------------------------------------|-------------------------------------------------------------------------------|---------------------------------------------------------------------------| -| key_value_grouper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Group samples to batched samples according values in given keys. | [code](../data_juicer/ops/grouper/key_value_grouper.py) | [tests](../tests/ops/grouper/test_key_value_grouper.py) | | naive_grouper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Group all samples to one batched sample. | [code](../data_juicer/ops/grouper/naive_grouper.py) | [tests](../tests/ops/grouper/test_naive_grouper.py) | +| naive_reverse_grouper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Split batched samples to samples. | [code](../data_juicer/ops/grouper/naive_reverse_grouper.py) | [tests](../tests/ops/grouper/test_naive_reverse_grouper.py) | +| key_value_grouper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Group samples to batched samples according values in given keys. | [code](../data_juicer/ops/grouper/key_value_grouper.py) | [tests](../tests/ops/grouper/test_key_value_grouper.py) | ## Aggregator | Operator | Tags | Description | Source code | Unit tests | |------------------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|-----------------------------------------------------------------------------------------|-------------------------------------------------------------------------------|---------------------------------------------------------------------------| | entity_attribute_aggregator | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Return conclusion of the given entity's attribute from some docs. | [code](../data_juicer/ops/aggregator/entity_attribute_aggregator.py) | [tests](../tests/ops/aggregator/test_entity_attribute_aggregator.py) | +| meta_tags_aggregator | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Merge similar meta tags to one tag. | [code](../data_juicer/ops/aggregator/meta_tags_aggregator.py) | [tests](../tests/ops/aggregator/test_meta_tags_aggregator.py) | | most_relavant_entities_aggregator | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Extract entities closely related to a given entity from some texts, and sort them in descending order of importance. | [code](../data_juicer/ops/aggregator/most_relavant_entities_aggregator.py) | [tests](../tests/ops/aggregator/test_most_relavant_entities_aggregator.py) | | nested_aggregator | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Considering the limitation of input length, nested aggregate contents for each given number of samples. | [code](../data_juicer/ops/aggregator/nested_aggregator.py) | [tests](../tests/ops/aggregator/test_nested_aggregator.py) | diff --git a/docs/Operators_ZH.md b/docs/Operators_ZH.md index 749459ad7..6c22ac0fe 100644 --- a/docs/Operators_ZH.md +++ b/docs/Operators_ZH.md @@ -14,9 +14,9 @@ Data-Juicer 中的算子分为以下 5 种类型。 | [ Mapper ]( #mapper ) | 70 | 对数据样本进行编辑和转换 | | [ Filter ]( #filter ) | 44 | 过滤低质量样本 | | [ Deduplicator ]( #deduplicator ) | 8 | 识别、删除重复样本 | -| [ Selector ]( #selector ) | 4 | 基于排序选取高质量样本 | -| [ Grouper ]( #grouper ) | 2 | 将样本分组,每一组组成一个批量样本 | -| [ Aggregator ]( #aggregator ) | 3 | 对批量样本进行汇总,如得出总结或结论 | +| [ Selector ]( #selector ) | 5 | 基于排序选取高质量样本 | +| [ Grouper ]( #grouper ) | 3 | 将样本分组,每一组组成一个批量样本 | +| [ Aggregator ]( #aggregator ) | 4 | 对批量样本进行汇总,如得出总结或结论 | 下面列出所有具体算子,每种算子都通过多个标签来注明其主要功能。 @@ -198,20 +198,23 @@ Data-Juicer 中的算子分为以下 5 种类型。 | frequency_specified_field_selector | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 通过比较指定字段的频率选出前 k 个样本 | [code](../data_juicer/ops/selector/frequency_specified_field_selector.py) | [tests](../tests/ops/selector/test_frequency_specified_field_selector.py) | | random_selector | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 随机筛选 k 个样本 | [code](../data_juicer/ops/selector/random_selector.py) | [tests](../tests/ops/selector/test_random_selector.py) | | range_specified_field_selector | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 通过比较指定字段的值选出指定范围的 k 个样本 | [code](../data_juicer/ops/selector/range_specified_field_selector.py) | [tests](../tests/ops/selector/test_range_specified_field_selector.py) | +| tags_specified_field_selector | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 通过指定字段的标签值筛选样例 | [code](../data_juicer/ops/selector/tags_specified_field_selector.py) | [tests](../tests/ops/selector/test_tags_specified_field_selector.py) | | topk_specified_field_selector | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 通过比较指定字段的值选出前 k 个样本 | [code](../data_juicer/ops/selector/topk_specified_field_selector.py) | [tests](../tests/ops/selector/test_topk_specified_field_selector.py) | ## Grouper | 算子 | 标签 | 描述 | 源码 | 单测样例 | |-------------------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|--------------------------|---------------------------------------------------------------------------|---------------------------------------------------------------------------| +| naive_grouper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 将所有样本分为一个组,返回一个batch化的样本 | [code](../data_juicer/ops/grouper/naive_grouper.py) | [tests](../tests/ops/grouper/test_naive_grouper.py) | +| naive_reverse_grouper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 将batch化的样本拆分成普通的样本 | [code](../data_juicer/ops/grouper/naive_reverse_grouper.py) | [tests](../tests/ops/grouper/test_naive_reverse_grouper.py) | | key_value_grouper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 根据给定键的值将样本分组,每一组组成一个批量样本。 | [code](../data_juicer/ops/grouper/key_value_grouper.py) | [tests](../tests/ops/grouper/test_key_value_grouper.py) | -| naive_grouper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 将所有样本分为一个组,返回一个批量样本 | [code](../data_juicer/ops/grouper/naive_grouper.py) | [tests](../tests/ops/grouper/test_naive_grouper.py) | ## Aggregator | 算子 | 标签 | 描述 | 源码 | 单测样例 | |-------------------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|--------------------------|---------------------------------------------------------------------------|---------------------------------------------------------------------------| | entity_attribute_aggregator | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 从一些文本中总结出给定实体的属性 | [code](../data_juicer/ops/aggregator/entity_attribute_aggregator.py) | [tests](../tests/ops/aggregator/test_entity_attribute_aggregator.py) | +| meta_tags_aggregator | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 将相似的标签合并成同一个标签。 | [code](../data_juicer/ops/aggregator/meta_tags_aggregator.py) | [tests](../tests/ops/aggregator/test_meta_tags_aggregator.py) | | most_relavant_entities_aggregator | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 从一些文本中抽取出与给定实体密切相关的实体,按重要性从高到低排序 | [code](../data_juicer/ops/aggregator/most_relavant_entities_aggregator.py) | [tests](../tests/ops/aggregator/test_most_relavant_entities_aggregator.py) | | nested_aggregator | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 考虑到输入长度的限制,对样本中的内容进行嵌套聚合。 | [code](../data_juicer/ops/aggregator/nested_aggregator.py) | [tests](../tests/ops/aggregator/test_nested_aggregator.py) | From a31b9c5ebbb17f66c409414bee6b956d33686fd0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=81=93=E8=BE=95?= Date: Thu, 26 Dec 2024 16:56:04 +0800 Subject: [PATCH 116/118] - rename tests/ops/Aggregator intotests/ops/aggregator for right linking; - minor fix for OP doc --- docs/Operators.md | 2 +- docs/Operators_ZH.md | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/Operators.md b/docs/Operators.md index 963155333..ea84a360c 100644 --- a/docs/Operators.md +++ b/docs/Operators.md @@ -6,7 +6,7 @@ This page offers a basic description of the operators (OPs) in Data-Juicer. User ## Overview -The operators in Data-Juicer are categorized into 5 types. +The operators in Data-Juicer are categorized into 7 types. | Type | Number | Description | |-----------------------------------|:------:|-------------------------------------------------| diff --git a/docs/Operators_ZH.md b/docs/Operators_ZH.md index 6c22ac0fe..40710f68f 100644 --- a/docs/Operators_ZH.md +++ b/docs/Operators_ZH.md @@ -6,7 +6,7 @@ ## 概览 -Data-Juicer 中的算子分为以下 5 种类型。 +Data-Juicer 中的算子分为以下 7 种类型。 | 类型 | 数量 | 描述 | |------------------------------------|:--:|---------------| From 8241508409fab8c01235d7f651d43949b1b3e08a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=81=93=E8=BE=95?= Date: Thu, 26 Dec 2024 17:20:53 +0800 Subject: [PATCH 117/118] rename for right doc linking in test dir --- tests/ops/{Aggregator => aggregator}/__init__.py | 0 .../test_entity_attribute_aggregator.py | 0 tests/ops/{Aggregator => aggregator}/test_meta_tags_aggregator.py | 0 .../test_most_relavant_entities_aggregator.py | 0 tests/ops/{Aggregator => aggregator}/test_nested_aggregator.py | 0 5 files changed, 0 insertions(+), 0 deletions(-) rename tests/ops/{Aggregator => aggregator}/__init__.py (100%) rename tests/ops/{Aggregator => aggregator}/test_entity_attribute_aggregator.py (100%) rename tests/ops/{Aggregator => aggregator}/test_meta_tags_aggregator.py (100%) rename tests/ops/{Aggregator => aggregator}/test_most_relavant_entities_aggregator.py (100%) rename tests/ops/{Aggregator => aggregator}/test_nested_aggregator.py (100%) diff --git a/tests/ops/Aggregator/__init__.py b/tests/ops/aggregator/__init__.py similarity index 100% rename from tests/ops/Aggregator/__init__.py rename to tests/ops/aggregator/__init__.py diff --git a/tests/ops/Aggregator/test_entity_attribute_aggregator.py b/tests/ops/aggregator/test_entity_attribute_aggregator.py similarity index 100% rename from tests/ops/Aggregator/test_entity_attribute_aggregator.py rename to tests/ops/aggregator/test_entity_attribute_aggregator.py diff --git a/tests/ops/Aggregator/test_meta_tags_aggregator.py b/tests/ops/aggregator/test_meta_tags_aggregator.py similarity index 100% rename from tests/ops/Aggregator/test_meta_tags_aggregator.py rename to tests/ops/aggregator/test_meta_tags_aggregator.py diff --git a/tests/ops/Aggregator/test_most_relavant_entities_aggregator.py b/tests/ops/aggregator/test_most_relavant_entities_aggregator.py similarity index 100% rename from tests/ops/Aggregator/test_most_relavant_entities_aggregator.py rename to tests/ops/aggregator/test_most_relavant_entities_aggregator.py diff --git a/tests/ops/Aggregator/test_nested_aggregator.py b/tests/ops/aggregator/test_nested_aggregator.py similarity index 100% rename from tests/ops/Aggregator/test_nested_aggregator.py rename to tests/ops/aggregator/test_nested_aggregator.py From 8740c8939a145f5efac7091e522a6b209cbfbbdd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=81=93=E8=BE=95?= Date: Thu, 26 Dec 2024 17:23:25 +0800 Subject: [PATCH 118/118] fix bad dingtalk link --- README.md | 2 +- README_ZH.md | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 518e54713..385a9aef1 100644 --- a/README.md +++ b/README.md @@ -34,7 +34,7 @@ We provide a [playground](http://8.138.149.181/) with a managed JupyterLab. [Try [Platform for AI of Alibaba Cloud (PAI)](https://www.aliyun.com/product/bigdata/learn) has cited our work and integrated Data-Juicer into its data processing products. PAI is an AI Native large model and AIGC engineering platform that provides dataset management, computing power management, model tool chain, model development, model training, model deployment, and AI asset management. For documentation on data processing, please refer to: [PAI-Data Processing for Large Models](https://help.aliyun.com/zh/pai/user-guide/components-related-to-data-processing-for-foundation-models/?spm=a2c4g.11186623.0.0.3e9821a69kWdvX). Data-Juicer is being actively updated and maintained. We will periodically enhance and add more features, data recipes and datasets. -We welcome you to join us (via issues, PRs, [Slack](https://join.slack.com/t/data-juicer/shared_invite/zt-23zxltg9d-Z4d3EJuhZbCLGwtnLWWUDg?spm=a2c22.12281976.0.0.7a8253f30mgpjw) channel, [DingDing](https://qr.dingtalk.com/action/joingroup?spm=a2c22.12281976.0.0.7a8253f30mgpjw&code=v1,k1,C0DI7CwRFrg7gJP5aMC95FUmsNuwuKJboT62BqP5DAk=&_dt_no_comment=1&origin=11) group, ...), in promoting data-model co-development along with research and applications of (multimodal) LLMs! +We welcome you to join us (via issues, PRs, [Slack](https://join.slack.com/t/data-juicer/shared_invite/zt-23zxltg9d-Z4d3EJuhZbCLGwtnLWWUDg?spm=a2c22.12281976.0.0.7a8253f30mgpjw) channel, [DingDing](https://qr.dingtalk.com/action/joingroup?code=v1,k1,YFIXM2leDEk7gJP5aMC95AfYT+Oo/EP/ihnaIEhMyJM=&_dt_no_comment=1&origin=11) group, ...), in promoting data-model co-development along with research and applications of (multimodal) LLMs! ---- diff --git a/README_ZH.md b/README_ZH.md index 366fcb004..54ba1445f 100644 --- a/README_ZH.md +++ b/README_ZH.md @@ -27,7 +27,7 @@ Data-Juicer 是一个一站式**多模态**数据处理系统,旨在为大语 [阿里云人工智能平台 PAI](https://www.aliyun.com/product/bigdata/learn) 已引用我们的工作,将Data-Juicer的能力集成到PAI的数据处理产品中。PAI提供包含数据集管理、算力管理、模型工具链、模型开发、模型训练、模型部署、AI资产管理在内的功能模块,为用户提供高性能、高稳定、企业级的大模型工程化能力。数据处理的使用文档请参考:[PAI-大模型数据处理](https://help.aliyun.com/zh/pai/user-guide/components-related-to-data-processing-for-foundation-models/?spm=a2c4g.11186623.0.0.3e9821a69kWdvX)。 -Data-Juicer正在积极更新和维护中,我们将定期强化和新增更多的功能和数据菜谱。热烈欢迎您加入我们(issues/PRs/[Slack频道](https://join.slack.com/t/data-juicer/shared_invite/zt-23zxltg9d-Z4d3EJuhZbCLGwtnLWWUDg?spm=a2c22.12281976.0.0.7a8275bc8g7ypp) /[钉钉群](https://qr.dingtalk.com/action/joingroup?spm=a2c22.12281976.0.0.7a8275bc8g7ypp&code=v1,k1,C0DI7CwRFrg7gJP5aMC95FUmsNuwuKJboT62BqP5DAk=&_dt_no_comment=1&origin=11)/...),一起推进LLM-数据的协同开发和研究! +Data-Juicer正在积极更新和维护中,我们将定期强化和新增更多的功能和数据菜谱。热烈欢迎您加入我们(issues/PRs/[Slack频道](https://join.slack.com/t/data-juicer/shared_invite/zt-23zxltg9d-Z4d3EJuhZbCLGwtnLWWUDg?spm=a2c22.12281976.0.0.7a8275bc8g7ypp) /[钉钉群](https://qr.dingtalk.com/action/joingroup?code=v1,k1,YFIXM2leDEk7gJP5aMC95AfYT+Oo/EP/ihnaIEhMyJM=&_dt_no_comment=1&origin=11)/...),一起推进LLM-数据的协同开发和研究! ----