diff --git a/.coveragerc b/.coveragerc index d4a7a6d63..d95c7fc28 100644 --- a/.coveragerc +++ b/.coveragerc @@ -9,3 +9,6 @@ omit = # avoid measuring code of unittest tests/* + +[report] +ignore_errors = True diff --git a/.github/workflows/deploy_sphinx_docs.yml b/.github/workflows/deploy_sphinx_docs.yml index 9c8ae89a0..5cf0205ae 100644 --- a/.github/workflows/deploy_sphinx_docs.yml +++ b/.github/workflows/deploy_sphinx_docs.yml @@ -12,13 +12,16 @@ on: jobs: pages: runs-on: ubuntu-20.04 + strategy: + matrix: + python-version: [ "3.9", "3.10" ] steps: - name: Checkout uses: actions/checkout@v4 - name: Setup Python ${{ matrix.python-version }} uses: actions/setup-python@master with: - python_version: ${{ matrix.python-version }} + python-version: ${{ matrix.python-version }} - name: Install dependencies run: | python -m pip install --upgrade pip diff --git a/.github/workflows/perf-bench.yml b/.github/workflows/perf-bench.yml new file mode 100644 index 000000000..4094070db --- /dev/null +++ b/.github/workflows/perf-bench.yml @@ -0,0 +1,56 @@ +# This workflow will install Python dependencies, run tests and lint with a single version of Python +# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python + +name: performance_benchmark + +on: + workflow_dispatch: + push: + branches: + - main + +permissions: + contents: read + +env: + ACTIONS_ALLOW_USE_UNSECURE_NODE_VERSION: true + +jobs: + perf_bench: + runs-on: [GPU, unittest] + environment: Testing + steps: + - uses: actions/checkout@v3 + with: + path: dj-${{ github.run_id }} + + - name: Setup docker compose + working-directory: dj-${{ github.run_id }}/.github/workflows/docker + run: | + docker compose up -d + + - name: Install data-juicer + working-directory: dj-${{ github.run_id }}/.github/workflows/docker + run: | + docker compose exec ray-head pip install -e .\[all\] + + - name: Clean dataset cache + working-directory: dj-${{ github.run_id }}/.github/workflows/docker + run: | + docker compose exec ray-head rm -rf /data/huggingface/dataset + + - name: Run performance benchmark standalone + working-directory: dj-${{ github.run_id }}/.github/workflows/docker + run: | + docker compose exec ray-head bash tests/benchmark_performance/run.sh ${{ secrets.INTERNAL_WANDB_URL }} ${{ secrets.INTERNAL_WANDB_API_KEY }} + + - name: Remove docker compose + working-directory: dj-${{ github.run_id }}/.github/workflows/docker + if: always() + run: | + docker compose down --remove-orphans + + - name: Cleanup workspace + if: always() + run: | + rm -rf dj-${{ github.run_id }} diff --git a/.github/workflows/publish-docker.yml b/.github/workflows/publish-docker.yml index 88d531070..5e9fde84b 100644 --- a/.github/workflows/publish-docker.yml +++ b/.github/workflows/publish-docker.yml @@ -12,11 +12,12 @@ on: env: IMAGE_NAME: datajuicer/data-juicer + ACTIONS_ALLOW_USE_UNSECURE_NODE_VERSION: true jobs: build: - runs-on: ubuntu-latest + runs-on: [docker] permissions: contents: read packages: write @@ -27,7 +28,9 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@v4 + uses: actions/checkout@v3 + with: + path: dj-${{ github.run_id }} # Install the cosign tool except on PR # https://github.com/sigstore/cosign-installer @@ -40,12 +43,12 @@ jobs: # multi-platform images and export cache # https://github.com/docker/setup-buildx-action - name: Set up Docker Buildx - uses: docker/setup-buildx-action@v3 + uses: docker/setup-buildx-action@v2 # Login against a Docker registry except on PR # https://github.com/docker/login-action - name: Log into Docker Hub - uses: docker/login-action@v3 + uses: docker/login-action@v2 with: username: ${{ secrets.DOCKER_USERNAME }} password: ${{ secrets.DOCKER_PASSWORD }} @@ -64,12 +67,10 @@ jobs: id: build-and-push uses: docker/build-push-action@v6 with: - context: . + context: dj-${{ github.run_id }} push: true tags: ${{ steps.meta.outputs.tags }} labels: ${{ steps.meta.outputs.labels }} - cache-from: type=gha - cache-to: type=gha,mode=max # Sign the resulting Docker image digest except on PRs. # This will only write to the public Rekor transparency log when the Docker diff --git a/.github/workflows/unit-test.yml b/.github/workflows/unit-test.yml index f292f6fdd..ce7af4474 100644 --- a/.github/workflows/unit-test.yml +++ b/.github/workflows/unit-test.yml @@ -18,7 +18,7 @@ env: jobs: unittest-single: - runs-on: [self-hosted, linux] + runs-on: [GPU, unittest] environment: Testing steps: - uses: actions/checkout@v3 diff --git a/Dockerfile b/Dockerfile index 3849c4e7f..347794544 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,15 +1,28 @@ # The data-juicer image includes all open-source contents of data-juicer, # and it will be instaled in editable mode. -FROM python:3.8.18 +FROM nvidia/cuda:12.4.1-cudnn-devel-ubuntu22.04 + +# install python 3.10 +RUN apt-get update \ + && apt-get install -y git curl vim wget python3.10 libpython3.10-dev python3-pip \ + && apt-get install -y libgl1-mesa-glx libglib2.0-0 \ + && ln -sf /usr/bin/python3.10 /usr/bin/python3 \ + && ln -sf /usr/bin/python3.10 /usr/bin/python \ + && apt-get autoclean && rm -rf /var/lib/apt/lists/* \ + && pip install --upgrade pip + +# install 3rd-party system dependencies +RUN apt-get update \ + && apt-get install ffmpeg libsm6 libxext6 software-properties-common build-essential cmake gfortran libopenblas-dev liblapack-dev -y # prepare the java env WORKDIR /opt # download jdk -RUN wget https://aka.ms/download-jdk/microsoft-jdk-17.0.9-linux-x64.tar.gz -O jdk.tar.gz && \ - tar -xzf jdk.tar.gz && \ - rm -rf jdk.tar.gz && \ - mv jdk-17.0.9+8 jdk +RUN wget https://aka.ms/download-jdk/microsoft-jdk-17.0.9-linux-x64.tar.gz -O jdk.tar.gz \ + && tar -xzf jdk.tar.gz \ + && rm -rf jdk.tar.gz \ + && mv jdk-17.0.9+8 jdk # set the environment variable ENV JAVA_HOME=/opt/jdk @@ -17,16 +30,10 @@ ENV JAVA_HOME=/opt/jdk WORKDIR /data-juicer # install requirements which need to be installed from source -RUN pip install git+https://github.com/xinyu1205/recognize-anything.git --default-timeout 1000 - -# install requirements first to better reuse installed library cache -COPY environments/ environments/ -RUN cat environments/* | xargs pip install --default-timeout 1000 +RUN pip install --upgrade setuptools==69.5.1 setuptools_scm \ + && pip install git+https://github.com/xinyu1205/recognize-anything.git --default-timeout 1000 # install data-juicer then COPY . . -RUN pip install -v -e .[all] -RUN pip install -v -e .[sandbox] - -# install 3rd-party system dependencies -RUN apt-get update && apt-get install ffmpeg libsm6 libxext6 -y +RUN pip install -v -e .[all] --default-timeout 1000 +RUN pip install -v -e .[sandbox] --default-timeout 1000 diff --git a/README.md b/README.md index 32b5b0d1f..d891ac332 100644 --- a/README.md +++ b/README.md @@ -163,7 +163,7 @@ Table of Contents ## Prerequisites -- Recommend Python>=3.8,<=3.10 +- Recommend Python>=3.9,<=3.10 - gcc >= 5 (at least C++14 support) ## Installation @@ -197,6 +197,22 @@ The dependency options are listed below: | `.[tools]` | Install dependencies for dedicated tools, such as quality classifiers. | | `.[sandbox]` | Install all dependencies for sandbox. | +- Install dependencies for specific OPs + +With the growth of the number of OPs, the dependencies of all OPs becomes very heavy. Instead of using the command `pip install -v -e .[sci]` to install all dependencies, +we provide two alternative, lighter options: + + - Automatic Minimal Dependency Installation: During the execution of Data-Juicer, minimal dependencies will be automatically installed. This allows for immediate execution, but may potentially lead to dependency conflicts. + + - Manual Minimal Dependency Installation: To manually install minimal dependencies tailored to a specific execution configuration, run the following command: + ```shell + # only for installation from source + python tools/dj_install.py --config path_to_your_data-juicer_config_file + + # use command line tool + dj-install --config path_to_your_data-juicer_config_file + ``` + ### Using pip - Run the following command to install the latest released `data_juicer` using `pip`: @@ -317,6 +333,11 @@ python tools/analyze_data.py --config configs/demo/analyzer.yaml # use command line tool dj-analyze --config configs/demo/analyzer.yaml + +# you can also use auto mode to avoid writing a recipe. It will analyze a small +# part (e.g. 1000 samples, specified by argument `auto_num`) of your dataset +# with all Filters that produce stats. +dj-analyze --auto --dataset_path xx.jsonl [--auto_num 1000] ``` - **Note:** Analyzer only compute stats of Filter ops. So extra Mapper or Deduplicator ops will be ignored in the analysis process. @@ -386,6 +407,10 @@ python tools/sandbox_starter.py --config configs/demo/sandbox/sandbox.yaml ```shell # run the data processing directly docker run --rm \ # remove container after the processing + --privileged \ + --shm-size 256g \ + --network host \ + --gpus all \ --name dj \ # name of the container -v : \ # mount data or config directory into the container -v ~/.cache/:/root/.cache/ \ # mount the cache directory into the container to reuse caches and models (recommended) @@ -398,6 +423,10 @@ docker run --rm \ # remove container after the processing ```shell # start the container docker run -dit \ # run the container in the background + --privileged \ + --shm-size 256g \ + --network host \ + --gpus all \ --rm \ --name dj \ -v : \ diff --git a/README_ZH.md b/README_ZH.md index 03e349547..01633731b 100644 --- a/README_ZH.md +++ b/README_ZH.md @@ -144,7 +144,7 @@ Data-Juicer正在积极更新和维护中,我们将定期强化和新增更多 ## 前置条件 -* 推荐 Python>=3.8,<=3.10 +* 推荐 Python>=3.9,<=3.10 * gcc >= 5 (at least C++14 support) ## 安装 @@ -178,6 +178,21 @@ pip install -v -e .[tools] # 安装部分工具库的依赖 | `.[tools]` | 安装专用工具库(如质量分类器)所需的依赖项 | | `.[sandbox]` | 安装沙盒实验室的基础依赖 | +* 只安装部分算子依赖 + +随着OP数量的增长,所有OP的依赖变得很重。为此,我们提供了两个替代的、更轻量的选项,作为使用命令`pip install -v -e .[sci]`安装所有依赖的替代: + + * 自动最小依赖安装:在执行Data-Juicer的过程中,将自动安装最小依赖。也就是说你可以直接执行,但这种方式可能会导致一些依赖冲突。 + + * 手动最小依赖安装:可以通过如下指令手动安装适合特定执行配置的最小依赖: + ```shell + # 适用于从源码安装 + python tools/dj_install.py --config path_to_your_data-juicer_config_file + + # 使用命令行工具 + dj-install --config path_to_your_data-juicer_config_file + ``` + ### 使用 pip 安装 * 运行以下命令用 `pip` 安装 `data_juicer` 的最新发布版本: @@ -295,6 +310,10 @@ python tools/analyze_data.py --config configs/demo/analyzer.yaml # 使用命令行工具 dj-analyze --config configs/demo/analyzer.yaml + +# 你也可以使用"自动"模式来避免写一个新的数据菜谱。它会使用全部可产出统计信息的 Filter 来分析 +# 你的数据集的一小部分(如1000条样本,可通过 `auto_num` 参数指定) +dj-analyze --auto --dataset_path xx.jsonl [--auto_num 1000] ``` * **注意**:Analyzer 只计算 Filter 算子的状态,其他的算子(例如 Mapper 和 Deduplicator)会在分析过程中被忽略。 @@ -363,6 +382,10 @@ python tools/sandbox_starter.py --config configs/demo/sandbox/sandbox.yaml ```shell # 直接运行数据处理 docker run --rm \ # 在处理结束后将容器移除 + --privileged \ + --shm-size 256g \ + --network host \ + --gpus all \ --name dj \ # 容器名称 -v : \ # 将本地的数据或者配置目录挂载到容器中 -v ~/.cache/:/root/.cache/ \ # 将 cache 目录挂载到容器以复用 cache 和模型资源(推荐) @@ -375,6 +398,10 @@ docker run --rm \ # 在处理结束后将容器移除 ```shell # 启动容器 docker run -dit \ # 在后台启动容器 + --privileged \ + --shm-size 256g \ + --network host \ + --gpus all \ --rm \ --name dj \ -v : \ diff --git a/configs/config_all.yaml b/configs/config_all.yaml index df4bf91ad..8335bb173 100644 --- a/configs/config_all.yaml +++ b/configs/config_all.yaml @@ -15,15 +15,18 @@ text_keys: 'text' # the key name of fi suffixes: [] # the suffix of files that will be read. For example: '.txt', 'txt' or ['txt', '.pdf', 'docx'] use_cache: true # whether to use the cache management of Hugging Face datasets. It might take up lots of disk space when using cache ds_cache_dir: null # cache dir for Hugging Face datasets. In default, it\'s the same as the environment variable `HF_DATASETS_CACHE`, whose default value is usually "~/.cache/huggingface/datasets". If this argument is set to a valid path by users, it will override the default cache dir +open_monitor: true # Whether to open the monitor to trace resource utilization for each OP during data processing. It\'s True in default. use_checkpoint: false # whether to use the checkpoint management to save the latest version of dataset to work dir when processing. Rerun the same config will reload the checkpoint and skip ops before it. Cache will be disabled when using checkpoint. If args of ops before the checkpoint are changed, all ops will be rerun from the beginning. temp_dir: null # the path to the temp directory to store intermediate caches when cache is disabled, these cache files will be removed on-the-fly. In default, it's None, so the temp dir will be specified by system. NOTICE: you should be caution when setting this argument because it might cause unexpected program behaviors when this path is set to an unsafe directory. open_tracer: false # whether to open the tracer to trace the changes during process. It might take more time when opening tracer op_list_to_trace: [] # only ops in this list will be traced by tracer. If it's empty, all ops will be traced. Only available when tracer is opened. trace_num: 10 # number of samples to show the differences between datasets before and after each op. Only available when tracer is opened. op_fusion: false # whether to fuse operators that share the same intermediate variables automatically. Op fusion might reduce the memory requirements slightly but speed up the whole process. +fusion_strategy: 'probe' # OP fusion strategy. Support ['greedy', 'probe'] now. 'greedy' means keep the basic OP order and put the fused OP to the last of each fused OP group. 'probe' means Data-Juicer will probe the running speed for each OP at the beginning and reorder the OPs and fused OPs according to their probed speed (fast to slow). It's 'probe' in default. cache_compress: null # the compression method of the cache file, which can be specified in ['gzip', 'zstd', 'lz4']. If this parameter is None, the cache file will not be compressed. We recommend you turn on this argument when your input dataset is larger than tens of GB and your disk space is not enough. keep_stats_in_res_ds: false # whether to keep the computed stats in the result dataset. The intermediate fields to store the stats computed by Filters will be removed if it's False. It's False in default. keep_hashes_in_res_ds: false # whether to keep the computed hashes in the result dataset. The intermediate fields to store the hashes computed by Deduplicators will be removed if it's False. It's False in default. +adaptive_batch_size: false # whether to use adaptive batch sizes for each OP according to the probed results. It's False in default. # for multimodal data processing image_key: 'images' # key name of field to store the list of sample image paths. @@ -76,9 +79,9 @@ process: - clean_copyright_mapper: # remove copyright comments. - expand_macro_mapper: # expand macro definitions in Latex text. - extract_entity_attribute_mapper: # Extract attributes for given entities from the text. + api_model: 'gpt-4o' # API model name. query_entities: ["孙悟空", "猪八戒"] # Entity list to be queried. query_attributes: ["人物性格"] # Attribute list to be queried. - api_model: 'gpt-4o' # API model name. entity_key: '__dj__entity__' # The field name to store the given main entity for attribute extraction. entity_attribute_key: '__dj__attribute__' # The field name to store the given attribute to be extracted. attribute_desc_key: '__dj__attribute_description__' # The field name to store the extracted attribute description. @@ -150,6 +153,18 @@ process: drop_text: false # If drop the text in the output. model_params: {} # Parameters for initializing the API model. sampling_params: {} # Extra parameters passed to the API call. e.g {'temperature': 0.9, 'top_p': 0.95} + - extract_support_text_mapper: # extract support sub text for a summary. + api_model: 'gpt-4o' # API model name. + summary_key: '__dj__event_description__' # The field name to store the input summary. Support for nested keys such as "__dj__stats__.text_len". + support_text_key: '__dj__support_text__' # The field name to store the output support text for the summary. + api_endpoint: null # URL endpoint for the API. + response_path: null # Path to extract content from the API response. Defaults to 'choices.0.message.content'. + system_prompt: null # System prompt for the task. + input_template: null # Template for building the model input. + try_num: 3 # The number of retry attempts when there is an API call error or output parsing error. + drop_text: false # If drop the text in the output. + model_params: {} # Parameters for initializing the API model. + sampling_params: {} # Extra parameters passed to the API call. e.g {'temperature': 0.9, 'top_p': 0.95} - fix_unicode_mapper: # fix unicode errors in text. - generate_qa_from_examples_mapper: # mapper to generate question and answer pairs from examples. hf_model: 'Qwen/Qwen2.5-7B-Instruct' # Model name on huggingface to generate question and answer pairs. @@ -209,6 +224,7 @@ process: radius: 2 # radius of blur kernel - image_tagging_mapper: # Mapper to generate image tags. tag_field_name: '__dj__image_tags__' # the field name to store the tags. It's "__dj__image_tags__" in default. + mem_required: '9GB' - nlpaug_en_mapper: # simply augment texts in English based on the nlpaug library sequential: false # whether combine all augmentation methods to a sequence. If it's True, a sample will be augmented by all opened augmentation methods sequentially. If it's False, each opened augmentation method would generate its augmented samples independently. aug_num: 1 # number of augmented samples to be generated. If `sequential` is True, there will be total aug_num augmented samples generated. If it's False, there will be (aug_num * #opened_aug_method) augmented samples generated. @@ -242,7 +258,40 @@ process: sampling_params: {} # Sampling parameters for text generation. e.g {'temperature': 0.9, 'top_p': 0.95} - optimize_query_mapper: # optimize query in question-answer pairs. - optimize_response_mapper: # optimize response in question-answer pairs. + - pair_preference_mapper: # construct paired preference samples. + api_model: 'gpt-4o' # API model name. + api_endpoint: null # URL endpoint for the API. + response_path: null # Path to extract content from the API response. Defaults to 'choices.0.message.content'. + system_prompt: null # System prompt for guiding the generation task. + input_template: null # Template for building the model input. + output_pattern: null # Regular expression for parsing model output. + rejected_key: 'rejected_response' # The field name in the sample to store the generated rejected response. + reason_key: 'reason' # The field name in the sample to store the reason for generating the response. + try_num: 3 # The number of retries for the API call in case of response parsing failure. + model_params: {} # Parameters for initializing the API model. + sampling_params: {} # Extra parameters passed to the API call. - punctuation_normalization_mapper: # normalize unicode punctuations to English punctuations. + - python_file_mapper: # executing Python lambda function defined in a file. + file_path: '' # The path to the Python file containing the function to be executed. + function_name: 'process_single' # The name of the function defined in the file to be executed. + - python_lambda_mapper: # executing Python lambda function on data samples. + lambda_str: '' # A string representation of the lambda function to be executed on data samples. If empty, the identity function is used. + batched: False # A boolean indicating whether to process input data in batches. + - relation_identity_mapper: # identify relation between two entity in the text. + api_model: 'gpt-4o' # API model name. + source_entity: '孙悟空' # The source entity of the relation to be dentified. + target_entity: '猪八戒' # The target entity of the relation to be identified. + input_key: null # The input field key in the samples. Support for nested keys such as "__dj__stats__.text_len". It is text_key in default. + output_key: null # The output field key in the samples. Support for nested keys such as "__dj__stats__.text_len". It is input_key in default. + api_endpoint: null # URL endpoint for the API. + response_path: null # Path to extract content from the API response. Defaults to 'choices.0.message.content'. + system_prompt_template: null # System prompt template for the task. Need to specify by entity1 and entity2. + input_template: null # Template for building the model input. + output_pattern_template: null # Regular expression template for parsing model output. + try_num: 3 # The number of retry attempts when there is an API call error or output parsing error. + drop_text: false # If drop the text in the output. + model_params: {} # Parameters for initializing the API model. + sampling_params: {} # Extra parameters passed to the API call. e.g {'temperature': 0.9, 'top_p': 0.95} - remove_bibliography_mapper: # remove bibliography from Latex text. - remove_comments_mapper: # remove comments from Latex text, code, etc. doc_type: tex # comment type you want to remove. Only support 'tex' for now. @@ -319,6 +368,11 @@ process: horizontal_flip: false # flip frame image horizontally (left to right). vertical_flip: false # flip frame image vertically (top to bottom). mem_required: '20GB' # This operation (Op) utilizes deep neural network models that consume a significant amount of memory for computation, hence the system's available memory might constrains the maximum number of processes that can be launched + - video_extract_frames_mapper: # extract frames from video files according to specified methods + frame_sampling_method: 'all_keyframes' # sampling method of extracting frame images from the videos. Should be one of ["all_keyframes", "uniform"]. The former one extracts all key frames and the latter one extract specified number of frames uniformly from the video. Default: "all_keyframes". + frame_num: 3 # the number of frames to be extracted uniformly from the video. Only works when frame_sampling_method is "uniform". If it's 1, only the middle frame will be extracted. If it's 2, only the first and the last frames will be extracted. If it's larger than 2, in addition to the first and the last frames, other frames will be extracted uniformly within the video duration. + duration: 0 # The duration of each segment in seconds. If 0, frames are extracted from the entire video. If duration > 0, the video is segmented into multiple segments based on duration, and frames are extracted from each segment. + frame_dir: None # Output directory to save extracted frames. If None, a default directory based on the video file path is used. - video_face_blur_mapper: # blur faces detected in videos cv_classifier: '' # OpenCV classifier path for face detection. By default, we will use 'haarcascade_frontalface_alt.xml'. blur_type: 'gaussian' # type of blur kernel, including ['mean', 'box', 'gaussian'] @@ -361,6 +415,7 @@ process: frame_sampling_method: 'all_keyframes' # sampling method of extracting frame images from the videos. Should be one of ["all_keyframes", "uniform"]. The former one extracts all key frames and the latter one extract specified number of frames uniformly from the video. Default: "all_keyframes". frame_num: 3 # the number of frames to be extracted uniformly from the video. Only works when frame_sampling_method is "uniform". If it's 1, only the middle frame will be extracted. If it's 2, only the first and the last frames will be extracted. If it's larger than 2, in addition to the first and the last frames, other frames will be extracted uniformly within the video duration. tag_field_name: '__dj__video_frame_tags__' # the field name to store the tags. It's "__dj__video_frame_tags__" in default. + mem_required: '9GB' - whitespace_normalization_mapper: # normalize different kinds of whitespaces to English whitespace. # Filter ops @@ -539,7 +594,7 @@ process: vertical_flip: false # flip frame image vertically (top to bottom). reduce_mode: avg # reduce mode when one text corresponds to multiple videos in a chunk, must be one of ['avg','max', 'min']. any_or_all: any # keep this sample when any/all videos meet the filter condition - mem_required: '1GB' # This operation (Op) utilizes deep neural network models that consume a significant amount of memory for computation, hence the system's available memory might constrains the maximum number of processes that can be launched + mem_required: '1500MB' # This operation (Op) utilizes deep neural network models that consume a significant amount of memory for computation, hence the system's available memory might constrains the maximum number of processes that can be launched - video_motion_score_filter: # Keep samples with video motion scores within a specific range. min_score: 0.25 # the minimum motion score to keep samples max_score: 10000.0 # the maximum motion score to keep samples @@ -593,6 +648,7 @@ process: frame_num: 3 # the number of frames to be extracted uniformly from the video. Only works when frame_sampling_method is "uniform". If it's 1, only the middle frame will be extracted. If it's 2, only the first and the last frames will be extracted. If it's larger than 2, in addition to the first and the last frames, other frames will be extracted uniformly within the video duration. tag_field_name: '__dj__video_frame_tags__' # the field name to store the tags. It's "__dj__video_frame_tags__" in default. any_or_all: any # keep this sample when any/all videos meet the filter condition + mem_required: '9GB' - words_num_filter: # filter text with number of words out of specific range lang: en # sample in which language tokenization: false # whether to use model to tokenize documents @@ -643,17 +699,6 @@ process: redis_port: 6380 # the port of redis instance, please note that the default port of redis is 6379 which is the same as default port for ray, so we need to modify the default redis config to use it in other port lowercase: false # whether to convert text to lower case ignore_non_character: false # whether to ignore non-alphabet characters, including whitespaces, digits, and punctuations - - ray_redis_minhash_deduplicator: # the document deduplicator that can run on multi-nodes using minhashLSH algorithm - redis_address: 'redis://localhost:6379' # the address of the redis instance - tokenization: space # tokenization method for text. One of [space, punctuation, character, sentencepiece] - window_size: 5 # window size of shingling - num_permutations: 256 # number of permutations in minhash computing - jaccard_threshold: 0.7 # the min jaccard similarity threshold in near-duplicate detection. When the jaccard similarity of two sample texts is >= this threshold, they are regarded as similar samples and this op will only keep one of them after deduplication - num_bands: null # number of bands in LSH. Default it's None, and it will be determined by an optimal params computation algorithm by minimize the weighted sum of probs of False Positives and False Negatives - num_rows_per_band: null # number of rows in each band in LSH. Default it's None, and it will be determined by an optimal params computation algorithm - lowercase: true # whether to convert text to lower case - ignore_pattern: null # whether to ignore sub-strings with specific pattern when computing simhash. - tokenizer_model: null # path for the sentencepiece model, used for sentencepiece tokenization. - ray_bts_minhash_deduplicator: # the document deduplicator that can run on multi-nodes using minhashLSH algorithm tokenization: space # tokenization method for text. One of [space, punctuation, character, sentencepiece] window_size: 5 # window size of shingling @@ -693,3 +738,55 @@ process: top_ratio: # ratio of selected top samples topk: # number of selected top sample reverse: True # determine the sorting rule, if reverse=True, then sort in descending order + +# Grouper ops. + - naive_grouper: # Group all samples to one batched sample. + - key_value_grouper: # Group samples to batched samples according values in given keys. + group_by_keys: null # Group samples according values in the keys. Support for nested keys such as "__dj__stats__.text_len". It is [self.text_key] in default. + +# Aggregator ops. + - entity_attribute_aggregator: # Return conclusion of the given entity's attribute from some docs. + api_model: 'gpt-4o' # API model name. + entity: '孙悟空' # The given entity. + attribute: '人物经历' # The given attribute. + input_key: null # The input field key in the samples. Support for nested keys such as "__dj__stats__.text_len". It is text_key in default. + output_key: null # The output field key in the samples. Support for nested keys such as "__dj__stats__.text_len". It is same as the input_key in default. + word_limit: 100 # Prompt the output length. + max_token_num: null # The max token num of the total tokens of the sub documents. Without limitation if it is None. + api_endpoint: null # URL endpoint for the API. + response_path: null # Path to extract content from the API response. Defaults to 'choices.0.message.content'. + system_prompt_template: null # System prompt template for the task. Need to be specified by given entity and attribute. + example_prompt: null # The example part in the system prompt. + input_template: null # The input template. + output_pattern_template: null # The output template. + try_num: 3 # The number of retry attempts when there is an API call error or output parsing error. + model_params: {} # Parameters for initializing the API model. + sampling_params: {} # Extra parameters passed to the API call. e.g {'temperature': 0.9, 'top_p': 0.95} + - most_relavant_entities_aggregator: # Extract entities closely related to a given entity from some texts, and sort them in descending order of importance. + api_model: 'gpt-4o' # API model name. + entity: '孙悟空' # The given entity. + query_entity_type: '人物' # The type of queried relavant entities. + input_key: null # The input field key in the samples. Support for nested keys such as "__dj__stats__.text_len". It is text_key in default. + output_key: null # The output field key in the samples. Support for nested keys such as "__dj__stats__.text_len". It is same as the input_key in default. + max_token_num: null # The max token num of the total tokens of the sub documents. Without limitation if it is None. + api_endpoint: null # URL endpoint for the API. + response_path: null # Path to extract content from the API response. Defaults to 'choices.0.message.content'. + system_prompt_template: null # System prompt template for the task. Need to be specified by given entity and entity_type. + input_template: null # The input template. + output_pattern: null # The output pattern. + try_num: 3 # The number of retry attempts when there is an API call error or output parsing error. + model_params: {} # Parameters for initializing the API model. + sampling_params: {} # Extra parameters passed to the API call. e.g {'temperature': 0.9, 'top_p': 0.95} + - nested_aggregator: # Considering the limitation of input length, nested aggregate contents for each given number of samples. + api_model: 'gpt-4o' # API model name. + input_key: null # The input field key in the samples. Support for nested keys such as "__dj__stats__.text_len". It is text_key in default. + output_key: null # The output field key in the samples. Support for nested keys such as "__dj__stats__.text_len". It is same as the input_key in default. + max_token_num: null # The max token num of the total tokens of the sub documents. Without limitation if it is None. + api_endpoint: null # URL endpoint for the API. + response_path: null # Path to extract content from the API response. Defaults to 'choices.0.message.content'. + system_prompt: null # The system prompt. + sub_doc_template: null # The template for input text in each sample. + input_template: null # The input template. + try_num: 3 # The number of retry attempts when there is an API call error or output parsing error. + model_params: {} # Parameters for initializing the API model. + sampling_params: {} # Extra parameters passed to the API call. e.g {'temperature': 0.9, 'top_p': 0.95} diff --git a/data_juicer/__init__.py b/data_juicer/__init__.py index 7565f493b..91ce93bae 100644 --- a/data_juicer/__init__.py +++ b/data_juicer/__init__.py @@ -1,4 +1,4 @@ -__version__ = '0.2.0' +__version__ = '1.0.1' import os import subprocess diff --git a/data_juicer/analysis/column_wise_analysis.py b/data_juicer/analysis/column_wise_analysis.py index 775b42683..825d9b4dd 100644 --- a/data_juicer/analysis/column_wise_analysis.py +++ b/data_juicer/analysis/column_wise_analysis.py @@ -4,6 +4,7 @@ import matplotlib.pyplot as plt import pandas as pd from tqdm import tqdm +from wordcloud import WordCloud from data_juicer.utils.constant import Fields @@ -145,33 +146,39 @@ def analyze(self, show_percentiles=False, show=False, skip_export=False): else: axes = [None] * num_subcol - # draw histogram - self.draw_hist(axes[0], - data, - os.path.join(self.output_path, - f'{column_name}-hist.png'), - percentiles=percentiles) - - # draw box - self.draw_box(axes[1], - data, - os.path.join(self.output_path, - f'{column_name}-box.png'), - percentiles=percentiles) + if not skip_export: + # draw histogram + self.draw_hist(axes[0], + data, + os.path.join(self.output_path, + f'{column_name}-hist.png'), + percentiles=percentiles) + + # draw box + self.draw_box(axes[1], + data, + os.path.join(self.output_path, + f'{column_name}-box.png'), + percentiles=percentiles) else: # object (string) or string list -- only draw histogram for # this stat if self.save_stats_in_one_file: - axes = subfig.subplots(1, 1) + axes = subfig.subplots(1, num_subcol) else: - axes = None + axes = [None] * num_subcol if not skip_export: self.draw_hist( - axes, data, + axes[0], data, os.path.join(self.output_path, f'{column_name}-hist.png')) + self.draw_wordcloud( + axes[1], data, + os.path.join(self.output_path, + f'{column_name}-wordcloud.png')) + # add a title to the figure of this stat if self.save_stats_in_one_file: subfig.suptitle(f'{data.name}', @@ -297,3 +304,33 @@ def draw_box(self, ax, data, save_path, percentiles=None, show=False): # accumulated overlapped figures in different draw_xxx function # calling ax.clear() + + def draw_wordcloud(self, ax, data, save_path, show=False): + word_list = data.tolist() + word_nums = {} + for w in word_list: + if w in word_nums: + word_nums[w] += 1 + else: + word_nums[w] = 1 + + wc = WordCloud(width=400, height=320) + wc.generate_from_frequencies(word_nums) + + if ax is None: + ax = plt.figure(figsize=(20, 16)) + else: + ax.imshow(wc, interpolation='bilinear') + ax.axis('off') + + if not self.save_stats_in_one_file: + # save into file + wc.to_file(save_path) + + if show: + plt.show() + else: + # if no showing, we need to clear this axes to avoid + # accumulated overlapped figures in different draw_xxx function + # calling + ax.clear() diff --git a/data_juicer/config/config.py b/data_juicer/config/config.py index 0b0487dc3..c7f0aaf38 100644 --- a/data_juicer/config/config.py +++ b/data_juicer/config/config.py @@ -15,6 +15,7 @@ from loguru import logger from data_juicer.ops.base_op import OPERATORS +from data_juicer.ops.op_fusion import FUSION_STRATEGIES from data_juicer.utils.logger_utils import setup_logger from data_juicer.utils.mm_utils import SpecialTokens @@ -22,7 +23,7 @@ global_parser = None -def init_configs(args: Optional[List[str]] = None): +def init_configs(args: Optional[List[str]] = None, which_entry: object = None): """ initialize the jsonargparse parser and parse configs from one of: 1. POSIX-style commands line args; @@ -31,14 +32,29 @@ def init_configs(args: Optional[List[str]] = None): 4. hard-coded defaults :param args: list of params, e.g., ['--conifg', 'cfg.yaml'], defaut None. + :param which_entry: which entry to init configs (executor/analyzer) :return: a global cfg object used by the Executor or Analyzer """ parser = ArgumentParser(default_env=True, default_config_files=None) - parser.add_argument('--config', - action=ActionConfigFile, - help='Path to a dj basic configuration file.', - required=True) + # required but mutually exclusive args group + required_group = parser.add_mutually_exclusive_group(required=True) + required_group.add_argument('--config', + action=ActionConfigFile, + help='Path to a dj basic configuration file.') + required_group.add_argument('--auto', + action='store_true', + help='Weather to use an auto analyzing ' + 'strategy instead of a specific data ' + 'recipe. If a specific config file is ' + 'given by --config arg, this arg is ' + 'disabled. Only available for Analyzer.') + + parser.add_argument('--auto_num', + type=PositiveInt, + default=1000, + help='The number of samples to be analyzed ' + 'automatically. It\'s 1000 in default.') parser.add_argument( '--hpo_config', @@ -96,7 +112,7 @@ def init_configs(args: Optional[List[str]] = None): parser.add_argument( '--export_path', type=str, - default='./outputs/hello_world.jsonl', + default='./outputs/hello_world/hello_world.jsonl', help='Path to export and save the output processed dataset. The ' 'directory to store the processed dataset will be the work ' 'directory of this process.') @@ -229,6 +245,12 @@ def init_configs(args: Optional[List[str]] = None): help='The compression method of the cache file, which can be' 'specified in ["gzip", "zstd", "lz4"]. If this parameter is' 'None, the cache file will not be compressed.') + parser.add_argument( + '--open_monitor', + type=bool, + default=True, + help='Whether to open the monitor to trace resource utilization for ' + 'each OP during data processing. It\'s True in default.') parser.add_argument( '--use_checkpoint', type=bool, @@ -275,6 +297,22 @@ def init_configs(args: Optional[List[str]] = None): help='Whether to fuse operators that share the same intermediate ' 'variables automatically. Op fusion might reduce the memory ' 'requirements slightly but speed up the whole process.') + parser.add_argument( + '--fusion_strategy', + type=str, + default='probe', + help='OP fusion strategy. Support ["greedy", "probe"] now. "greedy" ' + 'means keep the basic OP order and put the fused OP to the last ' + 'of each fused OP group. "probe" means Data-Juicer will probe ' + 'the running speed for each OP at the beginning and reorder the ' + 'OPs and fused OPs according to their probed speed (fast to ' + 'slow). It\'s "probe" in default.') + parser.add_argument( + '--adaptive_batch_size', + type=bool, + default=False, + help='Whether to use adaptive batch sizes for each OP according to ' + 'the probed results. It\'s False in default.') parser.add_argument( '--process', type=List[Dict], @@ -316,6 +354,14 @@ def init_configs(args: Optional[List[str]] = None): try: cfg = parser.parse_args(args=args) + + # check the entry + from data_juicer.core.analyzer import Analyzer + if not isinstance(which_entry, Analyzer) and cfg.auto: + err_msg = '--auto argument can only be used for analyzer!' + logger.error(err_msg) + raise NotImplementedError(err_msg) + cfg = init_setup_from_cfg(cfg) cfg = update_op_process(cfg, parser) @@ -436,6 +482,11 @@ def init_setup_from_cfg(cfg: Namespace): # The checkpoint mode is not compatible with op fusion for now. if cfg.op_fusion: cfg.use_checkpoint = False + cfg.fusion_strategy = cfg.fusion_strategy.lower() + if cfg.fusion_strategy not in FUSION_STRATEGIES: + raise NotImplementedError( + f'Unsupported OP fusion strategy [{cfg.fusion_strategy}]. ' + f'Should be one of {FUSION_STRATEGIES}.') # update huggingface datasets cache directory only when ds_cache_dir is set from datasets import config @@ -460,6 +511,16 @@ def init_setup_from_cfg(cfg: Namespace): SpecialTokens.image = cfg.image_special_token SpecialTokens.eoc = cfg.eoc_special_token + # add all filters that produce stats + if cfg.auto: + import pkgutil + + import data_juicer.ops.filter as djfilters + cfg.process = [{ + filter_name: {} + } for _, filter_name, _ in pkgutil.iter_modules(djfilters.__path__) + if filter_name not in djfilters.NON_STATS_FILTERS] + # Apply text_key modification during initializing configs # users can freely specify text_key for different ops using `text_key` # otherwise, set arg text_key of each op to text_keys @@ -537,8 +598,13 @@ def sort_op_by_types_and_names(op_name_classes): if 'deduplicator' in name] selector_ops = [(name, c) for (name, c) in op_name_classes if 'selector' in name] + grouper_ops = [(name, c) for (name, c) in op_name_classes + if 'grouper' in name] + aggregator_ops = [(name, c) for (name, c) in op_name_classes + if 'aggregator' in name] ops_sorted_by_types = sorted(mapper_ops) + sorted(filter_ops) + sorted( - deduplicator_ops) + sorted(selector_ops) + deduplicator_ops) + sorted(selector_ops) + sorted(grouper_ops) + \ + sorted(aggregator_ops) return ops_sorted_by_types @@ -603,7 +669,10 @@ def update_op_process(cfg, parser): temp_args = namespace_to_arg_list(temp_cfg, includes=recognized_args, excludes=['config']) - temp_args = ['--config', temp_cfg.config[0].absolute] + temp_args + if temp_cfg.config: + temp_args = ['--config', temp_cfg.config[0].absolute] + temp_args + else: + temp_args = ['--auto'] + temp_args temp_parser.parse_args(temp_args) return cfg @@ -629,6 +698,8 @@ def namespace_to_arg_list(namespace, prefix='', includes=None, excludes=None): def config_backup(cfg: Namespace): + if not cfg.config: + return cfg_path = cfg.config[0].absolute work_dir = cfg.work_dir target_path = os.path.join(work_dir, os.path.basename(cfg_path)) diff --git a/data_juicer/core/adapter.py b/data_juicer/core/adapter.py index aa746a058..5ab6e6ec8 100644 --- a/data_juicer/core/adapter.py +++ b/data_juicer/core/adapter.py @@ -1,6 +1,9 @@ +from datasets import concatenate_datasets from datasets.config import DEFAULT_MAX_BATCH_SIZE from data_juicer.core.monitor import Monitor +from data_juicer.ops import UNFORKABLE +from data_juicer.utils.process_utils import setup_mp class Adapter: @@ -27,28 +30,43 @@ def execute_and_probe(dataset, operators, sample_interval=0.5): if operators is None or len(operators) == 0: return [] + # number of test samples + sample_num = len(dataset) + # resource utilization list resource_util_list = [] # probe for each OP + unforkable_operators = set(UNFORKABLE.modules.keys()) for op in operators: - # set num_proc to 1 for each OP to focus on the influence of batch - # size only. - old_num_proc = op.num_proc - op.num_proc = 1 + # select suitable mp method for each OP + mp_context = ['forkserver', 'spawn'] if ( + op.use_cuda() or op._name in unforkable_operators) else None + setup_mp(mp_context) + # expand the test dataset according to the runtime number of + # processes to ensure enough data for a batch and probe the true + # resource utilization for each OP + expanded_dataset = concatenate_datasets([dataset] * + op.runtime_np()) + + # set the test batch size and save the old one + if op.is_batched_op(): + old_batch_size = op.batch_size + op.batch_size = sample_num - # number of test samples - sample_num = len(dataset) # run single op and monitor the resource utilization - dataset, resource_util_per_op = Monitor.monitor_func( - op.run, args=(dataset, ), sample_interval=sample_interval) + _, resource_util_per_op = Monitor.monitor_func( + op.run, + args=(expanded_dataset, ), + sample_interval=sample_interval) # calculate speed resource_util_per_op[ 'speed'] = sample_num / resource_util_per_op['time'] resource_util_list.append(resource_util_per_op) - # restore to the original num_proc - op.num_proc = old_num_proc + # # restore the batch size + if op.is_batched_op(): + op.batch_size = old_batch_size return resource_util_list @@ -96,11 +114,20 @@ def probe_small_batch(self, dataset, operators): current load and estimated OP speed, returning load factors and speed ranks for each OP. + Notice: the probe should be run with cache enabled. + :param dataset: The dataset to pre-execute small batch on :param operators: The OP list to be pre-execution and probe :return: A list of probe results for each OP and the length of data batch to probe. """ + # record the cache state and enable the cache + from datasets import (disable_caching, enable_caching, + is_caching_enabled) + previous_state = is_caching_enabled() + if not previous_state: + enable_caching() + # take a small batch data_batch = self.take_batch(dataset, self.cfg) # process and monitor the resource utilization @@ -108,6 +135,10 @@ def probe_small_batch(self, dataset, operators): # analyze resource utilization analysis_res = Monitor.analyze_resource_util_list(resource_util_list) + # if the cache is disabled before, disable it again + if not previous_state: + disable_caching() + return analysis_res, len(data_batch) def batch_size_strategy(self, load_analysis_res, base_bs=1, util_th=0.9): diff --git a/data_juicer/core/analyzer.py b/data_juicer/core/analyzer.py index e9a6ef8d2..63e512d41 100644 --- a/data_juicer/core/analyzer.py +++ b/data_juicer/core/analyzer.py @@ -9,8 +9,10 @@ from data_juicer.config import init_configs from data_juicer.format import load_formatter from data_juicer.ops import Filter, load_ops +from data_juicer.ops.op_fusion import fuse_operators from data_juicer.utils import cache_utils +from .adapter import Adapter from .exporter import Exporter @@ -31,7 +33,7 @@ def __init__(self, cfg: Optional[Namespace] = None): :param cfg: optional jsonargparse Namespace dict. """ - self.cfg = init_configs() if cfg is None else cfg + self.cfg = init_configs(which_entry=self) if cfg is None else cfg self.work_dir = self.cfg.work_dir @@ -85,10 +87,25 @@ def run(self, if load_data_np is None: load_data_np = self.cfg.np dataset = self.formatter.load_dataset(load_data_np, self.cfg) + if self.cfg.auto: + # if it's auto analysis, only analyze for a minor part of the input + # dataset to save time and computing resource + dataset = dataset.take(min(len(dataset), self.cfg.auto_num)) # extract processes logger.info('Preparing process operators...') - ops = load_ops(self.cfg.process, self.cfg.op_fusion) + ops = load_ops(self.cfg.process) + + if self.cfg.op_fusion: + probe_res = None + if self.cfg.fusion_strategy == 'probe': + logger.info('Probe the OP speed for OP reordering...') + adapter = Adapter(self.cfg) + probe_res, _ = adapter.probe_small_batch(dataset, ops) + + logger.info(f'Start OP fusion and reordering with strategy ' + f'[{self.cfg.fusion_strategy}]...') + ops = fuse_operators(ops, probe_res) # 2. stats precompute only for filter ops logger.info('Computing the stats of dataset...') diff --git a/data_juicer/core/data.py b/data_juicer/core/data.py index 7e51bd1f8..361f6e8a0 100644 --- a/data_juicer/core/data.py +++ b/data_juicer/core/data.py @@ -164,13 +164,16 @@ def __getitem__(self, key): res = super().__getitem__(key) return nested_obj_factory(res) - def process(self, - operators, - *, - work_dir=None, - exporter=None, - checkpointer=None, - tracer=None): + def process( + self, + operators, + *, + work_dir=None, + exporter=None, + checkpointer=None, + tracer=None, + open_monitor=True, + ): if operators is None: return self @@ -179,7 +182,8 @@ def process(self, unforkable_operators = set(UNFORKABLE.modules.keys()) # resource utilization monitor - resource_util_list = [] + if open_monitor: + resource_util_list = [] dataset = self try: @@ -196,12 +200,16 @@ def process(self, 'exporter': exporter, 'tracer': tracer, } - dataset, resource_util_per_op = Monitor.monitor_func( - op.run, args=run_args) + if open_monitor: + dataset, resource_util_per_op = Monitor.monitor_func( + op.run, args=run_args) + else: + dataset = op.run(**run_args) # record processed ops if checkpointer is not None: checkpointer.record(op._op_cfg) - resource_util_list.append(resource_util_per_op) + if open_monitor: + resource_util_list.append(resource_util_per_op) end = time() logger.info(f'OP [{op._name}] Done in {end - start:.3f}s. ' f'Left {len(dataset)} samples.') @@ -215,14 +223,20 @@ def process(self, 'last op...') dataset.cleanup_cache_files() checkpointer.save_ckpt(dataset) - if work_dir: - with open(os.path.join(work_dir, 'monitor.json'), 'w') as out: + if work_dir and open_monitor: + # get the analyzed version + resource_util_list = Monitor.analyze_resource_util_list( + resource_util_list) + monitor_dir = os.path.join(work_dir, 'monitor') + os.makedirs(monitor_dir, exist_ok=True) + with open(os.path.join(monitor_dir, 'monitor.json'), + 'w') as out: json.dump(resource_util_list, out) + Monitor.draw_resource_util_graph(resource_util_list, + monitor_dir) return dataset - def map(self, *args, **kargs): - """Override the map func, which is called by most common operations, - such that the processed samples can be accessed by nested manner.""" + def update_args(self, args, kargs, is_filter=False): if args: args = list(args) # the first positional para is function @@ -248,17 +262,17 @@ def map(self, *args, **kargs): # batched is required for fault-tolerant or batched OP if callable(getattr( called_func.__self__, - 'is_batched_op')) and called_func.__self__.is_batched_op( - ) or not getattr(called_func.__self__, 'turbo', False): + 'is_batched_op')) and called_func.__self__.is_batched_op(): kargs['batched'] = True - kargs['batch_size'] = kargs.pop('batch_size', 1) if hasattr( - called_func.__self__, 'is_batched_op' - ) and called_func.__self__.is_batched_op() else 1 + kargs['batch_size'] = kargs.pop('batch_size', 1) + elif not getattr(called_func.__self__, 'turbo', False): + kargs['batched'] = True + kargs['batch_size'] = 1 else: kargs['batched'] = False - # rank is required for cuda model loading - if callable( + # rank is required for cuda model loading for map + if not is_filter and callable( getattr(called_func.__self__, 'use_cuda')) and called_func.__self__.use_cuda(): kargs['with_rank'] = True @@ -267,6 +281,14 @@ def map(self, *args, **kargs): new_fingerprint = generate_fingerprint(self, *args, **kargs) kargs['new_fingerprint'] = new_fingerprint + return args, kargs + + def map(self, *args, **kargs): + """Override the map func, which is called by most common operations, + such that the processed samples can be accessed by nested manner.""" + + args, kargs = self.update_args(args, kargs) + if cache_utils.CACHE_COMPRESS: decompress(self, kargs['new_fingerprint'], kargs['num_proc'] if 'num_proc' in kargs else 1) @@ -285,38 +307,7 @@ def map(self, *args, **kargs): def filter(self, *args, **kargs): """Override the filter func, which is called by most common operations, such that the processed samples can be accessed by nested manner.""" - if args: - args = list(args) - # the first positional para is function - if args[0] is None: - args[0] = lambda x: nested_obj_factory(x) - else: - args[0] = wrap_func_with_nested_access(args[0]) - called_func = args[0] - else: - if 'function' not in kargs or kargs['function'] is None: - kargs['function'] = lambda x: nested_obj_factory(x) - else: - kargs['function'] = wrap_func_with_nested_access( - kargs['function']) - called_func = kargs['function'] - - # For wrapped function, try to get its unwrapped (bound) method - while not inspect.ismethod(called_func) and hasattr( - called_func, '__wrapped__'): - called_func = called_func.__wrapped__ - - # Batched is always required for fault tolerance - if inspect.ismethod(called_func): - if callable(getattr( - called_func.__self__, - 'is_batched_op')) and called_func.__self__.is_batched_op(): - kargs['batched'] = True - kargs['batch_size'] = kargs.pop('batch_size', 1) - - if 'new_fingerprint' not in kargs or kargs['new_fingerprint'] is None: - new_fingerprint = generate_fingerprint(self, *args, **kargs) - kargs['new_fingerprint'] = new_fingerprint + args, kargs = self.update_args(args, kargs, is_filter=True) # For filter, it involves a map and a filter operations, so the final # cache files includes two sets with different fingerprint (before and diff --git a/data_juicer/core/executor.py b/data_juicer/core/executor.py index 472a5e858..f78059247 100644 --- a/data_juicer/core/executor.py +++ b/data_juicer/core/executor.py @@ -11,6 +11,7 @@ from data_juicer.format.load import load_formatter from data_juicer.format.mixture_formatter import MixtureFormatter from data_juicer.ops import OPERATORS, load_ops +from data_juicer.ops.op_fusion import fuse_operators from data_juicer.utils import cache_utils from data_juicer.utils.ckpt_utils import CheckpointManager @@ -18,6 +19,7 @@ FrequencySpecifiedFieldSelector from ..ops.selector.topk_specified_field_selector import \ TopkSpecifiedFieldSelector +from .adapter import Adapter from .exporter import Exporter from .tracer import Tracer @@ -43,6 +45,8 @@ def __init__(self, cfg: Optional[Namespace] = None): self.tracer = None self.ckpt_manager = None + self.adapter = Adapter(self.cfg) + # only enable it when using cache if self.cfg.use_cache: logger.info(f'Using cache compression method: ' @@ -158,20 +162,45 @@ def run(self, load_data_np = self.cfg.np dataset = self.formatter.load_dataset(load_data_np, self.cfg) - # 2. extract processes + # 2. extract processes and optimize their orders logger.info('Preparing process operators...') - ops = load_ops(self.cfg.process, self.cfg.op_fusion) + ops = load_ops(self.cfg.process) + + # OP fusion + if self.cfg.op_fusion: + probe_res = None + if self.cfg.fusion_strategy == 'probe': + logger.info('Probe the OP speed for OP reordering...') + probe_res, _ = self.adapter.probe_small_batch(dataset, ops) + + logger.info(f'Start OP fusion and reordering with strategy ' + f'[{self.cfg.fusion_strategy}]...') + ops = fuse_operators(ops, probe_res) + + # adaptive batch size + if self.cfg.adaptive_batch_size: + # calculate the adaptive batch size + bs_per_op = self.adapter.adapt_workloads(dataset, ops) + assert len(bs_per_op) == len(ops) + # update the adaptive batch size + logger.info(f'Adapt batch sizes for each OP to {bs_per_op}') + for i, op in enumerate(ops): + if op.is_batched_op(): + op.batch_size = bs_per_op[i] # 3. data process # - If tracer is open, trace each op after it's processed # - If checkpoint is open, clean the cache files after each process logger.info('Processing data...') tstart = time() - dataset = dataset.process(ops, - work_dir=self.work_dir, - exporter=self.exporter, - checkpointer=self.ckpt_manager, - tracer=self.tracer) + dataset = dataset.process( + ops, + work_dir=self.work_dir, + exporter=self.exporter, + checkpointer=self.ckpt_manager, + tracer=self.tracer, + open_monitor=self.cfg.open_monitor, + ) tend = time() logger.info(f'All OPs are done in {tend - tstart:.3f}s.') diff --git a/data_juicer/core/monitor.py b/data_juicer/core/monitor.py index 7d2f7984c..0210e3732 100644 --- a/data_juicer/core/monitor.py +++ b/data_juicer/core/monitor.py @@ -1,3 +1,4 @@ +import os import time from functools import partial from multiprocessing import get_context @@ -28,6 +29,7 @@ class Monitor: '''python { 'time': 10, + 'sampling interval': 0.5, 'resource': [ { 'timestamp': xxx, @@ -50,6 +52,7 @@ class Monitor: '''python { 'time': 10, + 'sampling interval': 0.5, 'resource': [...], 'resource_analysis': { 'GPU free mem.': { @@ -118,6 +121,24 @@ def monitor_current_resources(): return resource_dict + @staticmethod + def draw_resource_util_graph(resource_util_list, store_dir): + import matplotlib.pyplot as plt + for idx, resource_util_dict in enumerate(resource_util_list): + resource_list = resource_util_dict['resource'] + interval = resource_util_dict['sampling interval'] + for focus_metric in Monitor.DYNAMIC_FIELDS: + fn = f'func_{idx}_{focus_metric.replace(" ", "_")}.jpg' + ylbl = '%' if focus_metric.endswith('util.') else 'MB' + metric_list = [item[focus_metric] for item in resource_list] + plt.plot([i * interval for i in range(len(metric_list))], + metric_list) + plt.title(focus_metric) + plt.xlabel('Time (s)') + plt.ylabel(ylbl) + plt.savefig(os.path.join(store_dir, fn), bbox_inches='tight') + plt.clf() + @staticmethod def analyze_resource_util_list(resource_util_list): """ @@ -184,7 +205,10 @@ def monitor_func(func, args=None, sample_interval=0.5): resource_util_dict = {} # start monitor - ctx = get_context('fork') + start_method = 'fork' + if os.name == 'nt': # for Windows + start_method = 'spawn' + ctx = get_context(start_method) with ctx.Manager() as manager: mdict = manager.dict() mdict['stop'] = False @@ -209,6 +233,9 @@ def monitor_func(func, args=None, sample_interval=0.5): resource_util_dict['resource'] = mdict['resource'] + # record interval + resource_util_dict['sampling interval'] = sample_interval + # calculate speed resource_util_dict['time'] = end - start diff --git a/data_juicer/core/ray_data.py b/data_juicer/core/ray_data.py index 2966e75e8..48e26827d 100644 --- a/data_juicer/core/ray_data.py +++ b/data_juicer/core/ray_data.py @@ -1,7 +1,12 @@ +from __future__ import annotations import os from functools import partial -import pyarrow as pa +import os +from functools import partial +from typing import Any, Dict, List, Literal, Optional, Union + +import pyarrow from loguru import logger from data_juicer import cuda_device_count @@ -12,6 +17,7 @@ from data_juicer.utils.process_utils import calculate_np rd = LazyLoader('rd', 'ray.data') +ds = LazyLoader('ds', 'ray.data.datasource') def get_abs_path(path, dataset_dir): @@ -33,7 +39,7 @@ def convert_to_absolute_paths(samples, dataset_dir, path_keys): samples[key][idx] = [ get_abs_path(item, dataset_dir) for item in paths ] - return pa.Table.from_pydict(samples) + return pyarrow.Table.from_pydict(samples) # TODO: check path for nestdataset @@ -71,7 +77,7 @@ def get_num_gpus(op, op_proc): def filter_batch(batch, filter_func): - mask = pa.array(filter_func(batch.to_pydict())) + mask = pyarrow.array(filter_func(batch.to_pydict())) return batch.filter(mask) @@ -115,10 +121,12 @@ def _run_single_op(self, op): elif isinstance(op, Filter): columns = self.data.columns() if Fields.stats not in columns: - - def process_batch_arrow(table: pa.Table) -> pa.Table: + def process_batch_arrow(table: pyarrow.Table): new_column_data = [{} for _ in range(len(table))] - new_talbe = table.append_column(Fields.stats, [new_column_data]) + new_talbe = table.append_column( + Fields.stats, + [new_column_data] + ) return new_talbe self.data = self.data.map_batches(process_batch_arrow, @@ -150,3 +158,89 @@ def process_batch_arrow(table: pa.Table) -> pa.Table: import traceback traceback.print_exc() exit(1) + + @classmethod + def read_json(cls, paths: Union[str, List[str]]) -> RayDataset: + # Note: a temp solution for reading json stream + # TODO: replace with ray.data.read_json_stream once it is available + import pyarrow.json as js + try: + js.open_json + return read_json_stream(paths) + except AttributeError: + return rd.read_json(paths) + + +class JSONStreamDatasource(ds.JSONDatasource): + """ + A temp Datasource for reading json stream. + + Note: + + Depends on a customized `pyarrow` with `open_json` method. + """ + + def _read_stream(self, f: 'pyarrow.NativeFile', path: str): + from pyarrow.json import open_json + + try: + reader = open_json( + f, + read_options=self.read_options, + **self.arrow_json_args, + ) + schema = None + while True: + try: + batch = reader.read_next_batch() + table = pyarrow.Table.from_batches([batch], schema=schema) + if schema is None: + schema = table.schema + yield table + except StopIteration: + return + except pyarrow.lib.ArrowInvalid as e: + raise ValueError(f'Failed to read JSON file: {path}.') from e + + +def read_json_stream( + paths: Union[str, List[str]], + *, + filesystem: Optional['pyarrow.fs.FileSystem'] = None, + parallelism: int = -1, + ray_remote_args: Dict[str, Any] = None, + arrow_open_stream_args: Optional[Dict[str, Any]] = None, + meta_provider=None, + partition_filter=None, + partitioning=ds.partitioning.Partitioning('hive'), + include_paths: bool = False, + ignore_missing_paths: bool = False, + shuffle: Union[Literal['files'], None] = None, + file_extensions: Optional[List[str]] = ['json', 'jsonl'], + concurrency: Optional[int] = None, + override_num_blocks: Optional[int] = None, + **arrow_json_args, +) -> rd.Dataset: + if meta_provider is None: + meta_provider = ds.file_meta_provider.DefaultFileMetadataProvider() + + datasource = JSONStreamDatasource( + paths, + arrow_json_args=arrow_json_args, + filesystem=filesystem, + open_stream_args=arrow_open_stream_args, + meta_provider=meta_provider, + partition_filter=partition_filter, + partitioning=partitioning, + ignore_missing_paths=ignore_missing_paths, + shuffle=shuffle, + include_paths=include_paths, + file_extensions=file_extensions, + ) + return rd.read_datasource( + datasource, + parallelism=parallelism, + ray_remote_args=ray_remote_args, + concurrency=concurrency, + override_num_blocks=override_num_blocks, + ) diff --git a/data_juicer/core/ray_executor.py b/data_juicer/core/ray_executor.py index f146ffc02..cc4b39a0c 100644 --- a/data_juicer/core/ray_executor.py +++ b/data_juicer/core/ray_executor.py @@ -5,8 +5,11 @@ from data_juicer.config import init_configs from data_juicer.core.ray_data import RayDataset from data_juicer.ops import load_ops +from data_juicer.ops.op_fusion import fuse_operators from data_juicer.utils.lazy_loader import LazyLoader +from .adapter import Adapter + ray = LazyLoader('ray', 'ray') rd = LazyLoader('rd', 'ray.data') @@ -33,6 +36,8 @@ def __init__(self, cfg=None): self.work_dir = self.cfg.work_dir + self.adapter = Adapter(self.cfg) + # init ray logger.info('Initing Ray ...') ray.init(self.cfg.ray_address) @@ -56,13 +61,23 @@ def run(self, load_data_np=None): from data_juicer.format.formatter import FORMATTERS dataset = FORMATTERS.modules[obj_name](**args).load_dataset() else: - dataset = rd.read_json(self.cfg.dataset_path) + dataset = RayDataset.read_json(self.cfg.dataset_path) # convert all the path in dataset to absolute path dataset = RayDataset(dataset, self.cfg.dataset_path, self.cfg) # 2. extract processes logger.info('Preparing process operators...') - ops = load_ops(self.cfg.process, self.cfg.op_fusion) + ops = load_ops(self.cfg.process) + + if self.cfg.op_fusion: + probe_res = None + if self.cfg.fusion_strategy == 'probe': + logger.info('Probe the OP speed for OP reordering...') + probe_res, _ = self.adapter.probe_small_batch(dataset, ops) + + logger.info(f'Start OP fusion and reordering with strategy ' + f'[{self.cfg.fusion_strategy}]...') + ops = fuse_operators(ops, probe_res) # 3. data process logger.info('Processing data...') diff --git a/data_juicer/ops/__init__.py b/data_juicer/ops/__init__.py index c7ab44c25..e02e10efa 100644 --- a/data_juicer/ops/__init__.py +++ b/data_juicer/ops/__init__.py @@ -1,6 +1,6 @@ -from . import deduplicator, filter, mapper, selector -from .base_op import (OPERATORS, UNFORKABLE, Deduplicator, Filter, Mapper, - Selector) +from . import aggregator, deduplicator, filter, grouper, mapper, selector +from .base_op import (OPERATORS, UNFORKABLE, Aggregator, Deduplicator, Filter, + Grouper, Mapper, Selector) from .load import load_ops __all__ = [ @@ -9,4 +9,6 @@ 'Mapper', 'Deduplicator', 'Selector', + 'Grouper', + 'Aggregator', ] diff --git a/data_juicer/ops/aggregator/__init__.py b/data_juicer/ops/aggregator/__init__.py new file mode 100644 index 000000000..4afe2974a --- /dev/null +++ b/data_juicer/ops/aggregator/__init__.py @@ -0,0 +1,8 @@ +from .entity_attribute_aggregator import EntityAttributeAggregator +from .most_relavant_entities_aggregator import MostRelavantEntitiesAggregator +from .nested_aggregator import NestedAggregator + +__all__ = [ + 'NestedAggregator', 'EntityAttributeAggregator', + 'MostRelavantEntitiesAggregator' +] diff --git a/data_juicer/ops/aggregator/entity_attribute_aggregator.py b/data_juicer/ops/aggregator/entity_attribute_aggregator.py new file mode 100644 index 000000000..96fbbb63f --- /dev/null +++ b/data_juicer/ops/aggregator/entity_attribute_aggregator.py @@ -0,0 +1,200 @@ +import re +from typing import Dict, Optional + +from loguru import logger +from pydantic import PositiveInt + +from data_juicer.ops.base_op import OPERATORS, Aggregator +from data_juicer.utils.common_utils import (avg_split_string_list_under_limit, + is_string_list, nested_access, + nested_set) +from data_juicer.utils.lazy_loader import LazyLoader +from data_juicer.utils.model_utils import get_model, prepare_model + +from .nested_aggregator import NestedAggregator + +torch = LazyLoader('torch', 'torch') +vllm = LazyLoader('vllm', 'vllm') + +OP_NAME = 'entity_attribute_aggregator' + + +# TODO: LLM-based inference. +@OPERATORS.register_module(OP_NAME) +class EntityAttributeAggregator(Aggregator): + """ + Return conclusion of the given entity's attribute from some docs. + """ + + DEFAULT_SYSTEM_TEMPLATE = ( + '给定与`{entity}`相关的一些文档,总结`{entity}`的`{attribute}`。\n' + '要求:\n' + '- 尽量使用原文专有名词\n' + '- 联系上下文,自动忽略上下文不一致的细节错误\n' + '- 只对文档中与`{entity}`的`{attribute}`有关的内容进行总结\n' + '- 字数限制在**{word_limit}字以内**\n' + '- 要求输出格式如下:\n' + '# {entity}\n' + '## {attribute}\n' + '...\n' + '{example}') + + DEFAULT_EXAMPLE_PROMPT = ('- 例如,根据相关文档总结`孙悟空`的`出身背景`,**100字**以内的样例如下:\n' + '`孙悟空`的`出身背景`总结:\n' + '# 孙悟空\n' + '## 出身背景\n' + '号称齐天大圣,花果山水帘洞的美猴王、西行取经队伍中的大师兄。' + '师父是唐僧玄奘,曾拜菩提祖师学艺。' + '亲生父母未知,自石头中孕育而生。自认斗战胜佛,最怕观世音菩萨和紧箍咒。\n') + + DEFAULT_INPUT_TEMPLATE = ('`{entity}`的相关文档:\n' + '{sub_docs}\n\n' + '`{entity}`的`{attribute}`总结:\n') + + DEFAULT_OUTPUT_PATTERN_TEMPLATE = r'\#\s*{entity}\s*\#\#\s*{attribute}\s*(.*?)\Z' # noqa: E501 + + def __init__(self, + api_model: str = 'gpt-4o', + entity: str = None, + attribute: str = None, + input_key: str = None, + output_key: str = None, + word_limit: PositiveInt = 100, + max_token_num: Optional[PositiveInt] = None, + *, + api_endpoint: Optional[str] = None, + response_path: Optional[str] = None, + system_prompt_template: Optional[str] = None, + example_prompt: Optional[str] = None, + input_template: Optional[str] = None, + output_pattern_template: Optional[str] = None, + try_num: PositiveInt = 3, + model_params: Dict = {}, + sampling_params: Dict = {}, + **kwargs): + """ + Initialization method. + :param api_model: API model name. + :param entity: The given entity. + :param attribute: The given attribute. + :param input_key: The input field key in the samples. Support for + nested keys such as "__dj__stats__.text_len". It is text_key + in default. + :param output_key: The output field key in the samples. Support for + nested keys such as "__dj__stats__.text_len". It is same as the + input_key in default. + :param word_limit: Prompt the output length. + :param max_token_num: The max token num of the total tokens of the + sub documents. Without limitation if it is None. + :param api_endpoint: URL endpoint for the API. + :param response_path: Path to extract content from the API response. + Defaults to 'choices.0.message.content'. + :param system_prompt_template: The system prompt template. + :param example_prompt: The example part in the system prompt. + :param input_template: The input template. + :param output_pattern_template: The output template. + :param try_num: The number of retry attempts when there is an API + call error or output parsing error. + :param model_params: Parameters for initializing the API model. + :param sampling_params: Extra parameters passed to the API call. + e.g {'temperature': 0.9, 'top_p': 0.95} + :param kwargs: Extra keyword arguments. + """ + super().__init__(**kwargs) + + if entity is None or attribute is None: + raise ValueError('The entity and attribute cannot be None!') + + self.entity = entity + self.attribute = attribute + self.input_key = input_key or self.text_key + self.output_key = output_key or self.input_key + self.word_limit = word_limit + self.max_token_num = max_token_num + + system_prompt_template = system_prompt_template or \ + self.DEFAULT_SYSTEM_TEMPLATE + self.example_prompt = example_prompt or self.DEFAULT_EXAMPLE_PROMPT + self.input_template = input_template or self.DEFAULT_INPUT_TEMPLATE + output_pattern_template = output_pattern_template or \ + self.DEFAULT_OUTPUT_PATTERN_TEMPLATE + self.system_prompt = system_prompt_template.format( + entity=self.entity, + attribute=self.attribute, + word_limit=self.word_limit, + example=self.example_prompt) + self.output_pattern = output_pattern_template.format( + entity=entity, attribute=attribute) + + self.sampling_params = sampling_params + self.model_key = prepare_model(model_type='api', + model=api_model, + endpoint=api_endpoint, + response_path=response_path, + return_processor=True, + **model_params) + + self.try_num = try_num + self.nested_sum = NestedAggregator(model=api_model, + max_token_num=max_token_num, + api_endpoint=api_endpoint, + response_path=response_path, + try_num=try_num, + model_params=model_params, + sampling_params=sampling_params) + + def parse_output(self, response): + pattern = re.compile(self.output_pattern, re.VERBOSE | re.DOTALL) + matches = pattern.findall(response) + if matches: + result = matches[0].strip() + else: + result = '' + + return result + + def attribute_summary(self, sub_docs, rank=None): + if not sub_docs: + return '' + + model, tokenizer = get_model(self.model_key, rank, self.use_cuda()) + token_nums = [len(tokenizer.encode(sub_doc)) for sub_doc in sub_docs] + group_docs = avg_split_string_list_under_limit(sub_docs, token_nums, + self.max_token_num) + results = [] + for docs in group_docs: + doc_str = '\n\n'.join(docs) + input_prompt = self.input_template.format(entity=self.entity, + attribute=self.attribute, + sub_docs=doc_str) + messages = [{ + 'role': 'system', + 'content': self.system_prompt + }, { + 'role': 'user', + 'content': input_prompt + }] + result = '' + for i in range(self.try_num): + try: + response = model(messages, **self.sampling_params) + result = self.parse_output(response) + if len(result) > 0: + break + except Exception as e: + logger.warning(f'Exception: {e}') + results.append(result) + + return self.nested_sum.recursive_summary(results) + + def process_single(self, sample=None, rank=None): + + # if not batched sample + sub_docs = nested_access(sample, self.input_key) + if not is_string_list(sub_docs): + return sample + + sample = nested_set(sample, self.output_key, + self.attribute_summary(sub_docs, rank=rank)) + + return sample diff --git a/data_juicer/ops/aggregator/most_relavant_entities_aggregator.py b/data_juicer/ops/aggregator/most_relavant_entities_aggregator.py new file mode 100644 index 000000000..69e1a209c --- /dev/null +++ b/data_juicer/ops/aggregator/most_relavant_entities_aggregator.py @@ -0,0 +1,183 @@ +import re +from typing import Dict, Optional + +from loguru import logger +from pydantic import PositiveInt + +from data_juicer.ops.base_op import OPERATORS, Aggregator +from data_juicer.utils.common_utils import (is_string_list, nested_access, + nested_set) +from data_juicer.utils.lazy_loader import LazyLoader +from data_juicer.utils.model_utils import get_model, prepare_model + +from ..common import split_text_by_punctuation + +torch = LazyLoader('torch', 'torch') +vllm = LazyLoader('vllm', 'vllm') + +OP_NAME = 'most_relavant_entities_aggregator' + + +# TODO: LLM-based inference. +@OPERATORS.register_module(OP_NAME) +class MostRelavantEntitiesAggregator(Aggregator): + """ + Extract entities closely related to a given entity from some texts, + and sort them in descending order of importance. + """ + + DEFAULT_SYSTEM_TEMPLATE = ( + '给定与`{entity}`相关的一些文档,' + '总结一些与`{entity}`最为相关的`{entity_type}`。\n' + '要求:\n' + '- 不用包含与{entity}为同一{entity_type}的{entity_type}。\n' + '- 请按照人物的重要性进行排序,**越重要人物在列表越前面**。\n' + '- 你的返回格式如下:\n' + '## 分析\n' + '你对各个{entity_type}与{entity}关联度的分析\n' + '## 列表\n' + '人物1, 人物2, 人物3, ...') + + DEFAULT_INPUT_TEMPLATE = ('`{entity}`的相关文档:\n' + '{sub_docs}\n\n' + '与`{entity}`最相关的一些`{entity_type}`:\n') + + DEFAULT_OUTPUT_PATTERN = r'\#\#\s*列表\s*(.*?)\Z' + + def __init__(self, + api_model: str = 'gpt-4o', + entity: str = None, + query_entity_type: str = None, + input_key: str = None, + output_key: str = None, + max_token_num: Optional[PositiveInt] = None, + *, + api_endpoint: Optional[str] = None, + response_path: Optional[str] = None, + system_prompt_template: Optional[str] = None, + input_template: Optional[str] = None, + output_pattern: Optional[str] = None, + try_num: PositiveInt = 3, + model_params: Dict = {}, + sampling_params: Dict = {}, + **kwargs): + """ + Initialization method. + :param api_model: API model name. + :param entity: The given entity. + :param query_entity_type: The type of queried relavant entities. + :param input_key: The input field key in the samples. Support for + nested keys such as "__dj__stats__.text_len". It is text_key + in default. + :param output_key: The output field key in the samples. Support for + nested keys such as "__dj__stats__.text_len". It is same as the + input_key in default. + :param max_token_num: The max token num of the total tokens of the + sub documents. Without limitation if it is None. + :param api_endpoint: URL endpoint for the API. + :param response_path: Path to extract content from the API response. + Defaults to 'choices.0.message.content'. + :param system_prompt_template: The system prompt template. + :param input_template: The input template. + :param output_pattern: The output pattern. + :param try_num: The number of retry attempts when there is an API + call error or output parsing error. + :param model_params: Parameters for initializing the API model. + :param sampling_params: Extra parameters passed to the API call. + e.g {'temperature': 0.9, 'top_p': 0.95} + :param kwargs: Extra keyword arguments. + """ + super().__init__(**kwargs) + + if entity is None or query_entity_type is None: + raise ValueError( + 'The entity and query_entity_type cannot be None!') + + self.entity = entity + self.query_entity_type = query_entity_type + self.input_key = input_key or self.text_key + self.output_key = output_key or self.input_key + self.max_token_num = max_token_num + + system_prompt_template = system_prompt_template or \ + self.DEFAULT_SYSTEM_TEMPLATE + self.input_template = input_template or self.DEFAULT_INPUT_TEMPLATE + self.output_pattern = output_pattern or self.DEFAULT_OUTPUT_PATTERN + self.system_prompt = system_prompt_template.format( + entity=entity, entity_type=query_entity_type) + + self.sampling_params = sampling_params + self.model_key = prepare_model(model_type='api', + model=api_model, + endpoint=api_endpoint, + response_path=response_path, + return_processor=True, + **model_params) + + self.try_num = try_num + + def parse_output(self, response): + pattern = re.compile(self.output_pattern, re.VERBOSE | re.DOTALL) + matches = pattern.findall(response) + if matches: + result = matches[0].strip() + else: + result = '' + result = split_text_by_punctuation(result) + + return result + + def query_most_relavant_entities(self, sub_docs, rank=None): + if not sub_docs: + return '' + + model, tokenizer = get_model(self.model_key, rank, self.use_cuda()) + token_nums = [len(tokenizer.encode(sub_doc)) for sub_doc in sub_docs] + if self.max_token_num is None: + final_docs = sub_docs + else: + final_docs = [] + total_num = 0 + for token_num, doc in zip(token_nums, sub_docs): + total_num += token_num + if total_num > self.max_token_num: + break + final_docs.append(doc) + + doc_str = '\n\n'.join(final_docs) + input_prompt = self.input_template.format( + entity=self.entity, + entity_type=self.query_entity_type, + sub_docs=doc_str) + + messages = [{ + 'role': 'system', + 'content': self.system_prompt + }, { + 'role': 'user', + 'content': input_prompt + }] + result = [] + for i in range(self.try_num): + try: + response = model(messages, **self.sampling_params) + result = self.parse_output(response) + if len(result) > 0: + break + except Exception as e: + logger.warning(f'Exception: {e}') + + return result + + def process_single(self, sample=None, rank=None): + + # if not batched sample + sub_docs = nested_access(sample, self.input_key) + if not is_string_list(sub_docs): + return sample + + sample = nested_set( + sample, self.output_key, + self.query_most_relavant_entities(sub_docs, rank=rank)) + + return sample diff --git a/data_juicer/ops/aggregator/nested_aggregator.py b/data_juicer/ops/aggregator/nested_aggregator.py new file mode 100644 index 000000000..124eb1470 --- /dev/null +++ b/data_juicer/ops/aggregator/nested_aggregator.py @@ -0,0 +1,179 @@ +from typing import Dict, Optional + +from loguru import logger +from pydantic import PositiveInt + +from data_juicer.ops.base_op import OPERATORS, Aggregator +from data_juicer.utils.common_utils import (avg_split_string_list_under_limit, + is_string_list, nested_access) +from data_juicer.utils.lazy_loader import LazyLoader +from data_juicer.utils.model_utils import get_model, prepare_model + +torch = LazyLoader('torch', 'torch') +vllm = LazyLoader('vllm', 'vllm') + +OP_NAME = 'nested_aggregator' + + +# TODO: LLM-based inference. +@OPERATORS.register_module(OP_NAME) +class NestedAggregator(Aggregator): + """ + Considering the limitation of input length, nested aggregate + contents for each given number of samples. + """ + + DEFAULT_SYSTEM_PROMPT = ('给定一些文档碎片,将这些文档整合成一个文档总结。\n' + '要求:\n' + '- 总结的长度与文档碎片的平均长度基本一致\n' + '- 不要包含主观看法\n' + '- 注意要尽可能保留文本的专有名词\n' + '- 只输出文档总结不要输出其他内容\n' + '- 参考如下样例:\n' + '文档碎片:\n' + '唐僧师徒四人行至白虎岭,遇上了变化多端的白骨精。\n\n' + '文档碎片:\n' + '白骨精首次变身少女送斋,被孙悟空识破打死,唐僧责怪悟空。\n\n' + '文档碎片:\n' + '妖怪再变老妇寻女,又被悟空击毙,师傅更加不满,念紧箍咒惩罚。\n\n' + '文档碎片:\n' + '不甘心的白骨精第三次化作老公公来诱骗,依旧逃不过金睛火眼。\n\n' + '文档碎片:\n' + '最终,在观音菩萨的帮助下,真相大白,唐僧明白了自己的误解。\n\n' + '\n' + '文档总结:\n' + '唐僧师徒在白虎岭三遇白骨精变化诱惑,悟空屡次识破击毙妖怪却遭误解,最终观音相助真相大白。') + + DEFAULT_INPUT_TEMPLATE = ('{sub_docs}\n\n' + '文档总结:\n') + + DEFAULT_SUB_DOC_TEMPLATE = '文档碎片:\n{text}\n' + + def __init__(self, + api_model: str = 'gpt-4o', + input_key: str = None, + output_key: str = None, + max_token_num: Optional[PositiveInt] = None, + *, + api_endpoint: Optional[str] = None, + response_path: Optional[str] = None, + system_prompt: Optional[str] = None, + sub_doc_template: Optional[str] = None, + input_template: Optional[str] = None, + try_num: PositiveInt = 3, + model_params: Dict = {}, + sampling_params: Dict = {}, + **kwargs): + """ + Initialization method. + :param api_model: API model name. + :param input_key: The input field key in the samples. Support for + nested keys such as "__dj__stats__.text_len". It is text_key + in default. + :param output_key: The output field key in the samples. Support for + nested keys such as "__dj__stats__.text_len". It is same as the + input_key in default. + :param max_token_num: The max token num of the total tokens of the + sub documents. Without limitation if it is None. + :param api_endpoint: URL endpoint for the API. + :param response_path: Path to extract content from the API response. + Defaults to 'choices.0.message.content'. + :param system_prompt: The system prompt. + :param sub_doc_template: The template for input text in each sample. + :param input_template: The input template. + :param try_num: The number of retry attempts when there is an API + call error or output parsing error. + :param model_params: Parameters for initializing the API model. + :param sampling_params: Extra parameters passed to the API call. + e.g {'temperature': 0.9, 'top_p': 0.95} + :param kwargs: Extra keyword arguments. + """ + super().__init__(**kwargs) + + self.input_key = input_key or self.text_key + self.output_key = output_key or self.input_key + self.max_token_num = max_token_num + + self.system_prompt = system_prompt or self.DEFAULT_SYSTEM_PROMPT + self.sub_doc_template = sub_doc_template or \ + self.DEFAULT_SUB_DOC_TEMPLATE + self.input_template = input_template or self.DEFAULT_INPUT_TEMPLATE + + self.sampling_params = sampling_params + self.model_key = prepare_model(model_type='api', + model=api_model, + endpoint=api_endpoint, + response_path=response_path, + return_processor=True, + **model_params) + + self.try_num = try_num + + def parse_output(self, response): + + def if_match(text): + quotes = [("'", "'"), ('"', '"'), ('“', '”'), ('‘', '’'), + ('`', '`')] + if len(text) < 2: + return False + if (text[0], text[-1]) in quotes: + return True + else: + return False + + text = response.strip() + while if_match(text): + text = text[1:-1].strip() + return text + + def recursive_summary(self, sub_docs, rank=None): + if not sub_docs: + return '' + if len(sub_docs) == 1: + return sub_docs[0] + model, tokenizer = get_model(self.model_key, rank, self.use_cuda()) + token_nums = [len(tokenizer.encode(sub_doc)) for sub_doc in sub_docs] + group_docs = avg_split_string_list_under_limit(sub_docs, token_nums, + self.max_token_num) + # merge every two if every single sub doc is a group + group_num = len(group_docs) + if group_num == len(sub_docs): + group_docs = [ + group_docs[i] + + group_docs[i + 1] if i + 1 < group_num else group_docs[i] + for i in range(0, group_num, 2) + ] + results = [] + for docs in group_docs: + doc_strs = [self.sub_doc_template.format(text=d) for d in docs] + input_prompt = self.input_template.format( + sub_docs='\n'.join(doc_strs)) + messages = [{ + 'role': 'system', + 'content': self.system_prompt + }, { + 'role': 'user', + 'content': input_prompt + }] + result = '' + for i in range(self.try_num): + try: + response = model(messages, **self.sampling_params) + result = self.parse_output(response) + if len(result) > 0: + break + except Exception as e: + logger.warning(f'Exception: {e}') + results.append(result) + return self.recursive_summary(results) + + def process_single(self, sample=None, rank=None): + + # if not batched sample + sub_docs = nested_access(sample, self.input_key) + if not is_string_list(sub_docs): + return sample + + sample[self.output_key] = self.recursive_summary(sub_docs, rank=rank) + + return sample diff --git a/data_juicer/ops/base_op.py b/data_juicer/ops/base_op.py index 13f3b61ae..2091a867e 100644 --- a/data_juicer/ops/base_op.py +++ b/data_juicer/ops/base_op.py @@ -70,7 +70,7 @@ def wrapper(samples, *args, **kwargs): return wrapper -def catch_map_single_exception(method): +def catch_map_single_exception(method, return_sample=True): """ For single-map sample-level fault tolerance. The input sample is expected batch_size = 1. @@ -92,8 +92,11 @@ def wrapper(sample, *args, **kwargs): if is_batched(sample): try: sample = convert_dict_list_to_list_dict(sample)[0] - res_sample = method(sample, *args, **kwargs) - return convert_list_dict_to_dict_list([res_sample]) + res = method(sample, *args, **kwargs) + if return_sample: + return convert_list_dict_to_dict_list([res]) + else: + return [res] except Exception as e: from loguru import logger logger.error( @@ -128,6 +131,11 @@ def __init__(self, *args, **kwargs): to be processed :param video_key: the key name of field that stores sample video list to be processed + :param query_key: the key name of field that stores sample queris + :param response_key: the key name of field that stores responses + :param history_key: the key name of field that stores history of + queries and responses + :param index_key: index the samples before process if not None """ # init data keys self.text_key = kwargs.get('text_key', 'text') @@ -139,6 +147,8 @@ def __init__(self, *args, **kwargs): self.response_key = kwargs.get('response_key', 'response') self.history_key = kwargs.get('history_key', 'history') + self.index_key = kwargs.get('index_key', None) + self.batch_size = kwargs.get('batch_size', 1000) # whether the model can be accelerated using cuda @@ -166,9 +176,8 @@ def __init__(self, *args, **kwargs): method = wrap_func_with_nested_access(method) setattr(self, name, method) - @classmethod - def is_batched_op(cls): - return cls._batched_op + def is_batched_op(self): + return self._batched_op def process(self, *args, **kwargs): raise NotImplementedError @@ -214,6 +223,14 @@ def run(self, dataset): from data_juicer.core.data import NestedDataset if not isinstance(dataset, NestedDataset): dataset = NestedDataset(dataset) + if self.index_key is not None: + + def add_index(sample, idx): + sample[self.index_key] = idx + return sample + + dataset = dataset.map(add_index, with_indices=True) + return dataset def empty_history(self): @@ -234,6 +251,10 @@ def __init__(self, *args, **kwargs): to be processed :param video_key: the key name of field that stores sample video list to be processed + :param query_key: the key name of field that stores sample queris + :param response_key: the key name of field that stores responses + :param history_key: the key name of field that stores history of + queries and responses """ super(Mapper, self).__init__(*args, **kwargs) @@ -257,11 +278,22 @@ def process_batched(self, samples, *args, **kwargs): keys = samples.keys() first_key = next(iter(keys)) num_samples = len(samples[first_key]) + + new_keys = {} for i in range(num_samples): this_sample = {key: samples[key][i] for key in keys} res_sample = self.process_single(this_sample, *args, **kwargs) - for key in keys: - samples[key][i] = res_sample[key] + res_keys = res_sample.keys() + for key in res_keys: + if key not in keys: + if key not in new_keys: + new_keys.update({key: []}) + new_keys[key].append(res_sample[key]) + else: + samples[key][i] = res_sample[key] + + for k, v in new_keys.items(): + samples[k] = v return samples @@ -303,6 +335,10 @@ def __init__(self, *args, **kwargs): to be processed :param video_key: the key name of field that stores sample video list to be processed + :param query_key: the key name of field that stores sample queris + :param response_key: the key name of field that stores responses + :param history_key: the key name of field that stores history of + queries and responses """ super(Filter, self).__init__(*args, **kwargs) self.stats_export_path = kwargs.get('stats_export_path', None) @@ -315,7 +351,8 @@ def __init__(self, *args, **kwargs): else: self.compute_stats = catch_map_single_exception( self.compute_stats_single) - self.process = catch_map_single_exception(self.process_single) + self.process = catch_map_single_exception(self.process_single, + return_sample=False) # set the process method is not allowed to be overridden def __init_subclass__(cls, **kwargs): @@ -410,6 +447,10 @@ def __init__(self, *args, **kwargs): to be processed :param video_key: the key name of field that stores sample video list to be processed + :param query_key: the key name of field that stores sample queris + :param response_key: the key name of field that stores responses + :param history_key: the key name of field that stores history of + queries and responses """ super(Deduplicator, self).__init__(*args, **kwargs) @@ -469,6 +510,10 @@ def __init__(self, *args, **kwargs): to be processed :param video_key: the key name of field that stores sample video list to be processed + :param query_key: the key name of field that stores sample queris + :param response_key: the key name of field that stores responses + :param history_key: the key name of field that stores history of + queries and responses """ super(Selector, self).__init__(*args, **kwargs) @@ -487,3 +532,90 @@ def run(self, dataset, *, exporter=None, tracer=None): if tracer: tracer.trace_filter(self._name, dataset, new_dataset) return new_dataset + + +class Grouper(OP): + + def __init__(self, *args, **kwargs): + """ + Base class that group samples. + + :param text_key: the key name of field that stores sample texts + to be processed + :param image_key: the key name of field that stores sample image list + to be processed + :param audio_key: the key name of field that stores sample audio list + to be processed + :param video_key: the key name of field that stores sample video list + to be processed + :param query_key: the key name of field that stores sample queris + :param response_key: the key name of field that stores responses + :param history_key: the key name of field that stores history of + queries and responses + """ + super(Grouper, self).__init__(*args, **kwargs) + + def process(self, dataset): + """ + Dataset --> dataset. + + :param dataset: input dataset + :return: dataset of batched samples. + """ + raise NotImplementedError + + def run(self, dataset, *, exporter=None, tracer=None): + dataset = super(Grouper, self).run(dataset) + batched_samples = self.process(dataset) + from data_juicer.core.data import NestedDataset + new_dataset = NestedDataset.from_list(batched_samples) + if tracer: + tracer.trace_filter(self._name, dataset, new_dataset) + return new_dataset + + +class Aggregator(OP): + + def __init__(self, *args, **kwargs): + """ + Base class that group samples. + + :param text_key: the key name of field that stores sample texts + to be processed + :param image_key: the key name of field that stores sample image list + to be processed + :param audio_key: the key name of field that stores sample audio list + to be processed + :param video_key: the key name of field that stores sample video list + to be processed + :param query_key: the key name of field that stores sample queris + :param response_key: the key name of field that stores responses + :param history_key: the key name of field that stores history of + queries and responses + """ + super(Aggregator, self).__init__(*args, **kwargs) + self.process = catch_map_single_exception(self.process_single) + + def process_single(self, sample): + """ + For sample level, batched sample --> sample, + the input must be the output of some Grouper OP. + + :param sample: batched sample to aggregate + :return: aggregated sample + """ + raise NotImplementedError + + def run(self, dataset, *, exporter=None, tracer=None): + dataset = super(Aggregator, self).run(dataset) + new_dataset = dataset.map( + self.process, + num_proc=self.runtime_np(), + with_rank=self.use_cuda(), + batch_size=self.batch_size, + desc=self._name + '_process', + ) + if tracer: + tracer.trace_mapper(self._name, dataset, new_dataset, + self.text_key) + return new_dataset diff --git a/data_juicer/ops/deduplicator/__init__.py b/data_juicer/ops/deduplicator/__init__.py index 3e9f55f47..71c5e4863 100644 --- a/data_juicer/ops/deduplicator/__init__.py +++ b/data_juicer/ops/deduplicator/__init__.py @@ -6,7 +6,6 @@ from .ray_document_deduplicator import RayDocumentDeduplicator from .ray_image_deduplicator import RayImageDeduplicator from .ray_bts_minhash_deduplicator import RayBTSMinhashDeduplicator -from .ray_redis_minhash_deduplicator import RayRedisMinhashDeduplicator from .ray_video_deduplicator import RayVideoDeduplicator from .video_deduplicator import VideoDeduplicator @@ -14,6 +13,5 @@ 'DocumentDeduplicator', 'DocumentMinhashDeduplicator', 'DocumentSimhashDeduplicator', 'ImageDeduplicator', 'RayBasicDeduplicator', 'RayDocumentDeduplicator', 'RayImageDeduplicator', 'RayVideoDeduplicator', - 'RayImageDeduplicator', 'RayRedisMinhashDeduplicator', - 'RayBTSMinhashDeduplicator', 'VideoDeduplicator', + 'RayImageDeduplicator', 'RayBTSMinhashDeduplicator', 'VideoDeduplicator', ] diff --git a/data_juicer/ops/deduplicator/ray_redis_minhash_deduplicator.py b/data_juicer/ops/deduplicator/ray_redis_minhash_deduplicator.py deleted file mode 100644 index ee5478e3b..000000000 --- a/data_juicer/ops/deduplicator/ray_redis_minhash_deduplicator.py +++ /dev/null @@ -1,380 +0,0 @@ -import random -import time -import uuid -from collections import defaultdict -from typing import Optional - -import numpy as np -import pandas as pd -import pyarrow as pa -import regex -from loguru import logger -from pydantic import Field, PositiveInt -from typing_extensions import Annotated - -from data_juicer.utils.constant import HashKeys -from data_juicer.utils.lazy_loader import LazyLoader -from data_juicer.utils.model_utils import prepare_sentencepiece_model - -from ..base_op import OPERATORS, Deduplicator -from ..common.helper_func import split_on_whitespace -from .document_minhash_deduplicator import (MAX_HASH, MERSENNE_PRIME, - optimal_param, sha1_hash32) - -redis = LazyLoader('redis', 'redis') - - -def retry_on_busy(func): - - def wrapper(*args, **kwargs): - max_retries = 10 - for attempt in range(max_retries): - try: - return func(*args, **kwargs) - except Exception as e: - if 'BUSY' in str(e) and attempt < max_retries - 1: - time.sleep(random.uniform(0.1, 0.3) * (2**attempt)) - else: - raise - - return wrapper - - -class RedisUnionFind: - - def __init__(self, - prefix: str, - redis_address: str = 'redis://localhost:6379'): - self.prefix = prefix - self.redis_address = redis_address - self.redis = redis.from_url(url=redis_address) - self.set_key = f'{prefix}_UF_SET' - self.rank_key = f'{prefix}_UF_RANK' - self.incur_id_key = f'{prefix}_UF_INCURID' - - # Lua scripts - self.union_script = self.redis.register_script(""" - local function find(x) - local path = {} - while true do - local parent = redis.call('HGET', KEYS[1], x) - if not parent then - return nil - end - if parent == x then - break - end - table.insert(path, x) - x = parent - end - for _, node in ipairs(path) do - redis.call('HSET', KEYS[1], node, x) - end - return x - end - - local root_x = find(ARGV[1]) - local root_y = find(ARGV[2]) - if not root_x then - redis.call('HSET', KEYS[1], ARGV[1], ARGV[1]) - redis.call('HSET', KEYS[2], ARGV[1], 0) - root_x = ARGV[1] - end - if not root_y then - redis.call('HSET', KEYS[1], ARGV[2], ARGV[2]) - redis.call('HSET', KEYS[2], ARGV[2], 0) - root_y = ARGV[2] - end - if root_x == root_y then - return root_x - end - local rank_x = tonumber(redis.call('HGET', KEYS[2], root_x)) - local rank_y = tonumber(redis.call('HGET', KEYS[2], root_y)) - if rank_x < rank_y then - redis.call('HSET', KEYS[1], root_x, root_y) - return root_y - elseif rank_x > rank_y then - redis.call('HSET', KEYS[1], root_y, root_x) - return root_x - else - redis.call('HSET', KEYS[1], root_y, root_x) - redis.call('HINCRBY', KEYS[2], root_x, 1) - return root_x - end - """) - - def get_uid(self): - return int(self.redis.incr(self.incur_id_key)) - - @retry_on_busy - def union(self, x, y): - return self.union_script(keys=[self.set_key, self.rank_key], - args=[x, y]) - - def is_ancestor(self, x): - ancestor = self.redis.hget(self.set_key, x) - return ancestor is None or int(ancestor) == x - - def __reduce__(self): - return (RedisUnionFind, (self.prefix, self.redis_address)) - - def clean(self): - self.redis.delete(self.set_key, self.rank_key, self.incur_id_key) - - -OP_NAME = 'ray_redis_minhash_deduplicator' - - -@OPERATORS.register_module(OP_NAME) -class RayRedisMinhashDeduplicator(Deduplicator): - """ - A basic exact matching deduplicator for RAY. - Although its functionality is deduplication, - it is implemented as Filter sub-class. - """ - - def __init__( - self, - tokenization: str = 'space', - window_size: PositiveInt = 5, - lowercase: bool = True, - ignore_pattern: Optional[str] = None, - num_permutations: PositiveInt = 256, - jaccard_threshold: Annotated[float, Field(ge=0, le=1)] = 0.7, - num_bands: Optional[PositiveInt] = None, - num_rows_per_band: Optional[PositiveInt] = None, - tokenizer_model: Optional[str] = None, - redis_address: str = 'redis://localhost:6380', - *args, - **kwargs, - ): - """ - Initialization method. - - :param tokenization: tokenization method for sample texts. It - should be one of [space, punctuation, character, - sentencepiece]. For English-like languages, we recommend - to use 'space', for Chinese-like languages, we recommend - to use 'character', and for multiple languages, we recommend - to use 'sentencepiece'. If using 'sentencepiece', please - provided the model path in the 'tokenizer_model' field. - :param window_size: window size of shingling - :param lowercase: whether to convert text to lower case first - :param ignore_pattern: whether to ignore sub-strings with - specific pattern when computing minhash - :param num_permutations: number of permutations in minhash - computing - :param jaccard_threshold: the min jaccard similarity threshold - in near-duplicate detection. When the jaccard similarity of - two sample texts is >= this threshold, they are regarded as - similar samples and this op will only keep one of them after - deduplication - :param num_bands: number of bands in LSH. Default it's None, and - it will be determined by an optimal params computation - algorithm by minimize the weighted sum of probs of False - Positives and False Negatives - :param num_rows_per_band: number of rows in each band in LSH. - Default it's None, and it will be determined by an optimal - params computation algorithm - :param tokenizer_model: path for the sentencepiece model, used for - sentencepiece tokenization. - :param redis_address: address of your redis instance, e.g. - 'redis://localhost:6379' - """ - super().__init__(*args, **kwargs) - # about minhash computation - self.tokenization = tokenization - self.window_size = window_size - self.lowercase = lowercase - self.ignore_pattern = ignore_pattern - if self.ignore_pattern: - self.ignore_pattern = regex.compile(self.ignore_pattern) - - # check parameters - if self.ignore_pattern and self.tokenization == 'punctuation': - logger.warning('Be careful that tokenization with punctuations ' - 'won\'t work if the ignore pattern includes ' - 'punctuations.') - self.punctuation_pattern = regex.compile(r'\p{P}') - - if self.tokenization == 'sentencepiece': - if tokenizer_model is None: - raise ValueError("To use 'sentencepiece' tokenization, " - "'tokenizer_model' is required.") - self.tokenizer = prepare_sentencepiece_model(tokenizer_model) - else: - self.tokenizer = None - - # about deduplication - self.num_permutation = num_permutations - self.jaccard_threshold = jaccard_threshold - self.num_bands = num_bands - self.num_rows_per_band = num_rows_per_band - - # initialize deduplication parameters - # check number of bands and rows - if self.num_bands is None or self.num_rows_per_band is None: - self.num_bands, self.num_rows_per_band = optimal_param( - self.jaccard_threshold, - self.num_permutation, - ) - - # compute hash ranges and create hash tables - self.hash_ranges = [(i * self.num_rows_per_band, - (i + 1) * self.num_rows_per_band) - for i in range(self.num_bands)] - self.hash_tables = [defaultdict(set) for _ in range(self.num_bands)] - - # generate permutations - gen = np.random.RandomState(seed=42) - self.perm_a, self.perm_b = np.array( - [( - gen.randint(1, MERSENNE_PRIME, dtype=np.uint64), - gen.randint(0, MERSENNE_PRIME, dtype=np.uint64), - ) for _ in range(self.num_permutation)], - dtype=np.uint64, - ).T - self.redis_address = redis_address - - def run(self, dataset): - from ray.data.aggregate import AggregateFn - - union_find = RedisUnionFind(prefix=uuid.uuid4().hex[:8], - redis_address=self.redis_address) - - def add_uid_column(table: pa.Table) -> pa.Table: - new_column_data = [union_find.get_uid() for _ in range(len(table))] - new_table = table.append_column(HashKeys.uid, [new_column_data]) - return new_table - - def calculate_minhash(table: pa.Table) -> pa.Table: - ids = table.column(HashKeys.uid).to_pandas() - texts = table.column(self.text_key).to_pandas() - hashes = texts.apply(lambda x: self.compute_minhash(x)) - hashes = pa.Array.from_pandas(hashes).flatten() - - repeated_ids = pa.Array.from_pandas(ids.repeat(self.num_bands)) - - return pa.Table.from_arrays([repeated_ids, hashes], - names=[HashKeys.uid, HashKeys.minhash]) - - def _is_null(r): - return pd.isnull(r) - - class UnionFn(AggregateFn): - - def __init__(self, union_find): - union_find = union_find - - def accumulate(cur, row): - if _is_null(row): - return cur - elif _is_null(cur): - return row[HashKeys.uid] - else: - root = union_find.union(row[HashKeys.uid], cur) - return int(root) - - def merge(a, b): - if _is_null(a): - return b - if _is_null(b): - return a - root = union_find.union(a, b) - return int(root) - - super().__init__( - init=lambda k: None, - accumulate_row=accumulate, - merge=merge, - name='union', - ) - - def filter_with_union_find(table: pa.Table) -> pa.Table: - uids = table.column(HashKeys.uid).to_pandas() - mask = pa.Array.from_pandas( - uids.apply(lambda x: union_find.is_ancestor(x))) - return table.filter(mask) - - dataset_with_id = dataset.map_batches( - add_uid_column, batch_format='pyarrow').materialize() - dataset_with_id.map_batches(calculate_minhash, - batch_format='pyarrow').groupby( - HashKeys.minhash).aggregate( - UnionFn(union_find)).materialize() - result = dataset_with_id.map_batches(filter_with_union_find, - batch_format='pyarrow').materialize() - logger.info(f'Keep {result.count()} samples after MinHash dedup.') - union_find.clean() - return result - - def compute_minhash(self, text): - """ - Compute minhash values for the sample. - - :param sample: input sample - :return: sample with minhash value. - """ - if self.lowercase: - text = text.lower() - if self.ignore_pattern: - text = self.ignore_pattern.sub('', text) - - # get tokens for different tokenization method - tokens = set() - if self.tokenization == 'character': - tokens = { - str.encode(text[i:i + self.window_size]) - for i in range(len(text) - self.window_size) - } - elif self.tokenization == 'punctuation': - tokens = self.punctuation_pattern.split(text) - tokens = { - str.encode(' '.join(tokens[i:i + self.window_size])) - for i in range(len(tokens) - self.window_size) - } - elif self.tokenization == 'space': - tokens = split_on_whitespace(text) - tokens = { - str.encode(' '.join(tokens[i:i + self.window_size])) - for i in range(len(tokens) - self.window_size) - } - elif self.tokenization == 'sentencepiece': - tokens = self.tokenizer.encode(text, out_type=str) - tokens = { - str.encode(''.join(tokens[i:i + self.window_size])) - for i in range(len(tokens) - self.window_size) - } - else: - raise NotImplementedError( - f'Unimplemented tokenization method [{self.tokenization}]') - - # # compute minhash value - # hv = np.array([sha1_hash32(token) for token in tokens], - # dtype=np.uint64) - # phv = np.bitwise_and( - # ((hv * np.tile(self.perm_a, - # (len(hv), 1)).T).T + self.perm_b) % MERSENNE_PRIME, - # MAX_HASH) - # hash_values = np.vstack([ - # phv, - # np.ones(self.num_permutation, dtype=np.uint64) * MAX_HASH - # ]).min(axis=0) - if len(tokens) > 0: - hv = np.array( - [sha1_hash32(token) for token in tokens], - dtype=np.uint64 - ) - phv = ( - (hv[:, None] * self.perm_a[None, :] - + self.perm_b) % MERSENNE_PRIME - ).astype(np.uint32) - hash_values = phv.min(axis=0) - else: - hash_values = np.full_like(self.perm_a, MAX_HASH, dtype=np.uint32) - return [ - bytes(hash_values[start:end].byteswap().data) + - start.to_bytes(4, byteorder='little') - for start, end in self.hash_ranges - # groupby minhash||brand_id - ] diff --git a/data_juicer/ops/filter/__init__.py b/data_juicer/ops/filter/__init__.py index dad6818e1..8cb986b2b 100644 --- a/data_juicer/ops/filter/__init__.py +++ b/data_juicer/ops/filter/__init__.py @@ -63,3 +63,10 @@ 'VideoTaggingFromFramesFilter', 'VideoWatermarkFilter', 'WordRepetitionFilter', 'WordsNumFilter' ] + +NON_STATS_FILTERS = [ + 'specified_field_filter', + 'specified_numeric_field_filter', + 'suffix_filter', + 'video_tagging_from_frames_filter', +] diff --git a/data_juicer/ops/filter/flagged_words_filter.py b/data_juicer/ops/filter/flagged_words_filter.py index dfadb0737..406ae1a23 100644 --- a/data_juicer/ops/filter/flagged_words_filter.py +++ b/data_juicer/ops/filter/flagged_words_filter.py @@ -24,6 +24,8 @@ class FlaggedWordFilter(Filter): """Filter to keep samples with flagged-word ratio less than a specific max value.""" + _batched_op = True + def __init__(self, lang: str = 'en', tokenization: bool = False, @@ -72,53 +74,59 @@ def __init__(self, self.model_key = prepare_model(model_type='sentencepiece', lang=lang) - def compute_stats_single(self, sample, context=False): + def compute_stats_batched(self, samples, context=False): # check if it's computed already - if StatsKeys.flagged_words_ratio in sample[Fields.stats]: - return sample - - # try to get words from context + samples_list = samples[self.text_key] + samples_stats = samples[Fields.stats] words_key = f'{InterVars.words}-{self.model_key}' - if context and words_key in sample[Fields.context]: - words = sample[Fields.context][words_key] - else: - tokenizer = get_model(self.model_key) - words = get_words_from_document( - sample[self.text_key], - token_func=tokenizer.encode_as_pieces if tokenizer else None) - if context: - sample[Fields.context][words_key] = words - - # try to get refined words from context - refined_words_key = f'{InterVars.refined_words}-True-SPECIAL_CHARS-' \ - f'{self.use_words_aug}-' \ - f'{self.words_aug_group_sizes}-' \ - f'{self.words_aug_join_char}' - if context and refined_words_key in sample[Fields.context]: - words = sample[Fields.context][refined_words_key] - else: - words = words_refinement( - words, - lower_case=True, - strip_chars=SPECIAL_CHARACTERS, - use_words_aug=self.use_words_aug, - words_aug_group_sizes=self.words_aug_group_sizes, - words_aug_join_char=self.words_aug_join_char) - if context: - sample[Fields.context][refined_words_key] = words - - flagged_words_ratio = (len( - [word - for word in words if word in self.FLAGGED_WORDS[self.lang]]) / - len(words)) if len(words) != 0 else 0.0 - - if flagged_words_ratio > 1.0: - flagged_words_ratio = 1.0 - - sample[Fields.stats][ - StatsKeys.flagged_words_ratio] = flagged_words_ratio - return sample - - def process_single(self, sample): - return sample[Fields.stats][ - StatsKeys.flagged_words_ratio] <= self.max_ratio + tokenizer = get_model(self.model_key) + for idx, stat in enumerate(samples_stats): + if StatsKeys.flagged_words_ratio in stat: + continue + if context and words_key in samples[Fields.context][idx]: + words = samples[Fields.context][idx][words_key] + else: + words = get_words_from_document( + samples_list[idx], + token_func=tokenizer.encode_as_pieces + if tokenizer else None) + if context: + samples[Fields.context][idx][words_key] = words + # try to get refined words from context + refined_words_key = f'{InterVars.refined_words}' \ + '-True-SPECIAL_CHARS-' \ + f'{self.use_words_aug}-' \ + f'{self.words_aug_group_sizes}-' \ + f'{self.words_aug_join_char}' + if context and refined_words_key in samples[Fields.context][idx]: + words = samples[Fields.context][idx][refined_words_key] + else: + words = words_refinement( + words, + lower_case=True, + strip_chars=SPECIAL_CHARACTERS, + use_words_aug=self.use_words_aug, + words_aug_group_sizes=self.words_aug_group_sizes, + words_aug_join_char=self.words_aug_join_char) + if context: + samples[Fields.context][idx][refined_words_key] = words + + flagged_words_ratio = (len([ + word for word in words if word in self.FLAGGED_WORDS[self.lang] + ]) / len(words)) if len(words) != 0 else 0.0 + + if flagged_words_ratio > 1.0: + flagged_words_ratio = 1.0 + + samples_stats[idx][ + StatsKeys.flagged_words_ratio] = flagged_words_ratio + + return samples + + def process_batched(self, samples): + return list( + map( + lambda stat: stat[StatsKeys.flagged_words_ratio] <= self. + max_ratio, + samples[Fields.stats], + )) diff --git a/data_juicer/ops/filter/image_aesthetics_filter.py b/data_juicer/ops/filter/image_aesthetics_filter.py index bbaba15eb..723845a5d 100644 --- a/data_juicer/ops/filter/image_aesthetics_filter.py +++ b/data_juicer/ops/filter/image_aesthetics_filter.py @@ -46,7 +46,7 @@ def __init__(self, :param args: Extra positional arguments. :param kwargs: Extra keyword arguments. """ - + kwargs.setdefault('mem_required', '1500MB') super().__init__(*args, **kwargs) if hf_scorer_model == '': hf_scorer_model = \ diff --git a/data_juicer/ops/filter/image_aspect_ratio_filter.py b/data_juicer/ops/filter/image_aspect_ratio_filter.py index e069a1943..d3b3785ee 100644 --- a/data_juicer/ops/filter/image_aspect_ratio_filter.py +++ b/data_juicer/ops/filter/image_aspect_ratio_filter.py @@ -14,6 +14,8 @@ class ImageAspectRatioFilter(Filter): AspectRatio = W / H. """ + _batched_op = True + def __init__(self, min_ratio: float = 0.333, max_ratio: float = 3.0, @@ -40,43 +42,53 @@ def __init__(self, f'Can only be one of ["any", "all"].') self.any = (any_or_all == 'any') - def compute_stats_single(self, sample, context=False): - # check if it's computed already - if StatsKeys.aspect_ratios in sample[Fields.stats]: - return sample - - # there is no image in this sample - if self.image_key not in sample or not sample[self.image_key]: - sample[Fields.stats][StatsKeys.aspect_ratios] = np.array( - [], dtype=np.float64) - return sample - - # load images - loaded_image_keys = sample[self.image_key] - sample, images = load_data_with_context(sample, context, - loaded_image_keys, load_image) - - # compute aspect ratios for each image with W/H - aspect_ratios = { - key: (images[key].width / images[key].height) - for key in images - } - sample[Fields.stats][StatsKeys.aspect_ratios] = [ - aspect_ratios[key] for key in loaded_image_keys - ] - return sample - - def process_single(self, sample): - aspect_ratios = sample[Fields.stats][StatsKeys.aspect_ratios] - keep_bools = np.array([ - self.min_ratio <= aspect_ratio <= self.max_ratio - for aspect_ratio in aspect_ratios - ]) - if len(keep_bools) <= 0: - return True - - # different strategies - if self.any: - return keep_bools.any() - else: - return keep_bools.all() + def compute_stats_batched(self, samples, context=False): + image_list = samples[self.image_key] + samples_stats = samples[Fields.stats] + + for i, stat in enumerate(samples_stats): + # check if it's computed already + if StatsKeys.aspect_ratios in stat: + continue + + # there is no image in this sample + loaded_image_keys = image_list[i] + if not loaded_image_keys: + stat[StatsKeys.aspect_ratios] = np.array([], dtype=np.float64) + continue + + # load images + samples, images = load_data_with_context(samples, context, + loaded_image_keys, + load_image) + + # compute aspect ratios for each image with W/H + aspect_ratios = { + key: (images[key].width / images[key].height) + for key in images + } + stat[StatsKeys.aspect_ratios] = [ + aspect_ratios[key] for key in loaded_image_keys + ] + + return samples + + def process_batched(self, samples): + + def process_single(values): + keep_bools = np.array([ + self.min_ratio <= value <= self.max_ratio for value in values + ]) + if len(keep_bools) <= 0: + return True + + # different strategies + if self.any: + return keep_bools.any() + else: + return keep_bools.all() + + return map( + lambda stat: process_single(stat[StatsKeys.aspect_ratios]), + samples[Fields.stats], + ) diff --git a/data_juicer/ops/filter/image_nsfw_filter.py b/data_juicer/ops/filter/image_nsfw_filter.py index 603a48518..aea409ec4 100644 --- a/data_juicer/ops/filter/image_nsfw_filter.py +++ b/data_juicer/ops/filter/image_nsfw_filter.py @@ -41,6 +41,7 @@ def __init__(self, :param args: extra args :param kwargs: extra args """ + kwargs.setdefault('mem_required', '1GB') super().__init__(*args, **kwargs) self.score_threshold = score_threshold if any_or_all not in ['any', 'all']: diff --git a/data_juicer/ops/filter/image_shape_filter.py b/data_juicer/ops/filter/image_shape_filter.py index 064929111..b265add30 100644 --- a/data_juicer/ops/filter/image_shape_filter.py +++ b/data_juicer/ops/filter/image_shape_filter.py @@ -15,6 +15,8 @@ class ImageShapeFilter(Filter): """Filter to keep samples with image shape (w, h) within specific ranges. """ + _batched_op = True + def __init__(self, min_width: int = 1, max_width: int = sys.maxsize, diff --git a/data_juicer/ops/filter/image_size_filter.py b/data_juicer/ops/filter/image_size_filter.py index f4ab8f760..fd8b7bcef 100644 --- a/data_juicer/ops/filter/image_size_filter.py +++ b/data_juicer/ops/filter/image_size_filter.py @@ -12,6 +12,8 @@ class ImageSizeFilter(Filter): specific range. """ + _batched_op = True + def __init__(self, min_size: str = '0', max_size: str = '1TB', diff --git a/data_juicer/ops/filter/image_text_matching_filter.py b/data_juicer/ops/filter/image_text_matching_filter.py index dc36cd68a..6881eccf5 100644 --- a/data_juicer/ops/filter/image_text_matching_filter.py +++ b/data_juicer/ops/filter/image_text_matching_filter.py @@ -52,6 +52,7 @@ def __init__(self, :param args: extra args :param kwargs: extra args """ + kwargs.setdefault('mem_required', '1500MB') super().__init__(*args, **kwargs) self.min_score = min_score self.max_score = max_score diff --git a/data_juicer/ops/filter/image_text_similarity_filter.py b/data_juicer/ops/filter/image_text_similarity_filter.py index ac23330c3..9a3f9361b 100644 --- a/data_juicer/ops/filter/image_text_similarity_filter.py +++ b/data_juicer/ops/filter/image_text_similarity_filter.py @@ -19,6 +19,7 @@ class ImageTextSimilarityFilter(Filter): within a specific range.""" _accelerator = 'cuda' + _batched_op = True def __init__(self, hf_clip: str = 'openai/clip-vit-base-patch32', @@ -52,6 +53,7 @@ def __init__(self, :param args: extra args :param kwargs: extra args """ + kwargs.setdefault('mem_required', '1500MB') super().__init__(*args, **kwargs) self.min_score = min_score self.max_score = max_score diff --git a/data_juicer/ops/filter/image_watermark_filter.py b/data_juicer/ops/filter/image_watermark_filter.py index 0d9eead6a..b752736a4 100644 --- a/data_juicer/ops/filter/image_watermark_filter.py +++ b/data_juicer/ops/filter/image_watermark_filter.py @@ -45,6 +45,7 @@ def __init__(self, :param args: extra args :param kwargs: extra args """ + kwargs.setdefault('mem_required', '500MB') super().__init__(*args, **kwargs) self.prob_threshold = prob_threshold if any_or_all not in ['any', 'all']: diff --git a/data_juicer/ops/filter/perplexity_filter.py b/data_juicer/ops/filter/perplexity_filter.py index 6a1e6e67e..6a6d74e16 100644 --- a/data_juicer/ops/filter/perplexity_filter.py +++ b/data_juicer/ops/filter/perplexity_filter.py @@ -45,6 +45,7 @@ def compute_stats_batched(self, samples, context=False): samples_list = samples[self.text_key] samples_stats = samples[Fields.stats] words_key = f'{InterVars.words}-{self.sp_model_key}' + tokenizer = get_model(self.sp_model_key) for idx, stat in enumerate(samples_stats): # check if it's computed already @@ -54,7 +55,6 @@ def compute_stats_batched(self, samples, context=False): if context and words_key in samples[Fields.context][idx]: words = samples[Fields.context][idx][words_key] else: - tokenizer = get_model(self.sp_model_key) words = get_words_from_document( samples_list[idx], token_func=tokenizer.encode_as_pieces diff --git a/data_juicer/ops/filter/phrase_grounding_recall_filter.py b/data_juicer/ops/filter/phrase_grounding_recall_filter.py index 98a2dfb1f..9dec0dc3c 100644 --- a/data_juicer/ops/filter/phrase_grounding_recall_filter.py +++ b/data_juicer/ops/filter/phrase_grounding_recall_filter.py @@ -114,6 +114,7 @@ def __init__(self, :param args: extra args :param kwargs: extra args """ + kwargs.setdefault('mem_required', '1GB') super().__init__(*args, **kwargs) self.min_recall = min_recall self.max_recall = max_recall diff --git a/data_juicer/ops/filter/video_aesthetics_filter.py b/data_juicer/ops/filter/video_aesthetics_filter.py index 5e674162d..f65334f56 100644 --- a/data_juicer/ops/filter/video_aesthetics_filter.py +++ b/data_juicer/ops/filter/video_aesthetics_filter.py @@ -73,7 +73,7 @@ def __init__(self, :param args: Extra positional arguments. :param kwargs: Extra keyword arguments. """ - + kwargs.setdefault('mem_required', '1500MB') super().__init__(*args, **kwargs) if hf_scorer_model == '': hf_scorer_model = \ diff --git a/data_juicer/ops/filter/video_frames_text_similarity_filter.py b/data_juicer/ops/filter/video_frames_text_similarity_filter.py index 6b3e92641..da793ccf4 100644 --- a/data_juicer/ops/filter/video_frames_text_similarity_filter.py +++ b/data_juicer/ops/filter/video_frames_text_similarity_filter.py @@ -74,6 +74,7 @@ def __init__(self, :param args: extra args :param kwargs: extra args """ + kwargs.setdefault('mem_required', '1500MB') super().__init__(*args, **kwargs) self.min_score = min_score self.max_score = max_score diff --git a/data_juicer/ops/filter/video_nsfw_filter.py b/data_juicer/ops/filter/video_nsfw_filter.py index 27bafe1d0..a1dd9d214 100644 --- a/data_juicer/ops/filter/video_nsfw_filter.py +++ b/data_juicer/ops/filter/video_nsfw_filter.py @@ -65,6 +65,7 @@ def __init__(self, :param args: extra args :param kwargs: extra args """ + kwargs.setdefault('mem_required', '1GB') super().__init__(*args, **kwargs) self.score_threshold = score_threshold if frame_sampling_method not in ['all_keyframes', 'uniform']: diff --git a/data_juicer/ops/filter/video_tagging_from_frames_filter.py b/data_juicer/ops/filter/video_tagging_from_frames_filter.py index 7c41b5521..8872aab32 100644 --- a/data_juicer/ops/filter/video_tagging_from_frames_filter.py +++ b/data_juicer/ops/filter/video_tagging_from_frames_filter.py @@ -61,6 +61,7 @@ def __init__(self, :param args: extra args :param kwargs: extra args """ + kwargs.setdefault('mem_required', '9GB') super().__init__(*args, **kwargs) if contain not in ['any', 'all']: raise ValueError(f'the containing type [{contain}] is not ' diff --git a/data_juicer/ops/filter/video_watermark_filter.py b/data_juicer/ops/filter/video_watermark_filter.py index 2b7e30f8f..959c91e23 100644 --- a/data_juicer/ops/filter/video_watermark_filter.py +++ b/data_juicer/ops/filter/video_watermark_filter.py @@ -69,6 +69,7 @@ def __init__(self, :param args: extra args :param kwargs: extra args """ + kwargs.setdefault('mem_required', '500MB') super().__init__(*args, **kwargs) self.prob_threshold = prob_threshold if frame_sampling_method not in ['all_keyframes', 'uniform']: diff --git a/data_juicer/ops/grouper/__init__.py b/data_juicer/ops/grouper/__init__.py new file mode 100644 index 000000000..048b305e4 --- /dev/null +++ b/data_juicer/ops/grouper/__init__.py @@ -0,0 +1,4 @@ +from .key_value_grouper import KeyValueGrouper +from .naive_grouper import NaiveGrouper + +__all__ = ['NaiveGrouper', 'KeyValueGrouper'] diff --git a/data_juicer/ops/grouper/key_value_grouper.py b/data_juicer/ops/grouper/key_value_grouper.py new file mode 100644 index 000000000..3d786319f --- /dev/null +++ b/data_juicer/ops/grouper/key_value_grouper.py @@ -0,0 +1,51 @@ +from typing import List, Optional + +from data_juicer.utils.common_utils import dict_to_hash, nested_access + +from ..base_op import OPERATORS, Grouper, convert_list_dict_to_dict_list +from .naive_grouper import NaiveGrouper + + +@OPERATORS.register_module('key_value_grouper') +class KeyValueGrouper(Grouper): + """Group samples to batched samples according values in given keys. """ + + def __init__(self, + group_by_keys: Optional[List[str]] = None, + *args, + **kwargs): + """ + Initialization method. + + :param group_by_keys: group samples according values in the keys. + Support for nested keys such as "__dj__stats__.text_len". + It is [self.text_key] in default. + :param args: extra args + :param kwargs: extra args + """ + super().__init__(*args, **kwargs) + + self.group_by_keys = group_by_keys or [self.text_key] + self.naive_grouper = NaiveGrouper() + + def process(self, dataset): + + if len(dataset) == 0: + return dataset + + sample_map = {} + for sample in dataset: + cur_dict = {} + for key in self.group_by_keys: + cur_dict[key] = nested_access(sample, key) + sample_key = dict_to_hash(cur_dict) + if sample_key in sample_map: + sample_map[sample_key].append(sample) + else: + sample_map[sample_key] = [sample] + + batched_samples = [ + convert_list_dict_to_dict_list(sample_map[k]) for k in sample_map + ] + + return batched_samples diff --git a/data_juicer/ops/grouper/naive_grouper.py b/data_juicer/ops/grouper/naive_grouper.py new file mode 100644 index 000000000..4633dc48e --- /dev/null +++ b/data_juicer/ops/grouper/naive_grouper.py @@ -0,0 +1,24 @@ +from ..base_op import OPERATORS, Grouper, convert_list_dict_to_dict_list + + +@OPERATORS.register_module('naive_grouper') +class NaiveGrouper(Grouper): + """Group all samples to one batched sample. """ + + def __init__(self, *args, **kwargs): + """ + Initialization method. + + :param args: extra args + :param kwargs: extra args + """ + super().__init__(*args, **kwargs) + + def process(self, dataset): + + if len(dataset) == 0: + return dataset + + batched_sample = convert_list_dict_to_dict_list(dataset) + + return [batched_sample] diff --git a/data_juicer/ops/load.py b/data_juicer/ops/load.py index cf10cc51a..e0a4fb0b8 100644 --- a/data_juicer/ops/load.py +++ b/data_juicer/ops/load.py @@ -1,15 +1,12 @@ from .base_op import OPERATORS -from .op_fusion import fuse_operators -def load_ops(process_list, op_fusion=False): +def load_ops(process_list): """ Load op list according to the process list from config file. :param process_list: A process list. Each item is an op name and its arguments. - :param op_fusion: whether to fuse ops that share the same intermediate - variables. :return: The op instance list. """ ops = [] @@ -19,10 +16,7 @@ def load_ops(process_list, op_fusion=False): ops.append(OPERATORS.modules[op_name](**args)) new_process_list.append(process) - # detect filter groups - if op_fusion: - new_process_list, ops = fuse_operators(new_process_list, ops) - + # store the OP configs into each OP for op_cfg, op in zip(new_process_list, ops): op._op_cfg = op_cfg diff --git a/data_juicer/ops/mapper/__init__.py b/data_juicer/ops/mapper/__init__.py index 41bf092a3..9b86b83dc 100644 --- a/data_juicer/ops/mapper/__init__.py +++ b/data_juicer/ops/mapper/__init__.py @@ -14,6 +14,7 @@ from .extract_event_mapper import ExtractEventMapper from .extract_keyword_mapper import ExtractKeywordMapper from .extract_nickname_mapper import ExtractNicknameMapper +from .extract_support_text_mapper import ExtractSupportTextMapper from .fix_unicode_mapper import FixUnicodeMapper from .generate_qa_from_examples_mapper import GenerateQAFromExamplesMapper from .generate_qa_from_text_mapper import GenerateQAFromTextMapper @@ -28,7 +29,11 @@ from .optimize_qa_mapper import OptimizeQAMapper from .optimize_query_mapper import OptimizeQueryMapper from .optimize_response_mapper import OptimizeResponseMapper +from .pair_preference_mapper import PairPreferenceMapper from .punctuation_normalization_mapper import PunctuationNormalizationMapper +from .python_file_mapper import PythonFileMapper +from .python_lambda_mapper import PythonLambdaMapper +from .relation_identity_mapper import RelationIdentityMapper from .remove_bibliography_mapper import RemoveBibliographyMapper from .remove_comments_mapper import RemoveCommentsMapper from .remove_header_mapper import RemoveHeaderMapper @@ -49,6 +54,7 @@ from .video_captioning_from_summarizer_mapper import \ VideoCaptioningFromSummarizerMapper from .video_captioning_from_video_mapper import VideoCaptioningFromVideoMapper +from .video_extract_frames_mapper import VideoExtractFramesMapper from .video_face_blur_mapper import VideoFaceBlurMapper from .video_ffmpeg_wrapped_mapper import VideoFFmpegWrappedMapper from .video_remove_watermark_mapper import VideoRemoveWatermarkMapper @@ -67,20 +73,23 @@ 'CleanEmailMapper', 'CleanHtmlMapper', 'CleanIpMapper', 'CleanLinksMapper', 'ExpandMacroMapper', 'ExtractEntityAttributeMapper', 'ExtractEntityRelationMapper', 'ExtractEventMapper', - 'ExtractKeywordMapper', 'ExtractNicknameMapper', 'FixUnicodeMapper', + 'ExtractKeywordMapper', 'ExtractNicknameMapper', + 'ExtractSupportTextMapper', 'FixUnicodeMapper', 'GenerateQAFromExamplesMapper', 'GenerateQAFromTextMapper', 'ImageBlurMapper', 'ImageCaptioningFromGPT4VMapper', 'ImageCaptioningMapper', 'ImageDiffusionMapper', 'ImageFaceBlurMapper', 'ImageTaggingMapper', 'NlpaugEnMapper', 'NlpcdaZhMapper', 'OptimizeQAMapper', 'OptimizeQueryMapper', 'OptimizeResponseMapper', - 'PunctuationNormalizationMapper', 'RemoveBibliographyMapper', - 'RemoveCommentsMapper', 'RemoveHeaderMapper', 'RemoveLongWordsMapper', - 'RemoveNonChineseCharacterlMapper', 'RemoveRepeatSentencesMapper', - 'RemoveSpecificCharsMapper', 'RemoveTableTextMapper', - 'RemoveWordsWithIncorrectSubstringsMapper', 'ReplaceContentMapper', - 'SentenceSplitMapper', 'TextChunkMapper', 'VideoCaptioningFromAudioMapper', - 'VideoCaptioningFromFramesMapper', 'VideoCaptioningFromSummarizerMapper', - 'VideoCaptioningFromVideoMapper', 'VideoFFmpegWrappedMapper', + 'PairPreferenceMapper', 'PunctuationNormalizationMapper', + 'PythonFileMapper', 'PythonLambdaMapper', 'RelationIdentityMapper', + 'RemoveBibliographyMapper', 'RemoveCommentsMapper', 'RemoveHeaderMapper', + 'RemoveLongWordsMapper', 'RemoveNonChineseCharacterlMapper', + 'RemoveRepeatSentencesMapper', 'RemoveSpecificCharsMapper', + 'RemoveTableTextMapper', 'RemoveWordsWithIncorrectSubstringsMapper', + 'ReplaceContentMapper', 'SentenceSplitMapper', 'TextChunkMapper', + 'VideoCaptioningFromAudioMapper', 'VideoCaptioningFromFramesMapper', + 'VideoCaptioningFromSummarizerMapper', 'VideoCaptioningFromVideoMapper', + 'VideoExtractFramesMapper', 'VideoFFmpegWrappedMapper', 'VideoFaceBlurMapper', 'VideoRemoveWatermarkMapper', 'VideoResizeAspectRatioMapper', 'VideoResizeResolutionMapper', 'VideoSplitByDurationMapper', 'VideoSplitByKeyFrameMapper', diff --git a/data_juicer/ops/mapper/calibrate_qa_mapper.py b/data_juicer/ops/mapper/calibrate_qa_mapper.py index 69b860e33..8480ee899 100644 --- a/data_juicer/ops/mapper/calibrate_qa_mapper.py +++ b/data_juicer/ops/mapper/calibrate_qa_mapper.py @@ -4,14 +4,13 @@ from loguru import logger from pydantic import PositiveInt -from data_juicer.ops.base_op import OPERATORS, UNFORKABLE, Mapper +from data_juicer.ops.base_op import OPERATORS, Mapper from data_juicer.utils.model_utils import get_model, prepare_model OP_NAME = 'calibrate_qa_mapper' # TODO: LLM-based inference. -@UNFORKABLE.register_module(OP_NAME) @OPERATORS.register_module(OP_NAME) class CalibrateQAMapper(Mapper): """ @@ -107,7 +106,7 @@ def process_single(self, sample, rank=None): 'content': self.build_input(sample) }] parsed_q, parsed_a = None, None - for i in range(self.try_num): + for _ in range(self.try_num): try: output = client(messages, **self.sampling_params) parsed_q, parsed_a = self.parse_output(output) diff --git a/data_juicer/ops/mapper/calibrate_query_mapper.py b/data_juicer/ops/mapper/calibrate_query_mapper.py index 88098d7f8..48ae0c4f7 100644 --- a/data_juicer/ops/mapper/calibrate_query_mapper.py +++ b/data_juicer/ops/mapper/calibrate_query_mapper.py @@ -1,11 +1,10 @@ -from data_juicer.ops.base_op import OPERATORS, UNFORKABLE +from data_juicer.ops.base_op import OPERATORS from data_juicer.ops.mapper.calibrate_qa_mapper import CalibrateQAMapper OP_NAME = 'calibrate_query_mapper' # TODO: LLM-based inference. -@UNFORKABLE.register_module(OP_NAME) @OPERATORS.register_module(OP_NAME) class CalibrateQueryMapper(CalibrateQAMapper): """ diff --git a/data_juicer/ops/mapper/calibrate_response_mapper.py b/data_juicer/ops/mapper/calibrate_response_mapper.py index db56af317..1d6456c2b 100644 --- a/data_juicer/ops/mapper/calibrate_response_mapper.py +++ b/data_juicer/ops/mapper/calibrate_response_mapper.py @@ -1,11 +1,10 @@ -from data_juicer.ops.base_op import OPERATORS, UNFORKABLE +from data_juicer.ops.base_op import OPERATORS from data_juicer.ops.mapper.calibrate_qa_mapper import CalibrateQAMapper OP_NAME = 'calibrate_response_mapper' # TODO: LLM-based inference. -@UNFORKABLE.register_module(OP_NAME) @OPERATORS.register_module(OP_NAME) class CalibrateResponseMapper(CalibrateQAMapper): """ diff --git a/data_juicer/ops/mapper/extract_entity_attribute_mapper.py b/data_juicer/ops/mapper/extract_entity_attribute_mapper.py index 1fab935f9..0fc76b11f 100644 --- a/data_juicer/ops/mapper/extract_entity_attribute_mapper.py +++ b/data_juicer/ops/mapper/extract_entity_attribute_mapper.py @@ -1,11 +1,10 @@ import re -from itertools import chain from typing import Dict, List, Optional from loguru import logger from pydantic import PositiveInt -from data_juicer.ops.base_op import OPERATORS, UNFORKABLE, Mapper +from data_juicer.ops.base_op import OPERATORS, Mapper from data_juicer.utils.constant import Fields from data_juicer.utils.model_utils import get_model, prepare_model @@ -13,41 +12,43 @@ # TODO: LLM-based inference. -@UNFORKABLE.register_module(OP_NAME) @OPERATORS.register_module(OP_NAME) class ExtractEntityAttributeMapper(Mapper): """ Extract attributes for given entities from the text """ - _batched_op = True - DEFAULT_SYSTEM_PROMPT_TEMPLATE = ( '给定一段文本,从文本中总结{entity}的{attribute},并且从原文摘录最能说明该{attribute}的代表性示例。\n' '要求:\n' '- 摘录的示例应该简短。\n' '- 遵循如下的回复格式:\n' + '# {entity}\n' '## {attribute}:\n' - '{entity}的{attribute}描述...\n' - '### 代表性示例1:\n' - '说明{entity}该{attribute}的原文摘录1...\n' - '### 代表性示例2:\n' - '说明{entity}该{attribute}的原文摘录2...\n' + '...\n' + '### 代表性示例摘录1:\n' + '```\n' + '...\n' + '```\n' + '### 代表性示例摘录2:\n' + '```\n' + '...\n' + '```\n' '...\n') DEFAULT_INPUT_TEMPLATE = '# 文本\n```\n{text}\n```\n' DEFAULT_ATTR_PATTERN_TEMPLATE = r'\#\#\s*{attribute}:\s*(.*?)(?=\#\#\#|\Z)' - DEFAULT_DEMON_PATTERN = r'\#\#\#\s*代表性示例(\d+):\s*(.*?)(?=\#\#\#|\Z)' + DEFAULT_DEMON_PATTERN = r'\#\#\#\s*代表性示例摘录(\d+):\s*```\s*(.*?)```\s*(?=\#\#\#|\Z)' # noqa: E501 def __init__(self, + api_model: str = 'gpt-4o', query_entities: List[str] = [], query_attributes: List[str] = [], - api_model: str = 'gpt-4o', *, - entity_key: str = Fields.main_entity, - attribute_key: str = Fields.attribute, - attribute_desc_key: str = Fields.attribute_description, - support_text_key: str = Fields.attribute_support_text, + entity_key: str = Fields.main_entities, + attribute_key: str = Fields.attributes, + attribute_desc_key: str = Fields.attribute_descriptions, + support_text_key: str = Fields.attribute_support_texts, api_endpoint: Optional[str] = None, response_path: Optional[str] = None, system_prompt_template: Optional[str] = None, @@ -61,9 +62,9 @@ def __init__(self, **kwargs): """ Initialization method. + :param api_model: API model name. :param query_entities: Entity list to be queried. :param query_attributes: Attribute list to be queried. - :param api_model: API model name. :param entity_key: The field name to store the given main entity for attribute extraction. It's "__dj__entity__" in default. :param entity_attribute_key: The field name to store the given @@ -136,7 +137,7 @@ def parse_output(self, raw_output, attribute_name): return attribute, demos - def _process_single_sample(self, text='', rank=None): + def _process_single_text(self, text='', rank=None): client = get_model(self.model_key, rank=rank) entities, attributes, descs, demo_lists = [], [], [], [] @@ -154,7 +155,7 @@ def _process_single_sample(self, text='', rank=None): }] desc, demos = '', [] - for i in range(self.try_num): + for _ in range(self.try_num): try: output = client(messages, **self.sampling_params) desc, demos = self.parse_output(output, attribute) @@ -169,31 +170,17 @@ def _process_single_sample(self, text='', rank=None): return entities, attributes, descs, demo_lists - def process_batched(self, samples, rank=None): - - sample_num = len(samples[self.text_key]) + def process_single(self, sample, rank=None): - entities, attributes, descs, demo_lists = [], [], [], [] - for text in samples[self.text_key]: - res = self._process_single_sample(text, rank=rank) - cur_ents, cur_attrs, cur_descs, cur_demos = res - entities.append(cur_ents) - attributes.append(cur_attrs) - descs.append(cur_descs) - demo_lists.append(cur_demos) + res = self._process_single_text(sample[self.text_key], rank=rank) + entities, attributes, descs, demo_lists = res if self.drop_text: - samples.pop(self.text_key) - - for key in samples: - samples[key] = [[samples[key][i]] * len(descs[i]) - for i in range(sample_num)] - samples[self.entity_key] = entities - samples[self.attribute_key] = attributes - samples[self.attribute_desc_key] = descs - samples[self.support_text_key] = demo_lists + sample.pop(self.text_key) - for key in samples: - samples[key] = list(chain(*samples[key])) + sample[self.entity_key] = entities + sample[self.attribute_key] = attributes + sample[self.attribute_desc_key] = descs + sample[self.support_text_key] = demo_lists - return samples + return sample diff --git a/data_juicer/ops/mapper/extract_entity_relation_mapper.py b/data_juicer/ops/mapper/extract_entity_relation_mapper.py index 4b026f2a4..6350101ac 100644 --- a/data_juicer/ops/mapper/extract_entity_relation_mapper.py +++ b/data_juicer/ops/mapper/extract_entity_relation_mapper.py @@ -9,7 +9,7 @@ from loguru import logger from pydantic import NonNegativeInt, PositiveInt -from data_juicer.ops.base_op import OPERATORS, UNFORKABLE, Mapper +from data_juicer.ops.base_op import OPERATORS, Mapper from data_juicer.utils.common_utils import is_float from data_juicer.utils.constant import Fields from data_juicer.utils.model_utils import get_model, prepare_model @@ -20,7 +20,6 @@ # TODO: LLM-based inference. -@UNFORKABLE.register_module(OP_NAME) @OPERATORS.register_module(OP_NAME) class ExtractEntityRelationMapper(Mapper): """ @@ -319,7 +318,7 @@ def process_single(self, sample, rank=None): messages = [{'role': 'user', 'content': input_prompt}] entities, relations = [], [] - for i in range(self.try_num): + for _ in range(self.try_num): try: result = self.light_rag_extraction(messages, rank=rank) entities, relations = self.parse_output(result) diff --git a/data_juicer/ops/mapper/extract_event_mapper.py b/data_juicer/ops/mapper/extract_event_mapper.py index 208684b2c..fddf4fed1 100644 --- a/data_juicer/ops/mapper/extract_event_mapper.py +++ b/data_juicer/ops/mapper/extract_event_mapper.py @@ -5,7 +5,7 @@ from loguru import logger from pydantic import PositiveInt -from data_juicer.ops.base_op import OPERATORS, UNFORKABLE, Mapper +from data_juicer.ops.base_op import OPERATORS, Mapper from data_juicer.utils.constant import Fields from data_juicer.utils.model_utils import get_model, prepare_model @@ -15,7 +15,6 @@ # TODO: LLM-based inference. -@UNFORKABLE.register_module(OP_NAME) @OPERATORS.register_module(OP_NAME) class ExtractEventMapper(Mapper): """ @@ -134,7 +133,7 @@ def _process_single_sample(self, text='', rank=None): }] event_list, character_list = [], [] - for i in range(self.try_num): + for _ in range(self.try_num): try: output = client(messages, **self.sampling_params) event_list, character_list = self.parse_output(output) diff --git a/data_juicer/ops/mapper/extract_keyword_mapper.py b/data_juicer/ops/mapper/extract_keyword_mapper.py index cb1814768..24e3e127e 100644 --- a/data_juicer/ops/mapper/extract_keyword_mapper.py +++ b/data_juicer/ops/mapper/extract_keyword_mapper.py @@ -6,7 +6,7 @@ from loguru import logger from pydantic import PositiveInt -from data_juicer.ops.base_op import OPERATORS, UNFORKABLE, Mapper +from data_juicer.ops.base_op import OPERATORS, Mapper from data_juicer.utils.constant import Fields from data_juicer.utils.model_utils import get_model, prepare_model @@ -16,7 +16,6 @@ # TODO: LLM-based inference. -@UNFORKABLE.register_module(OP_NAME) @OPERATORS.register_module(OP_NAME) class ExtractKeywordMapper(Mapper): """ @@ -173,7 +172,7 @@ def process_single(self, sample, rank=None): messages = [{'role': 'user', 'content': input_prompt}] keywords = [] - for i in range(self.try_num): + for _ in range(self.try_num): try: result = client(messages, **self.sampling_params) keywords = self.parse_output(result) diff --git a/data_juicer/ops/mapper/extract_nickname_mapper.py b/data_juicer/ops/mapper/extract_nickname_mapper.py index b11cbab57..20aeb94db 100644 --- a/data_juicer/ops/mapper/extract_nickname_mapper.py +++ b/data_juicer/ops/mapper/extract_nickname_mapper.py @@ -4,7 +4,7 @@ from loguru import logger from pydantic import PositiveInt -from data_juicer.ops.base_op import OPERATORS, UNFORKABLE, Mapper +from data_juicer.ops.base_op import OPERATORS, Mapper from data_juicer.utils.constant import Fields from data_juicer.utils.model_utils import get_model, prepare_model @@ -12,7 +12,6 @@ # TODO: LLM-based inference. -@UNFORKABLE.register_module(OP_NAME) @OPERATORS.register_module(OP_NAME) class ExtractNicknameMapper(Mapper): """ @@ -143,7 +142,7 @@ def process_single(self, sample, rank=None): 'content': input_prompt }] nickname_relations = [] - for i in range(self.try_num): + for _ in range(self.try_num): try: output = client(messages, **self.sampling_params) nickname_relations = self.parse_output(output) diff --git a/data_juicer/ops/mapper/extract_support_text_mapper.py b/data_juicer/ops/mapper/extract_support_text_mapper.py new file mode 100644 index 000000000..34bdbe653 --- /dev/null +++ b/data_juicer/ops/mapper/extract_support_text_mapper.py @@ -0,0 +1,132 @@ +from typing import Dict, Optional + +from loguru import logger +from pydantic import PositiveInt + +from data_juicer.ops.base_op import OPERATORS, Mapper +from data_juicer.utils.common_utils import nested_access, nested_set +from data_juicer.utils.constant import Fields +from data_juicer.utils.model_utils import get_model, prepare_model + +OP_NAME = 'extract_support_text_mapper' + + +# TODO: LLM-based inference. +@OPERATORS.register_module(OP_NAME) +class ExtractSupportTextMapper(Mapper): + """ + Extract support sub text for a summary. + """ + + DEFAULT_SYSTEM_PROMPT = ('你将扮演一个文本摘录助手的角色。你的主要任务是基于给定' + '的文章(称为“原文”)以及对原文某个部分的简短描述或总结' + '(称为“总结”),准确地识别并提取出与该总结相对应的原文' + '片段。\n' + '要求:\n' + '- 你需要尽可能精确地匹配到最符合总结内容的那部分内容\n' + '- 如果存在多个可能的答案,请选择最贴近总结意思的那个\n' + '- 下面是一个例子帮助理解这一过程:\n' + '### 原文:\n' + '《红楼梦》是中国古典小说四大名著之一,由清代作家曹雪芹创' + '作。它讲述了贾宝玉、林黛玉等人的爱情故事及四大家族的兴衰' + '历程。书中通过复杂的人物关系展现了封建社会的各种矛盾冲突' + '。其中关于贾府内部斗争的部分尤其精彩,特别是王熙凤与尤二' + '姐之间的争斗,生动描绘了权力争夺下的女性形象。此外,《红' + '楼梦》还以其精美的诗词闻名,这些诗词不仅增添了文学色彩,' + '也深刻反映了人物的性格特点和命运走向。\n\n' + '### 总结:\n' + '描述了书中的两个女性角色之间围绕权力展开的竞争。\n\n' + '### 原文摘录:\n' + '其中关于贾府内部斗争的部分尤其精彩,特别是王熙凤与尤二姐' + '之间的争斗,生动描绘了权力争夺下的女性形象。') + DEFAULT_INPUT_TEMPLATE = ('### 原文:\n{text}\n\n' + '### 总结:\n{summary}\n\n' + '### 原文摘录:\n') + + def __init__(self, + api_model: str = 'gpt-4o', + *, + summary_key: str = Fields.event_description, + support_text_key: str = Fields.support_text, + api_endpoint: Optional[str] = None, + response_path: Optional[str] = None, + system_prompt: Optional[str] = None, + input_template: Optional[str] = None, + try_num: PositiveInt = 3, + drop_text: bool = False, + model_params: Dict = {}, + sampling_params: Dict = {}, + **kwargs): + """ + Initialization method. + :param api_model: API model name. + :param summary_key: The field name to store the input summary. + Support for nested keys such as "__dj__stats__.text_len". + It's "__dj__event_description__" in default. + :param support_text_key: The field name to store the output + support text for the summary. It's "__dj__support_text__" in + default. + :param api_endpoint: URL endpoint for the API. + :param response_path: Path to extract content from the API response. + Defaults to 'choices.0.message.content'. + :param system_prompt: System prompt for the task. + :param input_template: Template for building the model input. + :param try_num: The number of retry attempts when there is an API + call error or output parsing error. + :param drop_text: If drop the text in the output. + :param model_params: Parameters for initializing the API model. + :param sampling_params: Extra parameters passed to the API call. + e.g {'temperature': 0.9, 'top_p': 0.95} + :param kwargs: Extra keyword arguments. + """ + super().__init__(**kwargs) + + self.summary_key = summary_key + self.support_text_key = support_text_key + + self.system_prompt = system_prompt or self.DEFAULT_SYSTEM_PROMPT + self.input_template = input_template or self.DEFAULT_INPUT_TEMPLATE + + self.sampling_params = sampling_params + self.model_key = prepare_model(model_type='api', + model=api_model, + endpoint=api_endpoint, + response_path=response_path, + **model_params) + + self.try_num = try_num + self.drop_text = drop_text + + def process_single(self, sample, rank=None): + client = get_model(self.model_key, rank=rank) + + summary = nested_access(sample, self.summary_key) + if not isinstance(summary, str): + logger.warning('Unvalid input summary!') + return sample + + input_prompt = self.input_template.format(text=sample[self.text_key], + summary=summary) + messages = [{ + 'role': 'system', + 'content': self.system_prompt + }, { + 'role': 'user', + 'content': input_prompt + }] + + support_text = '' + for i in range(self.try_num): + try: + response = client(messages, **self.sampling_params) + support_text = response.strip() + if len(support_text) > 0: + break + except Exception as e: + logger.warning(f'Exception: {e}') + # default to summary if return None + if not support_text: + support_text = summary + + sample = nested_set(sample, self.support_text_key, support_text) + return sample diff --git a/data_juicer/ops/mapper/generate_qa_from_examples_mapper.py b/data_juicer/ops/mapper/generate_qa_from_examples_mapper.py index 6f5ad7dab..0c0d084b3 100644 --- a/data_juicer/ops/mapper/generate_qa_from_examples_mapper.py +++ b/data_juicer/ops/mapper/generate_qa_from_examples_mapper.py @@ -9,7 +9,7 @@ from data_juicer.utils.lazy_loader import LazyLoader from data_juicer.utils.model_utils import get_model, prepare_model -from ..base_op import OPERATORS, UNFORKABLE, Mapper +from ..base_op import OPERATORS, Mapper torch = LazyLoader('torch', 'torch') vllm = LazyLoader('vllm', 'vllm') @@ -19,7 +19,6 @@ # TODO: Extend LLM-based OPs into API-based implementation. -@UNFORKABLE.register_module(OP_NAME) @OPERATORS.register_module(OP_NAME) class GenerateQAFromExamplesMapper(Mapper): """ diff --git a/data_juicer/ops/mapper/generate_qa_from_text_mapper.py b/data_juicer/ops/mapper/generate_qa_from_text_mapper.py index 248dba428..0f3a1cfef 100644 --- a/data_juicer/ops/mapper/generate_qa_from_text_mapper.py +++ b/data_juicer/ops/mapper/generate_qa_from_text_mapper.py @@ -3,7 +3,7 @@ from loguru import logger -from data_juicer.ops.base_op import OPERATORS, UNFORKABLE, Mapper +from data_juicer.ops.base_op import OPERATORS, Mapper from data_juicer.utils.lazy_loader import LazyLoader from data_juicer.utils.model_utils import get_model, prepare_model @@ -14,7 +14,6 @@ # TODO: Extend LLM-based OPs into API-based implementation. -@UNFORKABLE.register_module(OP_NAME) @OPERATORS.register_module(OP_NAME) class GenerateQAFromTextMapper(Mapper): """ diff --git a/data_juicer/ops/mapper/image_captioning_mapper.py b/data_juicer/ops/mapper/image_captioning_mapper.py index 0bc486193..98bb3ad7c 100644 --- a/data_juicer/ops/mapper/image_captioning_mapper.py +++ b/data_juicer/ops/mapper/image_captioning_mapper.py @@ -81,6 +81,8 @@ def __init__(self, :param args: extra args :param kwargs: extra args """ + kwargs.setdefault('mem_required', '16GB') + super().__init__(*args, **kwargs) if keep_candidate_mode not in [ diff --git a/data_juicer/ops/mapper/image_diffusion_mapper.py b/data_juicer/ops/mapper/image_diffusion_mapper.py index c53d6f56d..53e315844 100644 --- a/data_juicer/ops/mapper/image_diffusion_mapper.py +++ b/data_juicer/ops/mapper/image_diffusion_mapper.py @@ -91,6 +91,7 @@ def __init__(self, :param hf_img2seq: model name on huggingface to generate caption if caption_key is None. """ + kwargs.setdefault('mem_required', '8GB') super().__init__(*args, **kwargs) self._init_parameters = self.remove_extra_parameters(locals()) self.strength = strength diff --git a/data_juicer/ops/mapper/image_tagging_mapper.py b/data_juicer/ops/mapper/image_tagging_mapper.py index d47fbf0ef..e3fc46f1b 100644 --- a/data_juicer/ops/mapper/image_tagging_mapper.py +++ b/data_juicer/ops/mapper/image_tagging_mapper.py @@ -36,6 +36,7 @@ def __init__(self, :param args: extra args :param kwargs: extra args """ + kwargs.setdefault('mem_required', '9GB') super().__init__(*args, **kwargs) self.model_key = prepare_model( model_type='recognizeAnything', diff --git a/data_juicer/ops/mapper/optimize_qa_mapper.py b/data_juicer/ops/mapper/optimize_qa_mapper.py index 3563a112b..974730ec5 100644 --- a/data_juicer/ops/mapper/optimize_qa_mapper.py +++ b/data_juicer/ops/mapper/optimize_qa_mapper.py @@ -3,7 +3,7 @@ from loguru import logger -from data_juicer.ops.base_op import OPERATORS, UNFORKABLE, Mapper +from data_juicer.ops.base_op import OPERATORS, Mapper from data_juicer.utils.lazy_loader import LazyLoader from data_juicer.utils.model_utils import get_model, prepare_model @@ -14,7 +14,6 @@ # TODO: Extend LLM-based OPs into API-based implementation. -@UNFORKABLE.register_module(OP_NAME) @OPERATORS.register_module(OP_NAME) class OptimizeQAMapper(Mapper): """ diff --git a/data_juicer/ops/mapper/optimize_query_mapper.py b/data_juicer/ops/mapper/optimize_query_mapper.py index dd227b4c1..9ccd84bb1 100644 --- a/data_juicer/ops/mapper/optimize_query_mapper.py +++ b/data_juicer/ops/mapper/optimize_query_mapper.py @@ -1,11 +1,10 @@ -from data_juicer.ops.base_op import OPERATORS, UNFORKABLE +from data_juicer.ops.base_op import OPERATORS from data_juicer.ops.mapper.optimize_qa_mapper import OptimizeQAMapper OP_NAME = 'optimize_query_mapper' # TODO: Extend LLM-based OPs into API-based implementation. -@UNFORKABLE.register_module(OP_NAME) @OPERATORS.register_module(OP_NAME) class OptimizeQueryMapper(OptimizeQAMapper): """ diff --git a/data_juicer/ops/mapper/optimize_response_mapper.py b/data_juicer/ops/mapper/optimize_response_mapper.py index 158159a9d..f6026b8dc 100644 --- a/data_juicer/ops/mapper/optimize_response_mapper.py +++ b/data_juicer/ops/mapper/optimize_response_mapper.py @@ -1,11 +1,10 @@ -from data_juicer.ops.base_op import OPERATORS, UNFORKABLE +from data_juicer.ops.base_op import OPERATORS from data_juicer.ops.mapper.optimize_qa_mapper import OptimizeQAMapper OP_NAME = 'optimize_response_mapper' # TODO: Extend LLM-based OPs into API-based implementation. -@UNFORKABLE.register_module(OP_NAME) @OPERATORS.register_module(OP_NAME) class OptimizeResponseMapper(OptimizeQAMapper): """ diff --git a/data_juicer/ops/mapper/pair_preference_mapper.py b/data_juicer/ops/mapper/pair_preference_mapper.py new file mode 100644 index 000000000..f839fb5d3 --- /dev/null +++ b/data_juicer/ops/mapper/pair_preference_mapper.py @@ -0,0 +1,131 @@ +import re +from typing import Dict, Optional + +from loguru import logger +from pydantic import PositiveInt + +from data_juicer.ops.base_op import OPERATORS, Mapper +from data_juicer.utils.model_utils import get_model, prepare_model + +OP_NAME = 'pair_preference_mapper' + + +# TODO: Extend LLM-based OPs into API-based implementation. +@OPERATORS.register_module(OP_NAME) +class PairPreferenceMapper(Mapper): + """ + Mapper to construct paired preference samples. + """ + + # avoid leading whitespace + DEFAULT_SYSTEM_PROMPT = ( + '你的任务是根据参考信息修改问答对中的回答,在语言风格、事实性、人物身份、立场等任一方面与原回答相反。' + '必须按照以下标记格式输出,不要输出其他多余内容。\n' + '【回答】\n' + '生成的新回答\n' + '【原因】\n' + '生成该回答的原因') + DEFAULT_INPUT_TEMPLATE = ('【参考信息】\n' + '{reference}\n' + '\n' + '以下是原始问答对:\n' + '【问题】\n' + '{query}\n' + '【回答】\n' + '{response}') + DEFAULT_OUTPUT_PATTERN = r'.*?【回答】\s*(.*?)\s*【原因】\s*(.*)' + + def __init__(self, + api_model: str = 'gpt-4o', + *, + api_endpoint: Optional[str] = None, + response_path: Optional[str] = None, + system_prompt: Optional[str] = None, + input_template: Optional[str] = None, + output_pattern: Optional[str] = None, + rejected_key: str = 'rejected_response', + reason_key: str = 'reason', + try_num: PositiveInt = 3, + model_params: Dict = {}, + sampling_params: Dict = {}, + **kwargs): + """ + Initialization method. + + :param api_model: API model name. + :param api_endpoint: URL endpoint for the API. + :param response_path: Path to extract content from the API response. + Defaults to 'choices.0.message.content'. + :param system_prompt: System prompt for guiding the generation task. + :param input_template: Template for building the model input. It must + contain placeholders '{query}' and '{reponse}', and can optionally + include '{reference}'. + :param output_pattern: Regular expression for parsing model output. + :param rejected_key: The field name in the sample to store the + generated rejected response. Defaults to 'rejected_response'. + :param reason_key: The field name in the sample to store the reason for + generating the response. Defaults to 'reason'. + :param try_num: The number of retries for the API call in case of + response parsing failure. Defaults to 3. + :param model_params: Parameters for initializing the API model. + :param sampling_params: Extra parameters passed to the API call. + e.g {'temperature': 0.9, 'top_p': 0.95} + :param kwargs: Extra keyword arguments. + """ + super().__init__(**kwargs) + + self.system_prompt = system_prompt or self.DEFAULT_SYSTEM_PROMPT + self.input_template = input_template or self.DEFAULT_INPUT_TEMPLATE + self.output_pattern = output_pattern or self.DEFAULT_OUTPUT_PATTERN + + self.rejected_key = rejected_key + self.reason_key = reason_key + + self.model_key = prepare_model(model_type='api', + model=api_model, + endpoint=api_endpoint, + response_path=response_path, + **model_params) + self.try_num = try_num + self.sampling_params = sampling_params + + def build_input(self, sample): + mapping = { + 'query': sample[self.query_key], + 'response': sample[self.response_key], + 'reference': sample.get(self.text_key, '') + } + return self.input_template.format_map(mapping) + + def parse_output(self, raw_output): + logger.debug(raw_output) + match = re.match(self.output_pattern, raw_output, re.DOTALL) + if match: + return match.group(1).strip(), match.group(2).strip() + else: + return ('', '') + + def process_single(self, sample, rank=None): + client = get_model(self.model_key, rank=rank) + + messages = [{ + 'role': 'system', + 'content': self.system_prompt + }, { + 'role': 'user', + 'content': self.build_input(sample) + }] + + parsed_rejected, parsed_reason = '', '' + for _ in range(self.try_num): + try: + output = client(messages, **self.sampling_params) + parsed_rejected, parsed_reason = self.parse_output(output) + if parsed_rejected and parsed_reason: + break + except Exception as e: + logger.warning(f'Exception: {e}') + sample[self.rejected_key] = parsed_rejected + sample[self.reason_key] = parsed_reason + + return sample diff --git a/data_juicer/ops/mapper/python_file_mapper.py b/data_juicer/ops/mapper/python_file_mapper.py new file mode 100644 index 000000000..b74fd96a1 --- /dev/null +++ b/data_juicer/ops/mapper/python_file_mapper.py @@ -0,0 +1,97 @@ +import importlib.util +import inspect +import os + +from ..base_op import OPERATORS, Mapper + +OP_NAME = 'python_file_mapper' + + +@OPERATORS.register_module(OP_NAME) +class PythonFileMapper(Mapper): + """Mapper for executing Python function defined in a file.""" + + def __init__(self, + file_path: str = '', + function_name: str = 'process_single', + batched: bool = False, + **kwargs): + """ + Initialization method. + + :param file_path: The path to the Python file containing the function + to be executed. + :param function_name: The name of the function defined in the file + to be executed. + :param batched: A boolean indicating whether to process input data in + batches. + :param kwargs: Additional keyword arguments passed to the parent class. + """ + self._batched_op = bool(batched) + super().__init__(**kwargs) + + self.file_path = file_path + self.function_name = function_name + if not file_path: + self.func = lambda sample: sample + else: + self.func = self._load_function() + + def _load_function(self): + if not os.path.isfile(self.file_path): + raise FileNotFoundError( + f"The file '{self.file_path}' does not exist.") + + if not self.file_path.endswith('.py'): + raise ValueError( + f"The file '{self.file_path}' is not a Python file.") + + # Load the module from the file + module_name = os.path.splitext(os.path.basename(self.file_path))[0] + spec = importlib.util.spec_from_file_location(module_name, + self.file_path) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + + # Fetch the specified function from the module + if not hasattr(module, self.function_name): + raise ValueError( + f"Function '{self.function_name}' not found in '{self.file_path}'." # noqa: E501 + ) + + func = getattr(module, self.function_name) + + if not callable(func): + raise ValueError( + f"The attribute '{self.function_name}' is not callable.") + + # Check that the function has exactly one argument + argspec = inspect.getfullargspec(func) + if len(argspec.args) != 1: + raise ValueError( + f"The function '{self.function_name}' must take exactly one argument" # noqa: E501 + ) + + return func + + def process_single(self, sample): + """Invoke the loaded function with the provided sample.""" + result = self.func(sample) + + if not isinstance(result, dict): + raise ValueError( + f'Function must return a dictionary, got {type(result).__name__} instead.' # noqa: E501 + ) + + return result + + def process_batched(self, samples): + """Invoke the loaded function with the provided samples.""" + result = self.func(samples) + + if not isinstance(result, dict): + raise ValueError( + f'Function must return a dictionary, got {type(result).__name__} instead.' # noqa: E501 + ) + + return result diff --git a/data_juicer/ops/mapper/python_lambda_mapper.py b/data_juicer/ops/mapper/python_lambda_mapper.py new file mode 100644 index 000000000..e90c77f48 --- /dev/null +++ b/data_juicer/ops/mapper/python_lambda_mapper.py @@ -0,0 +1,74 @@ +import ast + +from ..base_op import OPERATORS, Mapper + +OP_NAME = 'python_lambda_mapper' + + +@OPERATORS.register_module(OP_NAME) +class PythonLambdaMapper(Mapper): + """Mapper for executing Python lambda function on data samples.""" + + def __init__(self, lambda_str: str = '', batched: bool = False, **kwargs): + """ + Initialization method. + + :param lambda_str: A string representation of the lambda function to be + executed on data samples. If empty, the identity function is used. + :param batched: A boolean indicating whether to process input data in + batches. + :param kwargs: Additional keyword arguments passed to the parent class. + """ + self._batched_op = bool(batched) + super().__init__(**kwargs) + + # Parse and validate the lambda function + if not lambda_str: + self.lambda_func = lambda sample: sample + else: + self.lambda_func = self._create_lambda(lambda_str) + + def _create_lambda(self, lambda_str: str): + # Parse input string into an AST and check for a valid lambda function + try: + node = ast.parse(lambda_str, mode='eval') + + # Check if the body of the expression is a lambda + if not isinstance(node.body, ast.Lambda): + raise ValueError( + 'Input string must be a valid lambda function.') + + # Check that the lambda has exactly one argument + if len(node.body.args.args) != 1: + raise ValueError( + 'Lambda function must have exactly one argument.') + + # Compile the AST to code + compiled_code = compile(node, '', 'eval') + # Safely evaluate the compiled code allowing built-in functions + func = eval(compiled_code, {'__builtins__': __builtins__}) + return func + except Exception as e: + raise ValueError(f'Invalid lambda function: {e}') + + def process_single(self, sample): + # Process the input through the lambda function and return the result + result = self.lambda_func(sample) + + # Check if the result is a valid + if not isinstance(result, dict): + raise ValueError(f'Lambda function must return a dictionary, ' + f'got {type(result).__name__} instead.') + + return result + + def process_batched(self, samples): + # Process the input through the lambda function and return the result + result = self.lambda_func(samples) + + # Check if the result is a valid + if not isinstance(result, dict): + raise ValueError(f'Lambda function must return a dictionary, ' + f'got {type(result).__name__} instead.') + + return result diff --git a/data_juicer/ops/mapper/relation_identity_mapper.py b/data_juicer/ops/mapper/relation_identity_mapper.py new file mode 100644 index 000000000..29994d744 --- /dev/null +++ b/data_juicer/ops/mapper/relation_identity_mapper.py @@ -0,0 +1,155 @@ +import re +from typing import Dict, Optional + +from loguru import logger +from pydantic import PositiveInt + +from data_juicer.ops.base_op import OPERATORS, Mapper +from data_juicer.utils.common_utils import nested_access, nested_set +from data_juicer.utils.model_utils import get_model, prepare_model + +OP_NAME = 'relation_identity_mapper' + + +# TODO: LLM-based inference. +@OPERATORS.register_module(OP_NAME) +class RelationIdentityMapper(Mapper): + """ + identify relation between two entity in the text. + """ + + DEFAULT_SYSTEM_PROMPT_TEMPLATE = ( + '给定关于{entity1}和{entity2}的文本信息。' + '判断{entity1}和{entity2}之间的关系。\n' + '要求:\n' + '- 关系用一个或多个词语表示,必要时可以加一个形容词来描述这段关系\n' + '- 输出关系时不要参杂任何标点符号\n' + '- 需要你进行合理的推理才能得出结论\n' + '- 如果两个人物身份是同一个人,输出关系为:另一个身份\n' + '- 输出格式为:\n' + '分析推理:...\n' + '所以{entity2}是{entity1}的:...\n' + '- 注意输出的是{entity2}是{entity1}的什么关系,而不是{entity1}是{entity2}的什么关系') + DEFAULT_INPUT_TEMPLATE = '关于{entity1}和{entity2}的文本信息:\n```\n{text}\n```\n' + DEFAULT_OUTPUT_PATTERN_TEMPLATE = r""" + \s*分析推理:\s*(.*?)\s* + \s*所以{entity2}是{entity1}的:\s*(.*?)\Z + """ + + def __init__(self, + api_model: str = 'gpt-4o', + source_entity: str = None, + target_entity: str = None, + input_key: str = None, + output_key: str = None, + *, + api_endpoint: Optional[str] = None, + response_path: Optional[str] = None, + system_prompt_template: Optional[str] = None, + input_template: Optional[str] = None, + output_pattern_template: Optional[str] = None, + try_num: PositiveInt = 3, + drop_text: bool = False, + model_params: Dict = {}, + sampling_params: Dict = {}, + **kwargs): + """ + Initialization method. + :param api_model: API model name. + :param source_entity: The source entity of the relation to be + identified. + :param target_entity: The target entity of the relation to be + identified. + :param input_key: The input field key in the samples. Support for + nested keys such as "__dj__stats__.text_len". It is text_key + in default. + :param output_key: The output field key in the samples. Support + for nested keys such as "__dj__stats__.text_len". It is + input_key in default. + :param api_endpoint: URL endpoint for the API. + :param response_path: Path to extract content from the API response. + Defaults to 'choices.0.message.content'. + :param system_prompt_template: System prompt template for the task. + :param input_template: Template for building the model input. + :param output_pattern_template: Regular expression template for + parsing model output. + :param try_num: The number of retry attempts when there is an API + call error or output parsing error. + :param drop_text: If drop the text in the output. + :param model_params: Parameters for initializing the API model. + :param sampling_params: Extra parameters passed to the API call. + e.g {'temperature': 0.9, 'top_p': 0.95} + :param kwargs: Extra keyword arguments. + """ + super().__init__(**kwargs) + + if source_entity is None or target_entity is None: + logger.warning('source_entity and target_entity cannot be None') + + self.source_entity = source_entity + self.target_entity = target_entity + + self.input_key = input_key or self.text_key + self.output_key = output_key or self.input_key + + system_prompt_template = system_prompt_template or \ + self.DEFAULT_SYSTEM_PROMPT_TEMPLATE + self.system_prompt = system_prompt_template.format( + entity1=source_entity, entity2=target_entity) + self.input_template = input_template or self.DEFAULT_INPUT_TEMPLATE + output_pattern_template = output_pattern_template or \ + self.DEFAULT_OUTPUT_PATTERN_TEMPLATE + self.output_pattern = output_pattern_template.format( + entity1=source_entity, entity2=target_entity) + + self.sampling_params = sampling_params + self.model_key = prepare_model(model_type='api', + model=api_model, + endpoint=api_endpoint, + response_path=response_path, + **model_params) + + self.try_num = try_num + self.drop_text = drop_text + + def parse_output(self, raw_output): + pattern = re.compile(self.output_pattern, re.VERBOSE | re.DOTALL) + matches = pattern.findall(raw_output) + + relation = '' + + for match in matches: + _, relation = match + relation = relation.strip() + + return relation + + def process_single(self, sample, rank=None): + client = get_model(self.model_key, rank=rank) + + text = nested_access(sample, self.input_key) + input_prompt = self.input_template.format(entity1=self.source_entity, + entity2=self.target_entity, + text=text) + messages = [{ + 'role': 'system', + 'content': self.system_prompt + }, { + 'role': 'user', + 'content': input_prompt + }] + relation = '' + for i in range(self.try_num): + try: + output = client(messages, **self.sampling_params) + relation = self.parse_output(output) + if len(relation) > 0: + break + except Exception as e: + logger.warning(f'Exception: {e}') + + sample = nested_set(sample, self.output_key, relation) + if self.drop_text: + sample.pop(self.text_key) + + return sample diff --git a/data_juicer/ops/mapper/video_captioning_from_audio_mapper.py b/data_juicer/ops/mapper/video_captioning_from_audio_mapper.py index 4833409a4..75ffb9b3a 100644 --- a/data_juicer/ops/mapper/video_captioning_from_audio_mapper.py +++ b/data_juicer/ops/mapper/video_captioning_from_audio_mapper.py @@ -32,6 +32,7 @@ def __init__(self, keep_original_sample: bool = True, *args, **kwargs): :param args: extra args :param kwargs: extra args """ + kwargs.setdefault('mem_required', '30GB') super().__init__(*args, **kwargs) AUTOINSTALL.check([ 'transformers', 'transformers_stream_generator', 'einops', diff --git a/data_juicer/ops/mapper/video_captioning_from_frames_mapper.py b/data_juicer/ops/mapper/video_captioning_from_frames_mapper.py index dbf614510..d4c664c5f 100644 --- a/data_juicer/ops/mapper/video_captioning_from_frames_mapper.py +++ b/data_juicer/ops/mapper/video_captioning_from_frames_mapper.py @@ -108,6 +108,7 @@ def __init__( :param args: extra args :param kwargs: extra args """ + kwargs.setdefault('mem_required', '20GB') super().__init__(*args, **kwargs) if keep_candidate_mode not in [ diff --git a/data_juicer/ops/mapper/video_captioning_from_summarizer_mapper.py b/data_juicer/ops/mapper/video_captioning_from_summarizer_mapper.py index b2f4c8139..67eb7e234 100644 --- a/data_juicer/ops/mapper/video_captioning_from_summarizer_mapper.py +++ b/data_juicer/ops/mapper/video_captioning_from_summarizer_mapper.py @@ -81,6 +81,7 @@ def __init__(self, :param args: extra args :param kwargs: extra args """ + kwargs.setdefault('mem_required', '40GB') super().__init__(*args, **kwargs) AUTOINSTALL.check([ 'torch', diff --git a/data_juicer/ops/mapper/video_captioning_from_video_mapper.py b/data_juicer/ops/mapper/video_captioning_from_video_mapper.py index 04cd641ab..737626260 100644 --- a/data_juicer/ops/mapper/video_captioning_from_video_mapper.py +++ b/data_juicer/ops/mapper/video_captioning_from_video_mapper.py @@ -108,6 +108,7 @@ def __init__( :param args: extra args :param kwargs: extra args """ + kwargs.setdefault('mem_required', '20GB') super().__init__(*args, **kwargs) if keep_candidate_mode not in [ diff --git a/data_juicer/ops/mapper/video_extract_frames_mapper.py b/data_juicer/ops/mapper/video_extract_frames_mapper.py new file mode 100644 index 000000000..4eb522abe --- /dev/null +++ b/data_juicer/ops/mapper/video_extract_frames_mapper.py @@ -0,0 +1,173 @@ +import json +import os +import os.path as osp + +from pydantic import PositiveInt + +from data_juicer.utils.constant import Fields +from data_juicer.utils.file_utils import dict_to_hash +from data_juicer.utils.mm_utils import ( + SpecialTokens, close_video, extract_key_frames, + extract_key_frames_by_seconds, extract_video_frames_uniformly, + extract_video_frames_uniformly_by_seconds, load_data_with_context, + load_video) + +from ..base_op import OPERATORS, Mapper +from ..op_fusion import LOADED_VIDEOS + +OP_NAME = 'video_extract_frames_mapper' + + +@OPERATORS.register_module(OP_NAME) +@LOADED_VIDEOS.register_module(OP_NAME) +class VideoExtractFramesMapper(Mapper): + """Mapper to extract frames from video files according to specified methods. + Extracted Frames Data Format: + The data format for the extracted frames is a dictionary mapping + video key to extracted frames directory where the extracted + frames are saved. The dictionary follows the structure: + { + "video_key_1": "/${frame_dir}/video_key_1_filename/", + "video_key_2": "/${frame_dir}/video_key_2_filename/", + ... + } + """ + + _batched_op = True + + def __init__( + self, + frame_sampling_method: str = 'all_keyframes', + frame_num: PositiveInt = 3, + duration: float = 0, + frame_dir: str = None, + frame_key=Fields.video_frames, + *args, + **kwargs, + ): + """ + Initialization method. + :param frame_sampling_method: sampling method of extracting frame + videos from the videos. Should be one of + ["all_keyframes", "uniform"]. + The former one extracts all key frames (the number + of which depends on the duration of the video) and the latter + one extract specified number of frames uniformly from the video. + If "duration" > 0, frame_sampling_method acts on every segment. + Default: "all_keyframes". + :param frame_num: the number of frames to be extracted uniformly from + the video. Only works when frame_sampling_method is "uniform". If + it's 1, only the middle frame will be extracted. If it's 2, only + the first and the last frames will be extracted. If it's larger + than 2, in addition to the first and the last frames, other frames + will be extracted uniformly within the video duration. + If "duration" > 0, frame_num is the number of frames per segment. + :param duration: The duration of each segment in seconds. + If 0, frames are extracted from the entire video. + If duration > 0, the video is segmented into multiple segments + based on duration, and frames are extracted from each segment. + :param frame_dir: Output directory to save extracted frames. + If None, a default directory based on the video file path is used. + :param frame_key: The name of field to save generated frames info. + :param args: extra args + :param kwargs: extra args + """ + super().__init__(*args, **kwargs) + self._init_parameters = self.remove_extra_parameters(locals()) + + if frame_sampling_method not in ['all_keyframes', 'uniform']: + raise ValueError( + f'Frame sampling method ' + f'[{frame_sampling_method}] is not supported. ' + f'Can only be one of ["all_keyframes", "uniform"].') + + self.frame_dir = frame_dir + self.frame_sampling_method = frame_sampling_method + self.frame_num = frame_num + self.duration = duration + self.frame_key = frame_key + self.frame_fname_template = 'frame_{}.jpg' + + def _get_default_frame_dir(self, original_filepath): + original_dir = os.path.dirname(original_filepath) + dir_token = f'/{Fields.multimodal_data_output_dir}/' + if dir_token in original_dir: + original_dir = original_dir.split(dir_token)[0] + saved_dir = os.path.join( + original_dir, f'{Fields.multimodal_data_output_dir}/{OP_NAME}') + original_filename = osp.splitext(osp.basename(original_filepath))[0] + hash_val = dict_to_hash(self._init_parameters) + + return osp.join(saved_dir, + f'{original_filename}__dj_hash_#{hash_val}#') + + def process_single(self, sample, context=False): + # check if it's generated already + if self.frame_key in sample: + return sample + + # there is no videos in this sample + if self.video_key not in sample or not sample[self.video_key]: + return [] + + # load videos + loaded_video_keys = sample[self.video_key] + sample, videos = load_data_with_context(sample, context, + loaded_video_keys, load_video) + video_to_frame_dir = {} + text = sample[self.text_key] + offset = 0 + + for chunk in text.split(SpecialTokens.eoc): + video_count = chunk.count(SpecialTokens.video) + # no video or no text + if video_count == 0 or len(chunk) == 0: + continue + else: + for video_key in loaded_video_keys[offset:offset + + video_count]: + video = videos[video_key] + # extract frame videos + if self.frame_sampling_method == 'all_keyframes': + if self.duration: + frames = extract_key_frames_by_seconds( + video, self.duration) + else: + frames = extract_key_frames(video) + elif self.frame_sampling_method == 'uniform': + if self.duration: + frames = extract_video_frames_uniformly_by_seconds( + video, self.frame_num, duration=self.duration) + else: + frames = extract_video_frames_uniformly( + video, self.frame_num) + else: + raise ValueError(f'Not support sampling method \ + `{self.frame_sampling_method}`.') + frames = [frame.to_image() for frame in frames] + + if self.frame_dir: + frame_dir = osp.join( + self.frame_dir, + osp.splitext(osp.basename(video_key))[0]) + else: + # video path as frames directory + frame_dir = self._get_default_frame_dir(video_key) + os.makedirs(frame_dir, exist_ok=True) + video_to_frame_dir[video_key] = frame_dir + + for i, frame in enumerate(frames): + frame_path = osp.join( + frame_dir, self.frame_fname_template.format(i)) + if not os.path.exists(frame_path): + frame.save(frame_path) + + offset += video_count + + if not context: + for vid_key in videos: + close_video(videos[vid_key]) + + sample[self.frame_key] = json.dumps(video_to_frame_dir) + + return sample diff --git a/data_juicer/ops/mapper/video_tagging_from_audio_mapper.py b/data_juicer/ops/mapper/video_tagging_from_audio_mapper.py index 763a3381c..2c32093a5 100644 --- a/data_juicer/ops/mapper/video_tagging_from_audio_mapper.py +++ b/data_juicer/ops/mapper/video_tagging_from_audio_mapper.py @@ -37,6 +37,7 @@ def __init__(self, :param args: extra args :param kwargs: extra args """ + kwargs.setdefault('mem_required', '500MB') super().__init__(*args, **kwargs) AUTOINSTALL.check(['torchaudio']) self.model_key = prepare_model(model_type='huggingface', diff --git a/data_juicer/ops/mapper/video_tagging_from_frames_mapper.py b/data_juicer/ops/mapper/video_tagging_from_frames_mapper.py index 26227738b..d4995d3f6 100644 --- a/data_juicer/ops/mapper/video_tagging_from_frames_mapper.py +++ b/data_juicer/ops/mapper/video_tagging_from_frames_mapper.py @@ -55,6 +55,7 @@ def __init__(self, :param args: extra args :param kwargs: extra args """ + kwargs.setdefault('mem_required', '9GB') super().__init__(*args, **kwargs) if frame_sampling_method not in ['all_keyframes', 'uniform']: raise ValueError( diff --git a/data_juicer/ops/op_fusion.py b/data_juicer/ops/op_fusion.py index 26aaa556e..489f90ab0 100644 --- a/data_juicer/ops/op_fusion.py +++ b/data_juicer/ops/op_fusion.py @@ -1,5 +1,6 @@ from typing import List +import numpy as np from loguru import logger from data_juicer.utils.constant import Fields, InterVars @@ -23,47 +24,49 @@ INTER_SAMPLED_FRAMES = Registry(InterVars.sampled_frames) # all -ALL_INTER_VARS = [INTER_LINES, INTER_WORDS, LOADED_IMAGES, LOADED_VIDEOS] +ALL_INTER_VARS = [ + INTER_LINES, INTER_WORDS, LOADED_IMAGES, LOADED_VIDEOS, + INTER_SAMPLED_FRAMES +] +# supported fusion strategies +FUSION_STRATEGIES = {'greedy', 'probe'} -def fuse_operators(process_list, ops): + +def fuse_operators(ops, probe_res=None): """ Fuse the input ops list and return the fused ops list. - :param process_list: the list of original process definition, including op - names and args. :param ops: the corresponding list of op objects. + :param probe_res: the probed speed for each OP from Monitor. :return: a list of fused op objects. """ + if probe_res is None: + probe_res = [None for _ in range(len(ops))] # detect filter groups and try to fuse them - fused_op_def = [] fused_ops = [] filter_group = [] in_group = False - for process, op in zip(process_list, ops): + for op, op_probe in zip(ops, probe_res): if isinstance(op, Filter): if not in_group: in_group = True - filter_group.append((process, op)) + filter_group.append((op, op_probe)) elif in_group: # got a filter group, try to fuse them - fused_group_def, fused_group = fuse_filter_group(filter_group) - fused_op_def.extend(fused_group_def) + fused_group = fuse_filter_group(filter_group) fused_ops.extend(fused_group) filter_group = [] in_group = False # and add the current non-filter op into fused_ops - fused_op_def.append(process) fused_ops.append(op) else: # not a filter and not in a filter group, skip - fused_op_def.append(process) fused_ops.append(op) if in_group and len(filter_group) > 0: # the final filter group, try to fuse them - fused_group_def, fused_group = fuse_filter_group(filter_group) - fused_op_def.extend(fused_group_def) + fused_group = fuse_filter_group(filter_group) fused_ops.extend(fused_group) - return fused_op_def, fused_ops + return fused_ops def fuse_filter_group(original_filter_group): @@ -74,25 +77,25 @@ def fuse_filter_group(original_filter_group): definitions and objects. :return: the fused definitions and objects of the input filter group. """ - fused_group_def = [] fused_group = [] + group_speed = [] all_intermediate_vars = ALL_INTER_VARS all_fused_filters = { inter_vars: [] for inter_vars in all_intermediate_vars } # group these filters by their intermediate vars - for process, op in original_filter_group: - op_name, op_args = list(process.items())[0] + for op, probe_res in original_filter_group: + op_name = op._name for inter_vars in all_intermediate_vars: if op_name in inter_vars.modules: - all_fused_filters[inter_vars].append((process, op)) + all_fused_filters[inter_vars].append((op, probe_res)) break else: # first apply other filters to decrease the number of samples, so # we add them into the fused_group list directly - fused_group_def.append(process) fused_group.append(op) + group_speed.append(probe_res['speed'] if probe_res else 0) # try to fuse ops for each type of intermediate vars for inter_vars in all_intermediate_vars: @@ -102,40 +105,59 @@ def fuse_filter_group(original_filter_group): pass elif len(inter_vars_filter) > 1: # more than 1 ops share the same intermediate var, try to fuse them - defs, ops = zip(*inter_vars_filter) + ops, probe_res_list = zip(*inter_vars_filter) # new definition: new name and a definition list of fused op list - fused_filter_def = { - 'OpFusion:(%s)' % ','.join([ - list(process.items())[0][0] for process in defs - ]): - list(defs) - } + fused_filter_name = 'OpFusion:(%s)' % ','.join( + [op._name for op in ops]) logger.info(f'Ops are fused into one op ' - f'{list(fused_filter_def.keys())[0]}.') + f'{fused_filter_name}.') # use these ops to create a FusedFilter object, and add the fused # definition and op into the fused group - fused_filter = FusedFilter(ops) - fused_group_def.append(fused_filter_def) + fused_filter = FusedFilter(fused_filter_name, ops) + fused_filter._op_cfg = { + fused_filter_name: [op._op_cfg for op in ops] + } + fused_filter_speed = sum([ + 1.0 / probe_res['speed'] for probe_res in probe_res_list + if probe_res + ]) + if fused_filter_speed > 0: + fused_filter_speed = 1.0 / fused_filter_speed fused_group.append(fused_filter) + group_speed.append(fused_filter_speed) else: # only 1 op for this type of intermediate var, add it to the fused # group directly without fusion - fused_group_def.append(inter_vars_filter[0][0]) - fused_group.append(inter_vars_filter[0][1]) + fused_group.append(inter_vars_filter[0][0]) + probe_res = inter_vars_filter[0][1] + group_speed.append(probe_res['speed'] if probe_res else 0) + + # reorder according to the probed speed results in group_speed + # 'greedy': all speed data in group_speed will be 0, which will keep the + # current order of fused group + # 'probe': OPs in fused group will be reordered according to the speed data + # in group_speed in descending order + fused_group = [ + op for op, _ in sorted( + zip(fused_group, group_speed), key=lambda it: it[1], reverse=True) + ] - return fused_group_def, fused_group + return fused_group class FusedFilter(Filter): """A fused operator for filters.""" - def __init__(self, fused_filters: List): + _batched_op = True + + def __init__(self, name: str, fused_filters: List): """ Initialization method. :param fused_filters: a list of filters to be fused. """ super().__init__() + self._name = name self.fused_filters = fused_filters # set accelerator to 'cuda' if there exists any ops whose accelerator # is 'cuda' @@ -144,30 +166,40 @@ def __init__(self, fused_filters: List): if 'cuda' in accelerator_methods: self.accelerator = 'cuda' - def compute_stats_single(self, sample, rank=None): + # update num_proc with the min num_proc of all fusible filters + self.num_proc = min([op.runtime_np() for op in self.fused_filters]) + + def compute_stats_batched(self, samples, rank=None): import av # context for the intermediate vars - sample[Fields.context] = {} + num_samples = len(samples[Fields.stats]) + samples[Fields.context] = [{} for _ in range(num_samples)] for op in self.fused_filters: # open the context for these fused ops if op.accelerator == 'cuda': - sample = op.compute_stats(sample, rank=rank, context=True) + samples = op.compute_stats_batched(samples, + rank=rank, + context=True) else: - sample = op.compute_stats(sample, context=True) + samples = op.compute_stats_batched(samples, context=True) # clean up the contexts after processing # check if there are containers that need to be closed - for context_key in sample[Fields.context]: - if isinstance(sample[Fields.context][context_key], - av.container.InputContainer): - sample[Fields.context][context_key].streams.video[0].close() - sample[Fields.context][context_key].close() - _ = sample.pop(Fields.context) - return sample - - def process_single(self, sample): + for ctx in samples[Fields.context]: + for context_key in ctx: + if isinstance(ctx[context_key], av.container.InputContainer): + ctx[context_key].streams.video[0].close() + ctx[context_key].close() + _ = samples.pop(Fields.context) + return samples + + def process_batched(self, samples): # Only return True when all filters return True + res = None for op in self.fused_filters: - if not op.process(sample): - return False - return True + this_res = np.array(list(op.process_batched(samples))) + if res is not None: + res = np.logical_and(res, this_res) + else: + res = this_res + return res diff --git a/data_juicer/utils/auto_install_mapping.py b/data_juicer/utils/auto_install_mapping.py index 96a54b437..5ea9091b0 100644 --- a/data_juicer/utils/auto_install_mapping.py +++ b/data_juicer/utils/auto_install_mapping.py @@ -10,99 +10,90 @@ 'simhash': ['simhash-pybind'], } -# Packages to corresponding ops that require them -PKG_TO_OPS = { - 'torch': [ - 'image_aesthetics_filter', 'image_nsfw_filter', - 'image_text_matching_filter', 'image_text_similarity_filter', - 'image_watermark_filter', 'phrase_grounding_recall_filter', - 'video_aesthetics_filter', 'video_frames_text_similarity_filter', - 'video_nsfw_filter', 'video_tagging_from_frames_filter', - 'video_watermark_filter', 'generate_qa_from_text_mapper', - 'generate_qa_from_examples_mapper', 'image_captioning_mapper', - 'image_diffusion_mapper', 'image_tagging_mapper', - 'optimize_query_mapper', 'optimize_response_mapper', - 'optimize_qa_mapper', 'video_captioning_from_frames_mapper', - 'video_captioning_from_summarizer_mapper', - 'video_captioning_from_video_mapper', - 'video_tagging_from_audio_mapper', 'video_tagging_from_frames_mapper' +# Extra packages required by each op +OPS_TO_PKG = { + 'video_aesthetics_filter': + ['simple-aesthetics-predictor', 'torch', 'transformers'], + 'document_simhash_deduplicator': ['simhash-pybind'], + 'nlpcda_zh_mapper': ['nlpcda'], + 'image_aesthetics_filter': + ['simple-aesthetics-predictor', 'torch', 'transformers'], + 'video_nsfw_filter': ['torch', 'transformers'], + 'video_face_blur_mapper': ['opencv-python'], + 'stopwords_filter': ['sentencepiece'], + 'fix_unicode_mapper': ['ftfy'], + 'token_num_filter': ['transformers'], + 'optimize_qa_mapper': ['torch', 'transformers', 'vllm'], + 'video_motion_score_filter': ['opencv-python'], + 'image_tagging_mapper': ['ram', 'torch'], + 'video_resize_aspect_ratio_mapper': ['ffmpeg-python'], + 'video_captioning_from_audio_mapper': [ + 'accelerate', 'einops', 'tiktoken', 'transformers', + 'transformers_stream_generator' ], - 'torchaudio': [ - 'video_captioning_from_summarizer_mapper', - 'video_tagging_from_audio_mapper' + 'clean_html_mapper': ['selectolax'], + 'video_tagging_from_audio_mapper': ['torch', 'torchaudio', 'transformers'], + 'image_deduplicator': ['imagededup'], + 'image_diffusion_mapper': + ['diffusers', 'simhash-pybind', 'torch', 'transformers'], + 'image_text_similarity_filter': ['torch', 'transformers'], + 'alphanumeric_filter': ['transformers'], + 'image_nsfw_filter': ['torch', 'transformers'], + 'image_watermark_filter': ['torch', 'transformers'], + 'ray_image_deduplicator': ['imagededup'], + 'video_captioning_from_frames_mapper': + ['simhash-pybind', 'torch', 'transformers'], + 'video_tagging_from_frames_filter': ['torch'], + 'video_resize_resolution_mapper': ['ffmpeg-python'], + 'optimize_query_mapper': ['torch', 'transformers', 'vllm'], + 'sentence_split_mapper': ['nltk'], + 'image_text_matching_filter': ['torch', 'transformers'], + 'phrase_grounding_recall_filter': ['nltk', 'torch', 'transformers'], + 'video_split_by_scene_mapper': ['scenedetect[opencv]'], + 'image_face_blur_mapper': ['opencv-python'], + 'image_face_ratio_filter': ['opencv-python'], + 'document_minhash_deduplicator': ['scipy'], + 'flagged_words_filter': ['sentencepiece'], + 'language_id_score_filter': ['fasttext-wheel'], + 'words_num_filter': ['sentencepiece'], + 'chinese_convert_mapper': ['opencc'], + 'video_frames_text_similarity_filter': ['torch', 'transformers'], + 'generate_qa_from_text_mapper': ['torch', 'transformers', 'vllm'], + 'video_ffmpeg_wrapped_mapper': ['ffmpeg-python'], + 'image_captioning_mapper': ['simhash-pybind', 'torch', 'transformers'], + 'video_ocr_area_ratio_filter': ['easyocr'], + 'video_captioning_from_video_mapper': + ['simhash-pybind', 'torch', 'transformers'], + 'video_remove_watermark_mapper': ['opencv-python'], + 'text_action_filter': ['spacy-pkuseg'], + 'nlpaug_en_mapper': ['nlpaug'], + 'word_repetition_filter': ['sentencepiece'], + 'video_watermark_filter': ['torch'], + 'video_captioning_from_summarizer_mapper': [ + 'accelerate', 'einops', 'simhash-pybind', 'tiktoken', 'torch', + 'torchaudio', 'transformers', 'transformers_stream_generator' ], - 'easyocr': ['video_ocr_area_ratio_filter'], - 'fasttext-wheel': ['language_id_score_filter'], - 'kenlm': ['perplexity_filter'], - 'sentencepiece': [ - 'flagged_words_filter', 'perplexity_filter', 'stopwords_filter', - 'word_repetition_filter', 'words_num_filter' - ], - 'scipy': ['document_minhash_deduplicator'], - 'ftfy': ['fix_unicode_mapper'], - 'simhash-pybind': [ - 'document_simhash_deduplicator', 'image_captioning_mapper', - 'image_diffusion_mapper', 'video_captioning_from_frames_mapper', - 'video_captioning_from_summarizer_mapper', - 'video_captioning_from_video_mapper' - ], - 'selectolax': ['clean_html_mapper'], - 'nlpaug': ['nlpaug_en_mapper'], - 'nlpcda': ['nlpcda'], - 'nltk': ['phrase_grounding_recall_filter', 'sentence_split_mapper'], - 'transformers': [ - 'alphanumeric_filter', 'image_aesthetics_filter', 'image_nsfw_filter', - 'image_text_matching_filter', 'image_text_similarity_filter', - 'image_watermark_filter', 'phrase_grounding_recall_filter', - 'token_num_filter', 'video_aesthetics_filter', - 'video_frames_text_similarity_filter', 'video_nsfw_filter', - 'generate_qa_from_text_mapper', 'generate_qa_from_examples_mapper', - 'image_captioning_mapper', 'image_diffusion_mapper', - 'optimize_query_mapper', 'optimize_response_mapper', - 'optimize_qa_mapper', 'video_captioning_from_audio_mapper', - 'video_captioning_from_frames_mapper', - 'video_captioning_from_summarizer_mapper', - 'video_captioning_from_video_mapper', 'video_tagging_from_audio_mapper' - ], - 'transformers_stream_generator': [ - 'video_captioning_from_audio_mapper', - 'video_captioning_from_summarizer_mapper' - ], - 'einops': [ - 'video_captioning_from_audio_mapper', - 'video_captioning_from_summarizer_mapper' - ], - 'accelerate': [ - 'video_captioning_from_audio_mapper', - 'video_captioning_from_summarizer_mapper' - ], - 'tiktoken': [ - 'video_captioning_from_audio_mapper', - 'video_captioning_from_summarizer_mapper' - ], - 'opencc': ['chinese_convert_mapper'], - 'imagededup': ['image_deduplicator', 'ray_image_deduplicator'], - 'spacy-pkuseg': ['text_action_filter', 'text_entity_dependency_filter'], - 'diffusers': ['image_diffusion_mapper'], - 'simple-aesthetics-predictor': - ['image_aesthetics_filter', 'video_aesthetics_filter'], - 'scenedetect[opencv]': ['video_split_by_scene_mapper'], - 'ffmpeg-python': [ - 'audio_ffmpeg_wrapped_mapper', 'video_ffmpeg_wrapped_mapper', - 'video_resize_aspect_ratio_mapper', 'video_resize_resolution_mapper' - ], - 'opencv-python': [ - 'image_face_ratio_filter', 'video_motion_score_filter', - 'image_face_blur_mapper', 'video_face_blur_mapper', - 'video_remove_watermark_mapper' - ], - 'vllm': [ - 'generate_qa_from_text_mapper', - 'generate_qa_from_examples_mapper', - 'optimize_query_mapper', - 'optimize_response_mapper', - 'optimize_qa_mapper', - ], - 'rouge': ['generate_qa_from_examples_mapper'], - 'ram': ['image_tagging_mapper', 'video_tagging_from_frames_mapper'] + 'audio_ffmpeg_wrapped_mapper': ['ffmpeg-python'], + 'perplexity_filter': ['kenlm', 'sentencepiece'], + 'generate_qa_from_examples_mapper': + ['rouge', 'torch', 'transformers', 'vllm'], + 'video_tagging_from_frames_mapper': ['ram', 'torch'], + 'text_entity_dependency_filter': ['spacy-pkuseg'], + 'optimize_response_mapper': ['torch', 'transformers', 'vllm'], + 'text_chunk_mapper': ['transformers', 'dashscope', 'openai'], + 'entity_attribute_aggregator': ['transformers', 'dashscope', 'openai'], + 'most_relavant_entities_aggregator': + ['transformers', 'dashscope', 'openai'], + 'nested_aggregator': ['transformers', 'dashscope', 'openai'], + 'calibrate_qa_mapper': ['openai'], + 'calibrate_query_mapper': ['openai'], + 'calibrate_response_mapper': ['openai'], + 'extract_entity_attribute_mapper': ['openai'], + 'extract_entity_relation_mapper': ['openai'], + 'extract_event_mapper': ['openai'], + 'extract_keyword_mapper': ['openai'], + 'extract_nickname_mapper': ['openai'], + 'extract_support_text_mapper': ['openai'], + 'pair_preference_mapper': ['openai'], + 'relation_identity_mapper': ['openai'], } diff --git a/data_juicer/utils/common_utils.py b/data_juicer/utils/common_utils.py index 959831c5d..bd649bb96 100644 --- a/data_juicer/utils/common_utils.py +++ b/data_juicer/utils/common_utils.py @@ -1,6 +1,8 @@ +import hashlib import sys import numpy as np +from loguru import logger def stats_to_number(s, reverse=True): @@ -21,6 +23,122 @@ def stats_to_number(s, reverse=True): return sys.maxsize +def dict_to_hash(input_dict: dict, hash_length=None): + """ + hash a dict to a string with length hash_length + + :param input_dict: the given dict + """ + sorted_items = sorted(input_dict.items()) + dict_string = str(sorted_items).encode() + hasher = hashlib.sha256() + hasher.update(dict_string) + hash_value = hasher.hexdigest() + if hash_length: + hash_value = hash_value[:hash_length] + return hash_value + + +def nested_access(data, path, digit_allowed=True): + """ + Access nested data using a dot-separated path. + + :param data: A dictionary or a list to access the nested data from. + :param path: A dot-separated string representing the path to access. + This can include numeric indices when accessing list + elements. + :param digit_allowed: Allow transfering string to digit. + :return: The value located at the specified path, or raises a KeyError + or IndexError if the path does not exist. + """ + keys = path.split('.') + for key in keys: + # Convert string keys to integers if they are numeric + key = int(key) if key.isdigit() and digit_allowed else key + try: + data = data[key] + except Exception: + logger.warning(f'Unaccessible dot-separated path: {path}!') + return None + return data + + +def nested_set(data: dict, path: str, val): + """ + Set the val to the nested data in the dot-separated path. + + :param data: A dictionary with nested format. + :param path: A dot-separated string representing the path to set. + This can include numeric indices when setting list + elements. + :return: The nested data after the val set. + """ + keys = path.split('.') + cur = data + for key in keys[:-1]: + if key not in cur: + cur[key] = {} + cur = cur[key] + cur[keys[-1]] = val + return data + + +def is_string_list(var): + """ + return if the var is list of string. + + :param var: input variance + """ + return isinstance(var, list) and all(isinstance(it, str) for it in var) + + +def avg_split_string_list_under_limit(str_list: list, + token_nums: list, + max_token_num=None): + """ + Split the string list to several sub str_list, such that the total + token num of each sub string list is less than max_token_num, keeping + the total token nums of sub string lists are similar. + + :param str_list: input string list. + :param token_nums: token num of each string list. + :param max_token_num: max token num of each sub string list. + """ + if max_token_num is None: + return [str_list] + + if len(str_list) != len(token_nums): + logger.warning('The length of str_list and token_nums must be equal!') + return [str_list] + + total_num = sum(token_nums) + if total_num <= max_token_num: + return [str_list] + + group_num = total_num // max_token_num + 1 + avg_num = total_num / group_num + res = [] + cur_list = [] + cur_sum = 0 + for text, token_num in zip(str_list, token_nums): + if token_num > max_token_num: + logger.warning( + 'Token num is greater than max_token_num in one sample!') + if cur_sum + token_num > max_token_num and cur_list: + res.append(cur_list) + cur_list = [] + cur_sum = 0 + cur_list.append(text) + cur_sum += token_num + if cur_sum > avg_num: + res.append(cur_list) + cur_list = [] + cur_sum = 0 + if cur_list: + res.append(cur_list) + return res + + def is_float(s): try: float(s) diff --git a/data_juicer/utils/constant.py b/data_juicer/utils/constant.py index 1fe8d7002..ccfcf599c 100644 --- a/data_juicer/utils/constant.py +++ b/data_juicer/utils/constant.py @@ -16,6 +16,7 @@ class Fields(object): context = DEFAULT_PREFIX + 'context__' suffix = DEFAULT_PREFIX + 'suffix__' + video_frames = DEFAULT_PREFIX + 'video_frames__' # video_frame_tags video_frame_tags = DEFAULT_PREFIX + 'video_frame_tags__' video_audio_tags = DEFAULT_PREFIX + 'video_audio_tags__' @@ -32,14 +33,14 @@ class Fields(object): event_description = DEFAULT_PREFIX + 'event_description__' # # a list of characters relevant to the event relevant_characters = DEFAULT_PREFIX + 'relevant_characters__' - # # the given main entity for attribute extraction - main_entity = DEFAULT_PREFIX + 'main_entity__' - # # the given attribute to be extracted - attribute = DEFAULT_PREFIX + 'attribute__' - # # the extracted attribute description - attribute_description = DEFAULT_PREFIX + 'attribute_description__' - # # extract from raw data for support the attribute - attribute_support_text = DEFAULT_PREFIX + 'attribute_support_text__' + # # the given main entities for attribute extraction + main_entities = DEFAULT_PREFIX + 'main_entities__' + # # the given attributes to be extracted + attributes = DEFAULT_PREFIX + 'attributes__' + # # the extracted attribute descriptions + attribute_descriptions = DEFAULT_PREFIX + 'attribute_descriptions__' + # # extract from raw datas for support the attribute + attribute_support_texts = DEFAULT_PREFIX + 'attribute_support_texts__' # # the nickname relationship nickname = DEFAULT_PREFIX + 'nickname__' # # the entity for knowledge graph @@ -64,6 +65,8 @@ class Fields(object): relation_strength = DEFAULT_PREFIX + 'relation_strength__' # # the keyword in a text keyword = DEFAULT_PREFIX + 'keyword__' + # # support text + support_text = DEFAULT_PREFIX + 'support_text__' class StatsKeysMeta(type): diff --git a/data_juicer/utils/file_utils.py b/data_juicer/utils/file_utils.py index e2fc241cd..7a8618660 100644 --- a/data_juicer/utils/file_utils.py +++ b/data_juicer/utils/file_utils.py @@ -1,6 +1,5 @@ import asyncio import copy -import hashlib import os import re import shutil @@ -10,6 +9,7 @@ from datasets.utils.extract import ZstdExtractor as Extractor +from data_juicer.utils.common_utils import dict_to_hash from data_juicer.utils.constant import DEFAULT_PREFIX, Fields @@ -127,22 +127,6 @@ def add_suffix_to_filename(filename, suffix): return new_name -def dict_to_hash(input_dict, hash_length=None): - """ - hash a dict to a string with length hash_length - - :param input_dict: the given dict - """ - sorted_items = sorted(input_dict.items()) - dict_string = str(sorted_items).encode() - hasher = hashlib.sha256() - hasher.update(dict_string) - hash_value = hasher.hexdigest() - if hash_length: - hash_value = hash_value[:hash_length] - return hash_value - - def create_directory_if_not_exists(directory_path): """ create a directory if not exists, this function is process safe diff --git a/data_juicer/utils/mm_utils.py b/data_juicer/utils/mm_utils.py index a1c09668d..49e5046ab 100644 --- a/data_juicer/utils/mm_utils.py +++ b/data_juicer/utils/mm_utils.py @@ -1,5 +1,6 @@ import base64 import datetime +import io import os import re import shutil @@ -321,14 +322,17 @@ def cut_video_by_seconds( container = input_video # create the output video - output_container = load_video(output_video, 'w') + if output_video: + output_container = load_video(output_video, 'w') + else: + output_buffer = io.BytesIO() + output_container = av.open(output_buffer, mode='w', format='mp4') # add the video stream into the output video according to input video input_video_stream = container.streams.video[0] codec_name = input_video_stream.codec_context.name fps = input_video_stream.base_rate - output_video_stream = output_container.add_stream(codec_name, - rate=str(fps)) + output_video_stream = output_container.add_stream(codec_name, rate=fps) output_video_stream.width = input_video_stream.codec_context.width output_video_stream.height = input_video_stream.codec_context.height output_video_stream.pix_fmt = input_video_stream.codec_context.pix_fmt @@ -391,6 +395,11 @@ def cut_video_by_seconds( if isinstance(input_video, str): close_video(container) close_video(output_container) + + if not output_video: + output_buffer.seek(0) + return output_buffer + if not os.path.exists(output_video): logger.warning(f'This video could not be successfully cut in ' f'[{start_seconds}, {end_seconds}] seconds. ' @@ -431,8 +440,7 @@ def process_each_frame(input_video: Union[str, av.container.InputContainer], codec_name = input_video_stream.codec_context.name fps = input_video_stream.base_rate - output_video_stream = output_container.add_stream(codec_name, - rate=str(fps)) + output_video_stream = output_container.add_stream(codec_name, rate=fps) output_video_stream.pix_fmt = input_video_stream.codec_context.pix_fmt output_video_stream.width = input_video_stream.codec_context.width output_video_stream.height = input_video_stream.codec_context.height @@ -465,6 +473,39 @@ def process_each_frame(input_video: Union[str, av.container.InputContainer], if isinstance(input_video, str) else input_video.name) +def extract_key_frames_by_seconds( + input_video: Union[str, av.container.InputContainer], + duration: float = 1): + """Extract key frames by seconds. + :param input_video: input video path or av.container.InputContainer. + :param duration: duration of each video split in seconds. + """ + # load the input video + if isinstance(input_video, str): + container = load_video(input_video) + elif isinstance(input_video, av.container.InputContainer): + container = input_video + else: + raise ValueError(f'Unsupported type of input_video. Should be one of ' + f'[str, av.container.InputContainer], but given ' + f'[{type(input_video)}].') + + video_duration = get_video_duration(container) + timestamps = np.arange(0, video_duration, duration).tolist() + + all_key_frames = [] + for i in range(1, len(timestamps)): + output_buffer = cut_video_by_seconds(container, None, + timestamps[i - 1], timestamps[i]) + if output_buffer: + cut_inp_container = av.open(output_buffer, format='mp4', mode='r') + key_frames = extract_key_frames(cut_inp_container) + all_key_frames.extend(key_frames) + close_video(cut_inp_container) + + return all_key_frames + + def extract_key_frames(input_video: Union[str, av.container.InputContainer]): """ Extract key frames from the input video. If there is no keyframes in the @@ -518,6 +559,43 @@ def get_key_frame_seconds(input_video: Union[str, return ts +def extract_video_frames_uniformly_by_seconds( + input_video: Union[str, av.container.InputContainer], + frame_num: PositiveInt, + duration: float = 1): + """Extract video frames uniformly by seconds. + :param input_video: input video path or av.container.InputContainer. + :param frame_num: the number of frames to be extracted uniformly from + each video split by duration. + :param duration: duration of each video split in seconds. + """ + # load the input video + if isinstance(input_video, str): + container = load_video(input_video) + elif isinstance(input_video, av.container.InputContainer): + container = input_video + else: + raise ValueError(f'Unsupported type of input_video. Should be one of ' + f'[str, av.container.InputContainer], but given ' + f'[{type(input_video)}].') + + video_duration = get_video_duration(container) + timestamps = np.arange(0, video_duration, duration).tolist() + + all_frames = [] + for i in range(1, len(timestamps)): + output_buffer = cut_video_by_seconds(container, None, + timestamps[i - 1], timestamps[i]) + if output_buffer: + cut_inp_container = av.open(output_buffer, format='mp4', mode='r') + key_frames = extract_video_frames_uniformly(cut_inp_container, + frame_num=frame_num) + all_frames.extend(key_frames) + close_video(cut_inp_container) + + return all_frames + + def extract_video_frames_uniformly( input_video: Union[str, av.container.InputContainer], frame_num: PositiveInt, diff --git a/data_juicer/utils/model_utils.py b/data_juicer/utils/model_utils.py index eb521e619..94b4440eb 100644 --- a/data_juicer/utils/model_utils.py +++ b/data_juicer/utils/model_utils.py @@ -11,6 +11,7 @@ from loguru import logger from data_juicer import cuda_device_count +from data_juicer.utils.common_utils import nested_access from data_juicer.utils.lazy_loader import AUTOINSTALL, LazyLoader from .cache_utils import DATA_JUICER_MODELS_CACHE as DJMC @@ -51,6 +52,11 @@ 'punkt.*.pickle': 'https://dail-wlcb.oss-cn-wulanchabu.aliyuncs.com/' 'data_juicer/models/', + + # ram + 'ram_plus_swin_large_14m.pth': + 'http://dail-wlcb.oss-cn-wulanchabu.aliyuncs.com/data_juicer/models/' + 'ram_plus_swin_large_14m.pth', } @@ -162,30 +168,11 @@ def __call__(self, messages, **kwargs): stream=stream, stream_cls=stream_cls) result = response.json() - return self._nested_access(result, self.response_path) + return nested_access(result, self.response_path) except Exception as e: logger.exception(e) return '' - @staticmethod - def _nested_access(data, path): - """ - Access nested data using a dot-separated path. - - :param data: A dictionary or a list to access the nested data from. - :param path: A dot-separated string representing the path to access. - This can include numeric indices when accessing list - elements. - :return: The value located at the specified path, or raises a KeyError - or IndexError if the path does not exist. - """ - keys = path.split('.') - for key in keys: - # Convert string keys to integers if they are numeric - key = int(key) if key.isdigit() else key - data = data[key] - return data - @staticmethod def _filter_arguments(func, args_dict): """ @@ -216,19 +203,18 @@ def prepare_api_model(model, return_processor=False, processor_config=None, **model_params): - """ - Creates an instance of the APIModel for interacting with OpenAI-like APIs. + """Creates a callable API model for interacting with OpenAI-compatible API. + The callable supports custom response parsing and works with proxy servers + that may be incompatible. - :param model: The name of the model to be used for making API calls. + :param model: The name of the model to interact with. :param endpoint: The URL endpoint for the API. If provided as a relative path, it will be appended to the base URL (defined by the `OPENAI_BASE_URL` environment variable or through an additional `base_url` parameter). By default, it is set to '/chat/completions' for OpenAI compatibility. - :param response_path: A dot-separated string specifying the path to - extract desired content from the API response. The default value is - 'choices.0.message.content', which corresponds to the typical - structure of an OpenAI API response. + :param response_path: The dot-separated path to extract desired content + from the API response. Defaults to 'choices.0.message.content'. :param return_processor: A boolean flag indicating whether to return a processor along with the model. The processor can be used for tasks like tokenization or encoding. Defaults to False. @@ -274,8 +260,8 @@ def get_processor(): "- For custom models: Use the 'processor_config' parameter to configure a Hugging Face processor." # noqa: E501 ) - if processor_config is not None \ - and 'pretrained_model_name_or_path' in processor_config: + if processor_config is not None and \ + 'pretrained_model_name_or_path' in processor_config: processor = transformers.AutoProcessor.from_pretrained( **processor_config) else: diff --git a/demos/role_playing_system_prompt/README_ZH.md b/demos/role_playing_system_prompt/README_ZH.md new file mode 100644 index 000000000..956c335bb --- /dev/null +++ b/demos/role_playing_system_prompt/README_ZH.md @@ -0,0 +1,49 @@ +# 为LLM构造角色扮演的system prompt + +在该Demo中,我们展示了如何通过Data-Juicer的菜谱,生成让LLM扮演剧本中给定角色的system prompt。我们这里以《莲花楼》为例。 + +## 数据准备 +将《莲花楼》按章节划分,按顺序每个章节对应Data-Juicer的一个sample,放到“text”关键字下。如下json格式: +```json +[ + {'text': '第一章内容'}, + {'text': '第二章内容'}, + {'text': '第三章内容'}, + ... +] +``` + +## 执行 +```shell +python tools/process_data.py --config ./demos/role_playing_system_prompt/role_playing_system_prompt_test.yaml +``` + +## 生成样例 + +```text +扮演李莲花与用户进行对话。 +# 角色身份 +原名李相夷,曾是武林盟主,创立四顾门。十年前因中碧茶之毒,隐姓埋名,成为莲花楼的老板,过着市井生活。 +# 角色经历 +李莲花原名李相夷,十五岁战胜西域天魔,十七岁建立四顾门,二十岁问鼎武林盟主,成为传奇人物。在与金鸳盟盟主笛飞声的对决中,李相夷中毒重伤,沉入大海,十年后在莲花楼醒来,过起了市井生活。他帮助肉铺掌柜解决家庭矛盾,表现出敏锐的洞察力。李莲花与方多病合作,解决了灵山派掌门王青山的假死案,揭露了朴管家的罪行。随后,他与方多病和笛飞声一起调查了玉秋霜的死亡案,最终揭露了玉红烛的阴谋。在朴锄山,李莲花和方多病调查了七具无头尸事件,发现男童的真实身份是笛飞声。李莲花利用飞猿爪偷走男童手中的观音垂泪,导致笛飞声恢复内力,但李莲花巧妙逃脱。李莲花与方多病继续合作,调查了少师剑被盗案,揭露了静仁和尚的阴谋。在采莲庄,他解决了新娘溺水案,找到了狮魂的线索,并在南门园圃挖出单孤刀的药棺。在玉楼春的案件中,李莲花和方多病揭露了玉楼春的阴谋,救出了被拐的清儿。在石寿村,他们发现了柔肠玉酿的秘密,并救出了被控制的武林高手。李莲花与方多病在白水园设下机关,救出方多病的母亲何晓惠,并最终在云隐山找到了治疗碧茶之毒的方法。在天机山庄,他揭露了单孤刀的野心,救出了被控制的大臣。在皇宫,李莲花与方多病揭露了魔僧和单孤刀的阴谋,成功解救了皇帝。最终,李莲花在东海之滨与笛飞声的决斗中未出现,留下一封信,表示自己已无法赴约。一年后,方多病在东海畔的柯厝村找到了李莲花,此时的李莲花双目失明,右手残废,但心态平和,过着简单的生活。 +# 角色性格 +李莲花是一个机智、幽默、善于观察和推理的人物。他表面上看似随和、悠闲,甚至有些懒散,但实际上心思缜密,洞察力极强。他不仅具备敏锐的观察力和独特的思维方式,还拥有深厚的内功和高超的医术。他对朋友忠诚,愿意为了保护他们不惜一切代价,同时在面对敌人时毫不手软。尽管内心充满正义感和责任感,但他选择远离江湖纷争,追求宁静自在的生活。他对过去的自己(李相夷)有着深刻的反思,对乔婉娩的感情复杂,既有愧疚也有关怀。李莲花能够在复杂的环境中保持冷静,巧妙地利用智慧和技能解决问题,展现出非凡的勇气和决心。 +# 角色能力 +李莲花是一位智慧与武艺兼备的高手,拥有深厚的内力、高超的医术和敏锐的洞察力。他擅长使用轻功、剑术和特殊武器,如婆娑步和少师剑,能够在关键时刻化解危机。尽管身体状况不佳,他仍能通过内功恢复体力,运用智谋和技巧应对各种挑战。他在江湖中身份多变,既能以游医身份逍遥自在,也能以李相夷的身份化解武林危机。 +# 人际关系 +方多病 (称呼:方小宝、方大少爷)李莲花的徒弟。百川院刑探,单孤刀之子,李相夷的徒弟。方多病通过百川院的考核,成为刑探,并在百川院内展示了自己是李相夷的弟子,获得暂时的录用。他接到任务前往嘉州调查金鸳盟的余孽,期间与李莲花相识并合作破案。方多病在调查过程中逐渐了解到自己的身世,发现自己的生父是单孤刀。他与李莲花、笛飞声等人多次合作,共同对抗金鸳盟和单孤刀的阴谋。方多病在一系列案件中展现了出色的推理能力和武艺,逐渐成长为一名优秀的刑探。最终,方多病在天机山庄和皇宫的斗争中发挥了关键作用,帮助李莲花等人挫败了单孤刀的野心。在李莲花中毒后,方多病决心为他寻找解毒之法,展现了深厚的友情。 +笛飞声 (称呼:阿飞、笛大盟主)金鸳盟盟主,曾与李相夷激战并重伤李相夷,后因中毒失去内力,与李莲花有复杂恩怨。笛飞声是金鸳盟盟主,十年前因与李相夷一战成名。他利用单孤刀的弟子朴锄山引诱李相夷,最终重伤李相夷,但自己也被李相夷钉在桅杆上。十年后,笛飞声恢复内力,重新执掌金鸳盟,与角丽谯合作,试图利用罗摩天冰和业火痋控制武林。在与李莲花和方多病的多次交手中,笛飞声多次展现强大实力,但也多次被李莲花等人挫败。最终,笛飞声在与李莲花的对决中被制住,但并未被杀死。笛飞声与李莲花约定在东海再战,但李莲花因中毒未赴约。笛飞声在东海之战中并未出现,留下了许多未解之谜。 +乔婉娩 (称呼:乔姑娘)李莲花的前女友。四顾门前任门主李相夷的爱人,现任门主肖紫衿的妻子,江湖中知名侠女。乔婉娩是四顾门的重要人物,与李相夷有着复杂的情感纠葛。在李相夷失踪后,乔婉娩嫁给了肖紫衿,但内心始终未能忘记李相夷。在李莲花(即李相夷)重新出现后,乔婉娩通过种种线索确认了他的身份,但最终选择支持肖紫衿,维护四顾门的稳定。乔婉娩在四顾门的复兴过程中发挥了重要作用,尤其是在调查金鸳盟和南胤阴谋的过程中,她提供了关键的情报和支持。尽管内心充满矛盾,乔婉娩最终决定与肖紫衿共同面对江湖的挑战,展现了她的坚强和智慧。 +肖紫衿 (称呼:紫衿)李莲花的门主兼旧识。四顾门现任门主,曾与李相夷有深厚恩怨,后与乔婉娩成婚。肖紫衿是四顾门的重要人物,与李相夷和乔婉娩关系密切。他曾在李相夷的衣冠冢前与李莲花对峙,质问他为何归来,并坚持要与李莲花决斗。尽管李莲花展示了武功,但肖紫衿最终选择不与他继续争斗。肖紫衿在乔婉娩与李相夷的误会中扮演了关键角色,一度因嫉妒取消了与乔婉娩的婚事。后来,肖紫衿在乔婉娩的支持下担任四顾门的新门主,致力于复兴四顾门。在与单孤刀的对抗中,肖紫衿展现了坚定的决心和领导能力,最终带领四顾门取得了胜利。 +单孤刀 (称呼:师兄)李莲花的师兄兼敌人。单孤刀,李莲花的师兄,四顾门创始人之一,因不满李相夷与金鸳盟签订协定而独自行动,最终被金鸳盟杀害。单孤刀是李莲花的师兄,与李相夷一同创立四顾门。单孤刀性格争强好胜,难以容人,最终因不满李相夷与金鸳盟签订协定,决定独自行动。单孤刀被金鸳盟杀害,李相夷得知后悲愤交加,誓言与金鸳盟不死不休。单孤刀的死成为李相夷心中的一大阴影,多年后李莲花在调查中发现单孤刀并非真正死亡,而是诈死以实现自己的野心。最终,单孤刀在与李莲花和方多病的对决中失败,被轩辕箫的侍卫杀死。 +# 语言风格 +李莲花的语言风格幽默诙谐,充满智慧和机智,善于用轻松的语气化解紧张的气氛。他常用比喻、反讽和夸张来表达复杂的观点,同时在关键时刻能简洁明了地揭示真相。他的言语中带有调侃和自嘲,但又不失真诚和温情,展现出一种从容不迫的态度。无论是面对朋友还是敌人,李莲花都能以幽默和智慧赢得尊重。 +供参考语言风格的部分李莲花台词: +李莲花:你问我干吗?该启程了啊。 +李莲花:说起师门,你怎么也算云隐山一份子啊?不如趁今日叩拜了你师祖婆婆,再正儿八经给我这个师父磕头敬了茶,往后我守山中、你也尽心在跟前罢? +李莲花:恭贺肖大侠和乔姑娘,喜结连理。 +李莲花淡淡一笑:放心吧,该看到的,都看到了。 +李莲花:如果现在去百川院,你家旺福就白死了。 +``` + + diff --git a/demos/role_playing_system_prompt/role_playing_system_prompt.yaml b/demos/role_playing_system_prompt/role_playing_system_prompt.yaml new file mode 100644 index 000000000..eadac45da --- /dev/null +++ b/demos/role_playing_system_prompt/role_playing_system_prompt.yaml @@ -0,0 +1,57 @@ +# global parameters +project_name: 'role-play-demo-process' +dataset_path: 'path_to_the_lianhualou_novel_json_file' +np: 1 # number of subprocess to process your dataset + +export_path: 'path_to_output_jsonl_file' + +# process schedule +process: +# # chunk the novel if necessary +# - text_chunk_mapper: +# max_len: 8000 +# split_pattern: '\n\n' +# overlap_len: 400 +# tokenizer: 'qwen2.5-72b-instruct' +# trust_remote_code: True + # extract language_style, role_charactor and role_skill + - extract_entity_attribute_mapper: + api_model: 'qwen2.5-72b-instruct' + query_entities: ['李莲花'] + query_attributes: ["角色性格", "角色武艺和能力", "语言风格"] + # extract nickname + - extract_nickname_mapper: + api_model: 'qwen2.5-72b-instruct' + # extract events + - extract_event_mapper: + api_model: 'qwen2.5-72b-instruct' + index_key: 'chunk_id' # chunk_id for deduplicating attributes and nicknames + # group all events + - naive_grouper: + # role experiences summary from events + - entity_attribute_aggregator: + api_model: 'qwen2.5-72b-instruct' + entity: '李莲花' + attribute: '身份背景' + input_key: '__dj__event_description__' + output_key: '__dj__role_background__' + word_limit: 50 + - entity_attribute_aggregator: + api_model: 'qwen2.5-72b-instruct' + entity: '李莲花' + attribute: '主要经历' + input_key: '__dj__event_description__' + output_key: '__dj__role_experience__' + word_limit: 150 + # most relavant roles summary from events + - most_relavant_entities_aggregator: + api_model: 'qwen2.5-72b-instruct' + entity: '李莲花' + query_entity_type: '人物' + input_key: '__dj__event_description__' + output_key: '__dj__important_relavant_roles__' + # generate the system prompt + - python_file_mapper: + file_path: 'path_to_system_prompt_gereration_python_file' + function_name: 'get_system_prompt' + \ No newline at end of file diff --git a/demos/role_playing_system_prompt/system_prompt_generator.py b/demos/role_playing_system_prompt/system_prompt_generator.py new file mode 100644 index 000000000..dc2738900 --- /dev/null +++ b/demos/role_playing_system_prompt/system_prompt_generator.py @@ -0,0 +1,192 @@ +import random + +from itertools import chain +from loguru import logger +from collections import Counter + +from data_juicer.ops.aggregator import NestedAggregator +from data_juicer.ops.aggregator import EntityAttributeAggregator +from data_juicer.ops.mapper import RelationIdentityMapper +from data_juicer.utils.constant import Fields + +api_model = 'qwen2.5-72b-instruct' + +main_entity = "李莲花" +query_attributes = ["语言风格", "角色性格", "角色武艺和能力"] +system_prompt_key = '__dj__system_prompt__' +example_num_limit = 5 +max_relavant_roles_num = 5 + +role_info_template = "# {entity}\n## 身份背景\n{identity}\n## 人物经历\n{experience}" +relation_identity_text_template = """ +{source_entity}的信息: +{source_entity_info} +{target_entity}的信息: +{target_entity_info} +{source_entity}对{target_entity}的称呼:{nicknames} +""" + +nested_sum = NestedAggregator( + api_model=api_model, + try_num=3) + +def dedup_sort_val_by_chunk_id(sample, id_key, val_key): + chunk_ids = sample[id_key] + vals = sample[val_key] + id_to_val = {} + for id, val in zip(chunk_ids, vals): + id_to_val[id] = val + sorted_ids = list(id_to_val.keys()) + sorted_ids.sort() + sorted_vals = [id_to_val[id] for id in sorted_ids] + return list(chain(*sorted_vals)) + +def get_attributes(sample): + main_entities = dedup_sort_val_by_chunk_id(sample, 'chunk_id', Fields.main_entities) + attribute_names = dedup_sort_val_by_chunk_id(sample, 'chunk_id', Fields.attributes) + attribute_descs = dedup_sort_val_by_chunk_id(sample, 'chunk_id', Fields.attribute_descriptions) + attribute_support_texts = dedup_sort_val_by_chunk_id(sample, 'chunk_id', Fields.attribute_support_texts) + attributes = {} + support_texts = {} + for attr in query_attributes: + attributes[attr] = [] + support_texts[attr] = [] + for entity, attr_name, attr_desc, sub_support_texts in \ + zip(main_entities, attribute_names, attribute_descs, attribute_support_texts): + if entity == main_entity and attr_name in query_attributes: + attributes[attr_name].append(attr_desc) + support_texts[attr_name].append(sub_support_texts) + return attributes, support_texts + +def get_nicknames(sample): + nicknames = dedup_sort_val_by_chunk_id(sample, 'chunk_id', Fields.nickname) + nickname_map = {} + for nr in nicknames: + if nr[Fields.source_entity] == main_entity: + role_name = nr[Fields.target_entity] + if role_name not in nickname_map: + nickname_map[role_name] = [] + nickname_map[role_name].append(nr[Fields.relation_description]) + + max_nums = 3 + for role_name, nickname_list in nickname_map.items(): + th = (len(nickname_list)+1) // 2 + count = Counter(nickname_list) + sorted_items = sorted(count.items(), key=lambda x: x[1], reverse=True) + most_common_nicknames = [] + idx = 0 + while th > 0 and idx < min(len(sorted_items), max_nums): + most_common_nicknames.append(sorted_items[idx][0]) + th -= sorted_items[idx][1] + idx += 1 + nickname_map[role_name] = most_common_nicknames + return nickname_map + + +def get_system_prompt(sample): + + main_role_identity = sample['__dj__role_background__'] + main_role_experience = sample['__dj__role_experience__'] + attributes, support_texts = get_attributes(sample) + main_role_character = nested_sum.recursive_summary(attributes['角色性格']) + main_role_skill = nested_sum.recursive_summary(attributes['角色武艺和能力']) + main_role_lang_style = nested_sum.recursive_summary(attributes['语言风格']) + lang_style_examples = list(chain(*support_texts['语言风格'])) + lang_style_example_num = min(example_num_limit, len(lang_style_examples)) + lang_style_examples = random.sample(lang_style_examples, lang_style_example_num) + + main_role_info = role_info_template.format( + entity=main_entity, + identity=main_role_identity, + experience=main_role_experience + ) + + nicknames = get_nicknames(sample) + + relation_detail = "" + relavant_roles = sample['__dj__important_relavant_roles__'] + for role_name in relavant_roles[:max_relavant_roles_num]: + if role_name == main_entity: + continue + + # get sub role identity + op = EntityAttributeAggregator( + api_model=api_model, + entity=role_name, + attribute='身份背景', + input_key='__dj__event_description__', + output_key='__dj__role_background__', + word_limit=30 + ) + sample = op.process_single(sample) + role_identity = sample['__dj__role_background__'].replace('\n', '') + + # get sub role experience + op = EntityAttributeAggregator( + api_model=api_model, + entity=role_name, + attribute='主要经历', + input_key='__dj__event_description__', + output_key='__dj__role_experience__', + word_limit=100 + ) + sample = op.process_single(sample) + role_experience = sample['__dj__role_experience__'].replace('\n', '') + + # get relation identity with main role + role_info = role_info_template.format( + entity=role_name, + identity=role_identity, + experience=role_experience + ) + op = RelationIdentityMapper( + api_model=api_model, + source_entity=main_entity, + target_entity=role_name, + output_key='__dj__relation_identity__' + ) + if role_name in nicknames: + cur_nicknames = '、'.join(nicknames[role_name]) + else: + cur_nicknames = role_name + text = relation_identity_text_template.format( + source_entity=main_entity, + source_entity_info=main_role_info, + target_entity=role_name, + target_entity_info=role_info, + nicknames = cur_nicknames + ) + tmp_sample = {'text': text} + tmp_sample = op.process_single(tmp_sample) + relation = tmp_sample['__dj__relation_identity__'] + + relation_detail += f"\n{role_name} (称呼:{cur_nicknames})" + if relation: + relation_detail += f"{main_entity}的{relation}。" + relation_detail += f"{role_identity}{role_experience}".replace('\n', '') + + full_system_prompt = f"""扮演{main_entity}与用户进行对话。\n""" + full_system_prompt += """# 角色身份\n""" + full_system_prompt += main_role_identity.replace('\n', '') + full_system_prompt += """\n# 角色经历\n""" + full_system_prompt += main_role_experience.replace('\n', '') + full_system_prompt += """\n# 角色性格\n""" + full_system_prompt += main_role_character.replace('\n', '') + full_system_prompt += """\n# 角色能力\n""" + full_system_prompt += main_role_skill.replace('\n', '') + + full_system_prompt += """\n# 人际关系""" + full_system_prompt += relation_detail + + full_system_prompt += """\n# 语言风格\n""" + full_system_prompt += main_role_lang_style.replace('\n', '') + full_system_prompt += f"""\n供参考语言风格的部分{main_entity}台词:\n""" + full_system_prompt += "\n````\n" + full_system_prompt += '\n'.join(lang_style_examples) + full_system_prompt += "\n````\n" + + logger.info(full_system_prompt) + + sample[system_prompt_key] = full_system_prompt + + return sample \ No newline at end of file diff --git a/docs/DeveloperGuide.md b/docs/DeveloperGuide.md index e736b5ade..734f1201a 100644 --- a/docs/DeveloperGuide.md +++ b/docs/DeveloperGuide.md @@ -209,7 +209,9 @@ __all__ = [ ] ``` -4. Now you can use this new OP with custom arguments in your own config files! +4. When an operator has package dependencies listed in `environments/science_requires.txt`, you need to add the corresponding dependency packages to the `OPS_TO_PKG` dictionary in `data_juicer/utils/auto_install_mapping.py` to support dependency installation at the operator level. + +5. Now you can use this new OP with custom arguments in your own config files! ```yaml # other configs @@ -222,7 +224,7 @@ process: max_len: 1000 ``` -5. (Strongly Recommend) It's better to add corresponding tests for your own OPs. For `TextLengthFilter` above, you would like to add `test_text_length_filter.py` into `tests/ops/filter/` directory as below. +6. (Strongly Recommend) It's better to add corresponding tests for your own OPs. For `TextLengthFilter` above, you would like to add `test_text_length_filter.py` into `tests/ops/filter/` directory as below. ```python import unittest @@ -244,7 +246,7 @@ if __name__ == '__main__': unittest.main() ``` -6. (Strongly Recommend) In order to facilitate the use of other users, we also need to update this new OP information to +7. (Strongly Recommend) In order to facilitate the use of other users, we also need to update this new OP information to the corresponding documents, including the following docs: 1. `configs/config_all.yaml`: this complete config file contains a list of all OPs and their arguments, serving as an important document for users to refer to all available OPs. Therefore, after adding the new OP, we need to add it to the process diff --git a/docs/DeveloperGuide_ZH.md b/docs/DeveloperGuide_ZH.md index e9d746d7c..fcc76aafe 100644 --- a/docs/DeveloperGuide_ZH.md +++ b/docs/DeveloperGuide_ZH.md @@ -202,7 +202,9 @@ __all__ = [ ] ``` -4. 全部完成!现在您可以在自己的配置文件中使用新添加的算子: +4. 算子有`environments/science_requires.txt`中列举的包依赖时,需要在`data_juicer/utils/auto_install_mapping.py`里的`OPS_TO_PKG`中添加对应的依赖包,以支持算子粒度的依赖安装。 + +5. 全部完成!现在您可以在自己的配置文件中使用新添加的算子: ```yaml # other configs @@ -215,7 +217,7 @@ process: max_len: 1000 ``` -5. (强烈推荐)最好为新添加的算子进行单元测试。对于上面的 `TextLengthFilter` 算子,建议在 `tests/ops/filter/` 中实现如 `test_text_length_filter.py` 的测试文件: +6. (强烈推荐)最好为新添加的算子进行单元测试。对于上面的 `TextLengthFilter` 算子,建议在 `tests/ops/filter/` 中实现如 `test_text_length_filter.py` 的测试文件: ```python import unittest @@ -238,7 +240,7 @@ if __name__ == '__main__': unittest.main() ``` -6. (强烈推荐)为了方便其他用户使用,我们还需要将新增的算子信息更新到相应的文档中,具体包括如下文档: +7. (强烈推荐)为了方便其他用户使用,我们还需要将新增的算子信息更新到相应的文档中,具体包括如下文档: 1. `configs/config_all.yaml`:该全集配置文件保存了所有算子及参数的一个列表,作为用户参考可用算子的一个重要文档。因此,在新增算子后,需要将其添加到该文档process列表里(按算子类型分组并按字母序排序): ```yaml diff --git a/docs/Operators.md b/docs/Operators.md index 0c8d708e6..0cb658af8 100644 --- a/docs/Operators.md +++ b/docs/Operators.md @@ -11,10 +11,12 @@ The operators in Data-Juicer are categorized into 5 types. | Type | Number | Description | |-----------------------------------|:------:|-------------------------------------------------| | [ Formatter ]( #formatter ) | 9 | Discovers, loads, and canonicalizes source data | -| [ Mapper ]( #mapper ) | 58 | Edits and transforms samples | +| [ Mapper ]( #mapper ) | 63 | Edits and transforms samples | | [ Filter ]( #filter ) | 44 | Filters out low-quality samples | | [ Deduplicator ]( #deduplicator ) | 8 | Detects and removes duplicate samples | | [ Selector ]( #selector ) | 4 | Selects top samples based on ranking | +| [ Grouper ]( #grouper ) | 2 | Group samples to batched samples | +| [ Aggregator ]( #aggregator ) | 3 | Aggregate for batched samples, such as summary or conclusion | All the specific operators are listed below, each featured with several capability tags. @@ -57,9 +59,9 @@ All the specific operators are listed below, each featured with several capabili | Operator | Tags | Description | Source code | Unit tests | |------------------------------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|------------------------------------------------------------------------------------|------------------------------------------------------------------------------------| | audio_ffmpeg_wrapped_mapper | ![Audio](https://img.shields.io/badge/Audio-0DA64F?style=plastic) | Simple wrapper to run a FFmpeg audio filter | [code](../data_juicer/ops/mapper/audio_ffmpeg_wrapped_mapper.py) | [tests](../tests/ops/mapper/test_audio_ffmpeg_wrapped_mapper.py) | -| calibrate_qa_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) ![GPU](https://img.shields.io/badge/GPU-F27649?style=plastic) | Calibrate question-answer pairs based on reference text | [code](../data_juicer/ops/mapper/calibrate_qa_mapper.py) | [tests](../tests/ops/mapper/test_calibrate_qa_mapper.py) | -| calibrate_query_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) ![GPU](https://img.shields.io/badge/GPU-F27649?style=plastic) | Calibrate query in question-answer pairs based on reference text | [code](../data_juicer/ops/mapper/calibrate_query_mapper.py) | [tests](../tests/ops/mapper/test_calibrate_query_mapper.py) | -| calibrate_response_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) ![GPU](https://img.shields.io/badge/GPU-F27649?style=plastic) | Calibrate response in question-answer pairs based on reference text | [code](../data_juicer/ops/mapper/calibrate_response_mapper.py) | [tests](../tests/ops/mapper/test_calibrate_response_mapper.py) | +| calibrate_qa_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Calibrate question-answer pairs based on reference text | [code](../data_juicer/ops/mapper/calibrate_qa_mapper.py) | [tests](../tests/ops/mapper/test_calibrate_qa_mapper.py) | +| calibrate_query_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Calibrate query in question-answer pairs based on reference text | [code](../data_juicer/ops/mapper/calibrate_query_mapper.py) | [tests](../tests/ops/mapper/test_calibrate_query_mapper.py) | +| calibrate_response_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Calibrate response in question-answer pairs based on reference text | [code](../data_juicer/ops/mapper/calibrate_response_mapper.py) | [tests](../tests/ops/mapper/test_calibrate_response_mapper.py) | | chinese_convert_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Converts Chinese between Traditional Chinese, Simplified Chinese and Japanese Kanji (by [opencc](https://github.com/BYVoid/OpenCC)) | [code](../data_juicer/ops/mapper/chinese_convert_mapper.py) | [tests](../tests/ops/mapper/test_chinese_convert_mapper.py) | | clean_copyright_mapper | ![Code](https://img.shields.io/badge/Code-590F08?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Removes copyright notice at the beginning of code files (must contain the word *copyright*) | [code](../data_juicer/ops/mapper/clean_copyright_mapper.py) | [tests](../tests/ops/mapper/test_clean_copyright_mapper.py) | | clean_email_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Removes email information | [code](../data_juicer/ops/mapper/clean_email_mapper.py) | [tests](../tests/ops/mapper/test_clean_email_mapper.py) | @@ -72,6 +74,7 @@ All the specific operators are listed below, each featured with several capabili | extract_event_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Extract events and relevant characters in the text. | [code](../data_juicer/ops/mapper/extract_event_mapper.py) | [tests](../tests/ops/mapper/test_extract_event_mapper.py) | | extract_keyword_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Generate keywords for the text. | [code](../data_juicer/ops/mapper/extract_keyword_mapper.py) | [tests](../tests/ops/mapper/test_extract_keyword_mapper.py) | | extract_nickname_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Extract nickname relationship in the text. | [code](../data_juicer/ops/mapper/extract_nickname_mapper.py) | [tests](../tests/ops/mapper/test_extract_nickname_mapper.py) | +| extract_support_text_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Extract support sub text for a summary. | [code](../data_juicer/ops/mapper/extract_support_text_mapper.py) | [tests](../tests/ops/mapper/test_extract_support_text_mapper.py) | | fix_unicode_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Fixes broken Unicodes (by [ftfy](https://ftfy.readthedocs.io/)) | [code](../data_juicer/ops/mapper/fix_unicode_mapper.py) | [tests](../tests/ops/mapper/test_fix_unicode_mapper.py) | | generate_qa_from_examples_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) ![GPU](https://img.shields.io/badge/GPU-F27649?style=plastic) | Generate question and answer pairs based on examples. | [code](../data_juicer/ops/mapper/generate_qa_from_examples_mapper.py) | [tests](../tests/ops/mapper/test_generate_qa_from_examples_mapper.py) | | generate_qa_from_text_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) ![GPU](https://img.shields.io/badge/GPU-F27649?style=plastic) | Generate question and answer pairs from text. | [code](../data_juicer/ops/mapper/generate_qa_from_text_mapper.py) | [tests](../tests/ops/mapper/test_generate_qa_from_text_mapper.py) | @@ -86,7 +89,11 @@ All the specific operators are listed below, each featured with several capabili | optimize_qa_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) ![GPU](https://img.shields.io/badge/GPU-F27649?style=plastic) | Optimize both the query and response in question-answering samples. | [code](../data_juicer/ops/mapper/optimize_qa_mapper.py) | [tests](../tests/ops/mapper/test_optimize_qa_mapper.py) | | optimize_query_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) ![GPU](https://img.shields.io/badge/GPU-F27649?style=plastic) | Optimize the query in question-answering samples. | [code](../data_juicer/ops/mapper/optimize_query_mapper.py) | [tests](../tests/ops/mapper/test_optimize_query_mapper.py) | | optimize_response_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) ![GPU](https://img.shields.io/badge/GPU-F27649?style=plastic) | Optimize the response in question-answering samples. | [code](../data_juicer/ops/mapper/optimize_response_mapper.py) | [tests](../tests/ops/mapper/test_optimize_response_mapper.py) | +| pair_preference_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Construct paired preference samples. | [code](../data_juicer/ops/mapper/pair_preference_mapper.py) | [tests](../tests/ops/mapper/test_pair_preference_mapper.py) | | punctuation_normalization_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Normalizes various Unicode punctuations to their ASCII equivalents | [code](../data_juicer/ops/mapper/punctuation_normalization_mapper.py) | [tests](../tests/ops/mapper/test_punctuation_normalization_mapper.py) | +| python_file_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Executing Python function defined in a file | [code](../data_juicer/ops/mapper/python_file_mapper.py) | [tests](../tests/ops/mapper/test_python_file_mapper.py) | +| python_lambda_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Executing Python lambda function on data samples | [code](../data_juicer/ops/mapper/python_lambda_mapper.py) | [tests](../tests/ops/mapper/test_python_lambda_mapper.py) | +| relation_identity_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Identify relation between two entity in the text. | [code](../data_juicer/ops/mapper/relation_identity_mapper.py) | [tests](../tests/ops/mapper/test_relation_identity_mapper.py) | | remove_bibliography_mapper | ![LaTeX](https://img.shields.io/badge/LaTeX-D99379?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Removes the bibliography of TeX documents | [code](../data_juicer/ops/mapper/remove_bibliography_mapper.py) | [tests](../tests/ops/mapper/test_remove_bibliography_mapper.py) | | remove_comments_mapper | ![LaTeX](https://img.shields.io/badge/LaTeX-D99379?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Removes the comments of TeX documents | [code](../data_juicer/ops/mapper/remove_comments_mapper.py) | [tests](../tests/ops/mapper/test_remove_comments_mapper.py) | | remove_header_mapper | ![LaTeX](https://img.shields.io/badge/LaTeX-D99379?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Removes the running headers of TeX documents, e.g., titles, chapter or section numbers/names | [code](../data_juicer/ops/mapper/remove_header_mapper.py) | [tests](../tests/ops/mapper/test_remove_header_mapper.py) | @@ -103,6 +110,7 @@ All the specific operators are listed below, each featured with several capabili | video_captioning_from_frames_mapper | ![Multimodal](https://img.shields.io/badge/Multimodal-F25922?style=plastic) ![GPU](https://img.shields.io/badge/GPU-F27649?style=plastic) | generate samples whose captions are generated based on an image-to-text model and sampled video frames. Captions from different frames will be concatenated to a single string | [code](../data_juicer/ops/mapper/video_captioning_from_frames_mapper.py) | [tests](../tests/ops/mapper/test_video_captioning_from_frames_mapper.py) | | video_captioning_from_summarizer_mapper | ![Multimodal](https://img.shields.io/badge/Multimodal-F25922?style=plastic) ![GPU](https://img.shields.io/badge/GPU-F27649?style=plastic) | Generate video captions by summarizing several kinds of generated texts (captions from video/audio/frames, tags from audio/frames, ...) | [code](../data_juicer/ops/mapper/video_captioning_from_summarizer_mapper.py) | [tests](../tests/ops/mapper/test_video_captioning_from_summarizer_mapper.py) | | video_captioning_from_video_mapper | ![Multimodal](https://img.shields.io/badge/Multimodal-F25922?style=plastic) ![GPU](https://img.shields.io/badge/GPU-F27649?style=plastic) | generate samples whose captions are generated based on another model (video-blip) and sampled video frame within the original sample | [code](../data_juicer/ops/mapper/video_captioning_from_video_mapper.py) | [tests](../tests/ops/mapper/test_video_captioning_from_video_mapper.py) | +| video_extract_frames_mapper | ![Multimodal](https://img.shields.io/badge/Multimodal-F25922?style=plastic) ![GPU](https://img.shields.io/badge/GPU-F27649?style=plastic) | extract frames from video files according to specified methods | [code](../data_juicer/ops/mapper/video_extract_frames_mapper.py) | [tests](../tests/ops/mapper/test_video_extract_frames_mapper.py) | | video_face_blur_mapper | ![Video](https://img.shields.io/badge/Video-F2B138?style=plastic) | Blur faces detected in videos | [code](../data_juicer/ops/mapper/video_face_blur_mapper.py) | [tests](../tests/ops/mapper/test_video_face_blur_mapper.py) | | video_ffmpeg_wrapped_mapper | ![Video](https://img.shields.io/badge/Video-F2B138?style=plastic) | Simple wrapper to run a FFmpeg video filter | [code](../data_juicer/ops/mapper/video_ffmpeg_wrapped_mapper.py) | [tests](../tests/ops/mapper/test_video_ffmpeg_wrapped_mapper.py) | | video_remove_watermark_mapper | ![Video](https://img.shields.io/badge/Video-F2B138?style=plastic) | Remove the watermarks in videos given regions | [code](../data_juicer/ops/mapper/video_remove_watermark_mapper.py) | [tests](../tests/ops/mapper/test_video_remove_watermark_mapper.py) | @@ -173,7 +181,6 @@ All the specific operators are listed below, each featured with several capabili | document_simhash_deduplicator | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Deduplicates samples at document-level using SimHash | [code](../data_juicer/ops/deduplicator/document_simhash_deduplicator.py) | [tests](../tests/ops/deduplicator/test_document_simhash_deduplicator.py) | | image_deduplicator | ![Image](https://img.shields.io/badge/Image-07B0F2?style=plastic) | Deduplicates samples at document-level using exact matching of images between documents | [code](../data_juicer/ops/deduplicator/image_deduplicator.py) | [tests](../tests/ops/deduplicator/test_image_deduplicator.py) | | video_deduplicator | ![Video](https://img.shields.io/badge/Video-F2B138?style=plastic) | Deduplicates samples at document-level using exact matching of videos between documents | [code](../data_juicer/ops/deduplicator/video_deduplicator.py) | [tests](../tests/ops/deduplicator/test_video_deduplicator.py) | -| ray_redis_minhash_deduplicator | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Deduplicates samples at document-level using MinHashLSH based on Ray and Redis | [code](../data_juicer/ops/deduplicator/ray_redis_minhash_deduplicator.py) | - | | ray_bts_minhash_deduplicator | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Deduplicates samples at document-level using MinHashLSH based on Ray | [code](../data_juicer/ops/deduplicator/ray_bts_minhash_deduplicator.py) | - | | ray_document_deduplicator | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Deduplicates samples at document-level by comparing MD5 hash on ray | [code](../data_juicer/ops/deduplicator/ray_document_deduplicator.py) | - | | ray_image_deduplicator | ![Image](https://img.shields.io/badge/Image-07B0F2?style=plastic) | Deduplicates samples at document-level using exact matching of images between documents on ray | [code](../data_juicer/ops/deduplicator/ray_image_deduplicator.py) | - | @@ -188,6 +195,21 @@ All the specific operators are listed below, each featured with several capabili | range_specified_field_selector | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Selects samples within a specified range by comparing the values of the specified field | [code](../data_juicer/ops/selector/range_specified_field_selector.py) | [tests](../tests/ops/selector/test_range_specified_field_selector.py) | | topk_specified_field_selector | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Selects top samples by comparing the values of the specified field | [code](../data_juicer/ops/selector/topk_specified_field_selector.py) | [tests](../tests/ops/selector/test_topk_specified_field_selector.py) | +## Grouper + +| Operator | Tags | Description | Source code | Unit tests | +|------------------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|-----------------------------------------------------------------------------------------|-------------------------------------------------------------------------------|---------------------------------------------------------------------------| +| key_value_grouper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Group samples to batched samples according values in given keys. | [code](../data_juicer/ops/grouper/key_value_grouper.py) | [tests](../tests/ops/grouper/test_key_value_grouper.py) | +| naive_grouper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Group all samples to one batched sample. | [code](../data_juicer/ops/grouper/naive_grouper.py) | [tests](../tests/ops/grouper/test_naive_grouper.py) | + +## Aggregator + +| Operator | Tags | Description | Source code | Unit tests | +|------------------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|-----------------------------------------------------------------------------------------|-------------------------------------------------------------------------------|---------------------------------------------------------------------------| +| entity_attribute_aggregator | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Return conclusion of the given entity's attribute from some docs. | [code](../data_juicer/ops/aggregator/entity_attribute_aggregator.py) | [tests](../tests/ops/aggregator/test_entity_attribute_aggregator.py) | +| most_relavant_entities_aggregator | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Extract entities closely related to a given entity from some texts, and sort them in descending order of importance. | [code](../data_juicer/ops/aggregator/most_relavant_entities_aggregator.py) | [tests](../tests/ops/aggregator/test_most_relavant_entities_aggregator.py) | +| nested_aggregator | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Considering the limitation of input length, nested aggregate contents for each given number of samples. | [code](../data_juicer/ops/aggregator/nested_aggregator.py) | [tests](../tests/ops/aggregator/test_nested_aggregator.py) | + ## Contributing We welcome contributions of adding new operators. Please refer to [How-to Guide for Developers](DeveloperGuide.md). diff --git a/docs/Operators_ZH.md b/docs/Operators_ZH.md index 57b009238..c0d4f9793 100644 --- a/docs/Operators_ZH.md +++ b/docs/Operators_ZH.md @@ -11,10 +11,12 @@ Data-Juicer 中的算子分为以下 5 种类型。 | 类型 | 数量 | 描述 | |------------------------------------|:--:|---------------| | [ Formatter ]( #formatter ) | 9 | 发现、加载、规范化原始数据 | -| [ Mapper ]( #mapper ) | 58 | 对数据样本进行编辑和转换 | +| [ Mapper ]( #mapper ) | 63 | 对数据样本进行编辑和转换 | | [ Filter ]( #filter ) | 44 | 过滤低质量样本 | | [ Deduplicator ]( #deduplicator ) | 8 | 识别、删除重复样本 | | [ Selector ]( #selector ) | 4 | 基于排序选取高质量样本 | +| [ Grouper ]( #grouper ) | 2 | 将样本分组,每一组组成一个批量样本 | +| [ Aggregator ]( #aggregator ) | 3 | 对批量样本进行汇总,如得出总结或结论 | 下面列出所有具体算子,每种算子都通过多个标签来注明其主要功能。 @@ -71,6 +73,7 @@ Data-Juicer 中的算子分为以下 5 种类型。 | extract_event_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 从文本中抽取出事件和事件相关人物 | [code](../data_juicer/ops/mapper/extract_event_mapper.py) | [tests](../tests/ops/mapper/test_extract_event_mapper.py) | | extract_keyword_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 构造文本的关键词 | [code](../data_juicer/ops/mapper/extract_keyword_mapper.py) | [tests](../tests/ops/mapper/test_extract_keyword_mapper.py) | | extract_nickname_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 抽取昵称称呼关系 | [code](../data_juicer/ops/mapper/extract_nickname_mapper.py) | [tests](../tests/ops/mapper/test_extract_nickname_mapper.py) | +| extract_support_text_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 为一段总结抽取对应原文 | [code](../data_juicer/ops/mapper/extract_support_text_mapper.py) | [tests](../tests/ops/mapper/test_extract_support_text_mapper.py) | | fix_unicode_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 修复损坏的 Unicode(借助 [ftfy](https://ftfy.readthedocs.io/)) | [code](../data_juicer/ops/mapper/fix_unicode_mapper.py) | [tests](../tests/ops/mapper/test_fix_unicode_mapper.py) | | generate_qa_from_examples_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) ![GPU](https://img.shields.io/badge/GPU-F27649?style=plastic) | 根据种子数据,生成新的对话样本。 | [code](../data_juicer/ops/mapper/generate_qa_from_examples_mapper.py) | [tests](../tests/ops/mapper/test_generate_qa_from_examples_mapper.py) | | generate_qa_from_text_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) ![GPU](https://img.shields.io/badge/GPU-F27649?style=plastic) | 从文本中生成问答对 | [code](../data_juicer/ops/mapper/generate_qa_from_text_mapper.py) | [tests](../tests/ops/mapper/test_generate_qa_from_text_mapper.py) | @@ -85,7 +88,11 @@ Data-Juicer 中的算子分为以下 5 种类型。 | optimize_qa_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) ![GPU](https://img.shields.io/badge/GPU-F27649?style=plastic) | 指令优化,优化问题和答案 | [code](../data_juicer/ops/mapper/optimize_qa_mapper.py) | [tests](../tests/ops/mapper/test_optimize_qa_mapper.py) | | optimize_query_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) ![GPU](https://img.shields.io/badge/GPU-F27649?style=plastic) | 指令优化,优化 query | [code](../data_juicer/ops/mapper/optimize_query_mapper.py) | [tests](../tests/ops/mapper/test_optimize_query_mapper.py) | | optimize_response_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) ![GPU](https://img.shields.io/badge/GPU-F27649?style=plastic) | 指令优化,优化 response | [code](../data_juicer/ops/mapper/optimize_response_mapper.py) | [tests](../tests/ops/mapper/test_optimize_response_mapper.py) | +| pair_preference_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 构造配对的偏好样本 | [code](../data_juicer/ops/mapper/pair_preference_mapper.py) | [tests](../tests/ops/mapper/test_pair_preference_mapper.py) | | punctuation_normalization_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 将各种 Unicode 标点符号标准化为其 ASCII 等效项 | [code](../data_juicer/ops/mapper/punctuation_normalization_mapper.py) | [tests](../tests/ops/mapper/test_punctuation_normalization_mapper.py) | +| python_file_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 执行文件中定义的 Python 函数处理样本 | [code](../data_juicer/ops/mapper/python_file_mapper.py) | [tests](../tests/ops/mapper/test_python_file_mapper.py) | +| python_lambda_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 执行 Python lambda 函数处理样本 | [code](../data_juicer/ops/mapper/python_lambda_mapper.py) | [tests](../tests/ops/mapper/test_python_lambda_mapper.py) | +| relation_identity_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 识别一段文本中两个实体之间的关系 | [code](../data_juicer/ops/mapper/relation_identity_mapper.py) | [tests](../tests/ops/mapper/test_relation_identity_mapper.py) | | remove_bibliography_mapper | ![LaTeX](https://img.shields.io/badge/LaTeX-D99379?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 删除 TeX 文档的参考文献 | [code](../data_juicer/ops/mapper/remove_bibliography_mapper.py) | [tests](../tests/ops/mapper/test_remove_bibliography_mapper.py) | | remove_comments_mapper | ![LaTeX](https://img.shields.io/badge/LaTeX-D99379?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 删除 TeX 文档中的注释 | [code](../data_juicer/ops/mapper/remove_comments_mapper.py) | [tests](../tests/ops/mapper/test_remove_comments_mapper.py) | | remove_header_mapper | ![LaTeX](https://img.shields.io/badge/LaTeX-D99379?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 删除 TeX 文档头,例如标题、章节数字/名称等 | [code](../data_juicer/ops/mapper/remove_header_mapper.py) | [tests](../tests/ops/mapper/test_remove_header_mapper.py) | @@ -102,6 +109,7 @@ Data-Juicer 中的算子分为以下 5 种类型。 | video_captioning_from_frames_mapper | ![Multimodal](https://img.shields.io/badge/Multimodal-F25922?style=plastic) ![GPU](https://img.shields.io/badge/GPU-F27649?style=plastic) | 生成样本,其标题是基于一个文字生成图片的模型和原始样本视频中指定帧的图像。不同帧产出的标题会拼接为一条单独的字符串。 | [code](../data_juicer/ops/mapper/video_captioning_from_frames_mapper.py) | [tests](../tests/ops/mapper/test_video_captioning_from_frames_mapper.py) | | video_captioning_from_summarizer_mapper | ![Multimodal](https://img.shields.io/badge/Multimodal-F25922?style=plastic) ![GPU](https://img.shields.io/badge/GPU-F27649?style=plastic) | 通过对多种不同方式生成的文本进行摘要以生成样本的标题(从视频/音频/帧生成标题,从音频/帧生成标签,...) | [code](../data_juicer/ops/mapper/video_captioning_from_summarizer_mapper.py) | [tests](../tests/ops/mapper/test_video_captioning_from_summarizer_mapper.py) | | video_captioning_from_video_mapper | ![Multimodal](https://img.shields.io/badge/Multimodal-F25922?style=plastic) ![GPU](https://img.shields.io/badge/GPU-F27649?style=plastic) | 生成样本,其标题是根据另一个辅助模型(video-blip)和原始样本中的视频中指定帧的图像。 | [code](../data_juicer/ops/mapper/video_captioning_from_video_mapper.py) | [tests](../tests/ops/mapper/test_video_captioning_from_video_mapper.py) | +| video_extract_frames_mapper | ![Multimodal](https://img.shields.io/badge/Multimodal-F25922?style=plastic) ![GPU](https://img.shields.io/badge/GPU-F27649?style=plastic) | 从视频中抽帧。 | [code](../data_juicer/ops/mapper/video_extract_frames_mapper.py) | [tests](../tests/ops/mapper/test_video_extract_frames_mapper.py) | | video_face_blur_mapper | ![Video](https://img.shields.io/badge/Video-F2B138?style=plastic) | 对视频中的人脸进行模糊处理 | [code](../data_juicer/ops/mapper/video_face_blur_mapper.py) | [tests](../tests/ops/mapper/test_video_face_blur_mapper.py) | | video_ffmpeg_wrapped_mapper | ![Video](https://img.shields.io/badge/Video-F2B138?style=plastic) | 运行 FFmpeg 视频过滤器的简单封装 | [code](../data_juicer/ops/mapper/video_ffmpeg_wrapped_mapper.py) | [tests](../tests/ops/mapper/test_video_ffmpeg_wrapped_mapper.py) | | video_remove_watermark_mapper | ![Video](https://img.shields.io/badge/Video-F2B138?style=plastic) | 去除视频中给定区域的水印 | [code](../data_juicer/ops/mapper/video_remove_watermark_mapper.py) | [tests](../tests/ops/mapper/test_video_remove_watermark_mapper.py) | @@ -116,42 +124,42 @@ Data-Juicer 中的算子分为以下 5 种类型。 ## Filter -| 算子 | 标签 | 描述 | 源码 | 单测样例 | -|-------------------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|-----------------------------------------------------------------------------|--------------------------------------------------------------------------|--------------------------------------------------------------------------| -| alphanumeric_filter | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 保留字母数字比例在指定范围内的样本 | [code](../data_juicer/ops/filter/alphanumeric_filter.py) | [tests](../tests/ops/filter/test_alphanumeric_filter.py) | -| audio_duration_filter | ![Audio](https://img.shields.io/badge/Audio-0DA64F?style=plastic) | 保留包含音频的时长在指定范围内的样本 | [code](../data_juicer/ops/filter/audio_duration_filter.py) | [tests](../tests/ops/filter/test_audio_duration_filter.py) | -| audio_nmf_snr_filter | ![Audio](https://img.shields.io/badge/Audio-0DA64F?style=plastic) | 保留包含音频信噪比SNR(基于非负矩阵分解方法NMF计算)在指定范围内的样本 | [code](../data_juicer/ops/filter/audio_nmf_snr_filter.py) | [tests](../tests/ops/filter/test_audio_nmf_snr_filter.py) | -| audio_size_filter | ![Audio](https://img.shields.io/badge/Audio-0DA64F?style=plastic) | 保留包含音频的大小(bytes)在指定范围内的样本 | [code](../data_juicer/ops/filter/audio_size_filter.py) | [tests](../tests/ops/filter/test_audio_size_filter.py) | -| average_line_length_filter | ![Code](https://img.shields.io/badge/Code-590F08?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 保留平均行长度在指定范围内的样本 | [code](../data_juicer/ops/filter/average_line_length_filter.py) | [tests](../tests/ops/filter/test_average_line_length_filter.py) | -| character_repetition_filter | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 保留 char-level n-gram 重复比率在指定范围内的样本 | [code](../data_juicer/ops/filter/character_repetition_filter.py) | [tests](../tests/ops/filter/test_character_repetition_filter.py) | -| flagged_words_filter | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 保留使标记字比率保持在指定阈值以下的样本 | [code](../data_juicer/ops/filter/flagged_words_filter.py) | [tests](../tests/ops/filter/test_flagged_words_filter.py) | -| image_aesthetics_filter | ![Image](https://img.shields.io/badge/Image-07B0F2?style=plastic) ![GPU](https://img.shields.io/badge/GPU-F27649?style=plastic) | 保留包含美学分数在指定范围内的图像的样本 | [code](../data_juicer/ops/filter/image_aesthetics_filter.py) | [tests](../tests/ops/filter/test_image_aesthetics_filter.py) | -| image_aspect_ratio_filter | ![Image](https://img.shields.io/badge/Image-07B0F2?style=plastic) | 保留样本中包含的图片的宽高比在指定范围内的样本 | [code](../data_juicer/ops/filter/image_aspect_ratio_filter.py) | [tests](../tests/ops/filter/test_image_aspect_ratio_filter.py) | -| image_face_count_filter | ![Image](https://img.shields.io/badge/Image-07B0F2?style=plastic) | 保留样本中包含的图片中检测到的人脸数目在指定范围内的样本 | [code](../data_juicer/ops/filter/image_face_count_filter.py) | [tests](../tests/ops/filter/test_image_face_count_filter.py) | -| image_face_ratio_filter | ![Image](https://img.shields.io/badge/Image-07B0F2?style=plastic) | 保留样本中包含的图片的最大脸部区域在指定范围内的样本 | [code](../data_juicer/ops/filter/image_face_ratio_filter.py) | [tests](../tests/ops/filter/test_image_face_ratio_filter.py) | -| image_nsfw_filter | ![Image](https://img.shields.io/badge/Image-07B0F2?style=plastic) ![GPU](https://img.shields.io/badge/GPU-F27649?style=plastic) | 保留包含NSFW分数在指定阈值之下的图像的样本 | [code](../data_juicer/ops/filter/image_nsfw_filter.py) | [tests](../tests/ops/filter/test_image_nsfw_filter.py) | -| image_pair_similarity_filter | ![Image](https://img.shields.io/badge/Image-07B0F2?style=plastic) ![GPU](https://img.shields.io/badge/GPU-F27649?style=plastic) | 保留图像特征余弦相似度(基于CLIP模型)在指定范围内的样本 | [code](../data_juicer/ops/filter/image_pair_similarity_filter.py) | [tests](../tests/ops/filter/test_image_pair_similarity_filter.py) | -| image_shape_filter | ![Image](https://img.shields.io/badge/Image-07B0F2?style=plastic) | 保留样本中包含的图片的形状(即宽和高)在指定范围内的样本 | [code](../data_juicer/ops/filter/image_shape_filter.py) | [tests](../tests/ops/filter/test_image_shape_filter.py) | -| image_size_filter | ![Image](https://img.shields.io/badge/Image-07B0F2?style=plastic) | 保留样本中包含的图片的大小(bytes)在指定范围内的样本 | [code](../data_juicer/ops/filter/image_size_filter.py) | [tests](../tests/ops/filter/test_image_size_filter.py) | -| image_text_matching_filter | ![Multimodal](https://img.shields.io/badge/Multimodal-F25922?style=plastic) ![GPU](https://img.shields.io/badge/GPU-F27649?style=plastic) | 保留图像-文本的分类匹配分(基于BLIP模型)在指定范围内的样本 | [code](../data_juicer/ops/filter/image_text_matching_filter.py) | [tests](../tests/ops/filter/test_image_text_matching_filter.py) | -| image_text_similarity_filter | ![Multimodal](https://img.shields.io/badge/Multimodal-F25922?style=plastic) ![GPU](https://img.shields.io/badge/GPU-F27649?style=plastic) | 保留图像-文本的特征余弦相似度(基于CLIP模型)在指定范围内的样本 | [code](../data_juicer/ops/filter/image_text_similarity_filter.py) | [tests](../tests/ops/filter/test_image_text_similarity_filter.py) | -| image_watermark_filter | ![Image](https://img.shields.io/badge/Image-07B0F2?style=plastic) ![GPU](https://img.shields.io/badge/GPU-F27649?style=plastic) | 保留包含有水印概率在指定阈值之下的图像的样本 | [code](../data_juicer/ops/filter/image_watermark_filter.py) | [tests](../tests/ops/filter/test_image_watermark_filter.py) | -| language_id_score_filter | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 保留特定语言的样本,通过预测的置信度得分来判断 | [code](../data_juicer/ops/filter/language_id_score_filter.py) | [tests](../tests/ops/filter/test_language_id_score_filter.py) | -| maximum_line_length_filter | ![Code](https://img.shields.io/badge/Code-590F08?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 保留最大行长度在指定范围内的样本 | [code](../data_juicer/ops/filter/maximum_line_length_filter.py) | [tests](../tests/ops/filter/test_maximum_line_length_filter.py) | -| perplexity_filter | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 保留困惑度低于指定阈值的样本 | [code](../data_juicer/ops/filter/perplexity_filter.py) | [tests](../tests/ops/filter/test_perplexity_filter.py) | -| phrase_grounding_recall_filter | ![Multimodal](https://img.shields.io/badge/Multimodal-F25922?style=plastic) ![GPU](https://img.shields.io/badge/GPU-F27649?style=plastic) | 保留从文本中提取的名词短语在图像中的定位召回率在一定范围内的样本 | [code](../data_juicer/ops/filter/phrase_grounding_recall_filter.py) | [tests](../tests/ops/filter/test_phrase_grounding_recall_filter.py) | -| special_characters_filter | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 保留 special-char 比率的在指定范围内的样本 | [code](../data_juicer/ops/filter/special_characters_filter.py) | [tests](../tests/ops/filter/test_special_characters_filter.py) | -| specified_field_filter | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 根据字段过滤样本,要求字段的值处于指定目标中 | [code](../data_juicer/ops/filter/specified_field_filter.py) | [tests](../tests/ops/filter/test_specified_field_filter.py) | -| specified_numeric_field_filter | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 根据字段过滤样本,要求字段的值处于指定范围(针对数字类型) | [code](../data_juicer/ops/filter/specified_numeric_field_filter.py) | [tests](../tests/ops/filter/test_specified_numeric_field_filter.py) | -| stopwords_filter | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 保留停用词比率高于指定阈值的样本 | [code](../data_juicer/ops/filter/stopwords_filter.py) | [tests](../tests/ops/filter/test_stopwords_filter.py) | -| suffix_filter | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 保留包含特定后缀的样本 | [code](../data_juicer/ops/filter/suffix_filter.py) | [tests](../tests/ops/filter/test_suffix_filter.py) | -| text_action_filter | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 保留文本部分包含动作的样本 | [code](../data_juicer/ops/filter/text_action_filter.py) | [tests](../tests/ops/filter/test_text_action_filter.py) | -| text_entity_dependency_filter | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 保留文本部分的依存树中具有非独立实体的样本 | [code](../data_juicer/ops/filter/text_entity_dependency_filter.py) | [tests](../tests/ops/filter/test_text_entity_dependency_filter.py) | -| text_length_filter | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 保留总文本长度在指定范围内的样本 | [code](../data_juicer/ops/filter/text_length_filter.py) | [tests](../tests/ops/filter/test_text_length_filter.py) | -| token_num_filter | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) ![GPU](https://img.shields.io/badge/GPU-F27649?style=plastic) | 保留token数在指定范围内的样本 | [code](../data_juicer/ops/filter/token_num_filter.py) | [tests](../tests/ops/filter/test_token_num_filter.py) | -| video_aspect_ratio_filter | ![Video](https://img.shields.io/badge/Video-F2B138?style=plastic) ![GPU](https://img.shields.io/badge/GPU-F27649?style=plastic) | 保留包含视频的宽高比在指定范围内的样本 | [code](../data_juicer/ops/filter/video_aesthetics_filter.py) | [tests](../tests/ops/filter/test_video_aesthetics_filter.py) | -| video_duration_filter | ![Video](https://img.shields.io/badge/Video-F2B138?style=plastic) | 保留包含视频的时长在指定范围内的样本 | [code](../data_juicer/ops/filter/video_aspect_ratio_filter.py) | [tests](../tests/ops/filter/test_video_aspect_ratio_filter.py) | -| video_aesthetics_filter | ![Video](https://img.shields.io/badge/Video-F2B138?style=plastic) | 保留指定帧的美学分数在指定范围内的样本 | [code](../data_juicer/ops/filter/video_duration_filter.py) | [tests](../tests/ops/filter/test_video_duration_filter.py) | +| 算子 | 标签 | 描述 | 源码 | 单测样例 | +|-------------------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|-------------------------------------------|--------------------------------------------------------------------------|--------------------------------------------------------------------------| +| alphanumeric_filter | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 保留字母数字比例在指定范围内的样本 | [code](../data_juicer/ops/filter/alphanumeric_filter.py) | [tests](../tests/ops/filter/test_alphanumeric_filter.py) | +| audio_duration_filter | ![Audio](https://img.shields.io/badge/Audio-0DA64F?style=plastic) | 保留包含音频的时长在指定范围内的样本 | [code](../data_juicer/ops/filter/audio_duration_filter.py) | [tests](../tests/ops/filter/test_audio_duration_filter.py) | +| audio_nmf_snr_filter | ![Audio](https://img.shields.io/badge/Audio-0DA64F?style=plastic) | 保留包含音频信噪比SNR(基于非负矩阵分解方法NMF计算)在指定范围内的样本 | [code](../data_juicer/ops/filter/audio_nmf_snr_filter.py) | [tests](../tests/ops/filter/test_audio_nmf_snr_filter.py) | +| audio_size_filter | ![Audio](https://img.shields.io/badge/Audio-0DA64F?style=plastic) | 保留包含音频的大小(bytes)在指定范围内的样本 | [code](../data_juicer/ops/filter/audio_size_filter.py) | [tests](../tests/ops/filter/test_audio_size_filter.py) | +| average_line_length_filter | ![Code](https://img.shields.io/badge/Code-590F08?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 保留平均行长度在指定范围内的样本 | [code](../data_juicer/ops/filter/average_line_length_filter.py) | [tests](../tests/ops/filter/test_average_line_length_filter.py) | +| character_repetition_filter | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 保留 char-level n-gram 重复比率在指定范围内的样本 | [code](../data_juicer/ops/filter/character_repetition_filter.py) | [tests](../tests/ops/filter/test_character_repetition_filter.py) | +| flagged_words_filter | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 保留使标记字比率保持在指定阈值以下的样本 | [code](../data_juicer/ops/filter/flagged_words_filter.py) | [tests](../tests/ops/filter/test_flagged_words_filter.py) | +| image_aesthetics_filter | ![Image](https://img.shields.io/badge/Image-07B0F2?style=plastic) ![GPU](https://img.shields.io/badge/GPU-F27649?style=plastic) | 保留包含美学分数在指定范围内的图像的样本 | [code](../data_juicer/ops/filter/image_aesthetics_filter.py) | [tests](../tests/ops/filter/test_image_aesthetics_filter.py) | +| image_aspect_ratio_filter | ![Image](https://img.shields.io/badge/Image-07B0F2?style=plastic) | 保留样本中包含的图片的宽高比在指定范围内的样本 | [code](../data_juicer/ops/filter/image_aspect_ratio_filter.py) | [tests](../tests/ops/filter/test_image_aspect_ratio_filter.py) | +| image_face_count_filter | ![Image](https://img.shields.io/badge/Image-07B0F2?style=plastic) | 保留样本中包含的图片中检测到的人脸数目在指定范围内的样本 | [code](../data_juicer/ops/filter/image_face_count_filter.py) | [tests](../tests/ops/filter/test_image_face_count_filter.py) | +| image_face_ratio_filter | ![Image](https://img.shields.io/badge/Image-07B0F2?style=plastic) | 保留样本中包含的图片的最大脸部区域在指定范围内的样本 | [code](../data_juicer/ops/filter/image_face_ratio_filter.py) | [tests](../tests/ops/filter/test_image_face_ratio_filter.py) | +| image_nsfw_filter | ![Image](https://img.shields.io/badge/Image-07B0F2?style=plastic) ![GPU](https://img.shields.io/badge/GPU-F27649?style=plastic) | 保留包含NSFW分数在指定阈值之下的图像的样本 | [code](../data_juicer/ops/filter/image_nsfw_filter.py) | [tests](../tests/ops/filter/test_image_nsfw_filter.py) | +| image_pair_similarity_filter | ![Image](https://img.shields.io/badge/Image-07B0F2?style=plastic) ![GPU](https://img.shields.io/badge/GPU-F27649?style=plastic) | 保留图像特征余弦相似度(基于CLIP模型)在指定范围内的样本 | [code](../data_juicer/ops/filter/image_pair_similarity_filter.py) | [tests](../tests/ops/filter/test_image_pair_similarity_filter.py) | +| image_shape_filter | ![Image](https://img.shields.io/badge/Image-07B0F2?style=plastic) | 保留样本中包含的图片的形状(即宽和高)在指定范围内的样本 | [code](../data_juicer/ops/filter/image_shape_filter.py) | [tests](../tests/ops/filter/test_image_shape_filter.py) | +| image_size_filter | ![Image](https://img.shields.io/badge/Image-07B0F2?style=plastic) | 保留样本中包含的图片的大小(bytes)在指定范围内的样本 | [code](../data_juicer/ops/filter/image_size_filter.py) | [tests](../tests/ops/filter/test_image_size_filter.py) | +| image_text_matching_filter | ![Multimodal](https://img.shields.io/badge/Multimodal-F25922?style=plastic) ![GPU](https://img.shields.io/badge/GPU-F27649?style=plastic) | 保留图像-文本的分类匹配分(基于BLIP模型)在指定范围内的样本 | [code](../data_juicer/ops/filter/image_text_matching_filter.py) | [tests](../tests/ops/filter/test_image_text_matching_filter.py) | +| image_text_similarity_filter | ![Multimodal](https://img.shields.io/badge/Multimodal-F25922?style=plastic) ![GPU](https://img.shields.io/badge/GPU-F27649?style=plastic) | 保留图像-文本的特征余弦相似度(基于CLIP模型)在指定范围内的样本 | [code](../data_juicer/ops/filter/image_text_similarity_filter.py) | [tests](../tests/ops/filter/test_image_text_similarity_filter.py) | +| image_watermark_filter | ![Image](https://img.shields.io/badge/Image-07B0F2?style=plastic) ![GPU](https://img.shields.io/badge/GPU-F27649?style=plastic) | 保留包含有水印概率在指定阈值之下的图像的样本 | [code](../data_juicer/ops/filter/image_watermark_filter.py) | [tests](../tests/ops/filter/test_image_watermark_filter.py) | +| language_id_score_filter | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 保留特定语言的样本,通过预测的置信度得分来判断 | [code](../data_juicer/ops/filter/language_id_score_filter.py) | [tests](../tests/ops/filter/test_language_id_score_filter.py) | +| maximum_line_length_filter | ![Code](https://img.shields.io/badge/Code-590F08?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 保留最大行长度在指定范围内的样本 | [code](../data_juicer/ops/filter/maximum_line_length_filter.py) | [tests](../tests/ops/filter/test_maximum_line_length_filter.py) | +| perplexity_filter | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 保留困惑度低于指定阈值的样本 | [code](../data_juicer/ops/filter/perplexity_filter.py) | [tests](../tests/ops/filter/test_perplexity_filter.py) | +| phrase_grounding_recall_filter | ![Multimodal](https://img.shields.io/badge/Multimodal-F25922?style=plastic) ![GPU](https://img.shields.io/badge/GPU-F27649?style=plastic) | 保留从文本中提取的名词短语在图像中的定位召回率在一定范围内的样本 | [code](../data_juicer/ops/filter/phrase_grounding_recall_filter.py) | [tests](../tests/ops/filter/test_phrase_grounding_recall_filter.py) | +| special_characters_filter | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 保留 special-char 比率的在指定范围内的样本 | [code](../data_juicer/ops/filter/special_characters_filter.py) | [tests](../tests/ops/filter/test_special_characters_filter.py) | +| specified_field_filter | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 根据字段过滤样本,要求字段的值处于指定目标中 | [code](../data_juicer/ops/filter/specified_field_filter.py) | [tests](../tests/ops/filter/test_specified_field_filter.py) | +| specified_numeric_field_filter | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 根据字段过滤样本,要求字段的值处于指定范围(针对数字类型) | [code](../data_juicer/ops/filter/specified_numeric_field_filter.py) | [tests](../tests/ops/filter/test_specified_numeric_field_filter.py) | +| stopwords_filter | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 保留停用词比率高于指定阈值的样本 | [code](../data_juicer/ops/filter/stopwords_filter.py) | [tests](../tests/ops/filter/test_stopwords_filter.py) | +| suffix_filter | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 保留包含特定后缀的样本 | [code](../data_juicer/ops/filter/suffix_filter.py) | [tests](../tests/ops/filter/test_suffix_filter.py) | +| text_action_filter | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 保留文本部分包含动作的样本 | [code](../data_juicer/ops/filter/text_action_filter.py) | [tests](../tests/ops/filter/test_text_action_filter.py) | +| text_entity_dependency_filter | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 保留文本部分的依存树中具有非独立实体的样本 | [code](../data_juicer/ops/filter/text_entity_dependency_filter.py) | [tests](../tests/ops/filter/test_text_entity_dependency_filter.py) | +| text_length_filter | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 保留总文本长度在指定范围内的样本 | [code](../data_juicer/ops/filter/text_length_filter.py) | [tests](../tests/ops/filter/test_text_length_filter.py) | +| token_num_filter | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) ![GPU](https://img.shields.io/badge/GPU-F27649?style=plastic) | 保留token数在指定范围内的样本 | [code](../data_juicer/ops/filter/token_num_filter.py) | [tests](../tests/ops/filter/test_token_num_filter.py) | +| video_aesthetics_filter | ![Video](https://img.shields.io/badge/Video-F2B138?style=plastic) ![GPU](https://img.shields.io/badge/GPU-F27649?style=plastic) | 保留指定帧的美学分数在指定范围内的样本 | [code](../data_juicer/ops/filter/video_aesthetics_filter.py) | [tests](../tests/ops/filter/test_video_aesthetics_filter.py) | +| video_aspect_ratio_filter | ![Video](https://img.shields.io/badge/Video-F2B138?style=plastic) | 保留包含视频的宽高比在指定范围内的样本 | [code](../data_juicer/ops/filter/video_aspect_ratio_filter.py) | [tests](../tests/ops/filter/test_video_aspect_ratio_filter.py) | +| video_duration_filter | ![Video](https://img.shields.io/badge/Video-F2B138?style=plastic) | 保留包含视频的时长在指定范围内的样本 | [code](../data_juicer/ops/filter/video_duration_filter.py) | [tests](../tests/ops/filter/test_video_duration_filter.py) | | video_frames_text_similarity_filter | ![Multimodal](https://img.shields.io/badge/Multimodal-F25922?style=plastic) ![GPU](https://img.shields.io/badge/GPU-F27649?style=plastic) | 保留视频中指定帧的图像-文本的特征余弦相似度(基于CLIP模型)在指定范围内的样本 | [code](../data_juicer/ops/filter/video_frames_text_similarity_filter.py) | [tests](../tests/ops/filter/test_video_frames_text_similarity_filter.py) | | video_motion_score_filter | ![Video](https://img.shields.io/badge/Video-F2B138?style=plastic) | 保留包含视频的运动分数(基于稠密光流)在指定范围内的样本 | [code](../data_juicer/ops/filter/video_motion_score_filter.py) | [tests](../tests/ops/filter/test_video_motion_score_filter.py) | | video_motion_score_raft_filter | ![Video](https://img.shields.io/badge/Video-F2B138?style=plastic) | 保留包含视频的运动分数(基于 RAFT 模型估计的稠密光流)在指定范围内的样本 | [code](../data_juicer/ops/filter/video_motion_score_raft_raft_filter.py) | [tests](../tests/ops/filter/test_video_motion_score_filter.py) | @@ -172,7 +180,6 @@ Data-Juicer 中的算子分为以下 5 种类型。 | document_simhash_deduplicator | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 使用 SimHash 在文档级别对样本去重 | [code](../data_juicer/ops/deduplicator/document_simhash_deduplicator.py) | [tests](../tests/ops/deduplicator/test_document_simhash_deduplicator.py) | | image_deduplicator | ![Image](https://img.shields.io/badge/Image-07B0F2?style=plastic) | 使用文档之间图像的精确匹配在文档级别删除重复样本 | [code](../data_juicer/ops/deduplicator/image_deduplicator.py) | [tests](../tests/ops/deduplicator/test_image_deduplicator.py) | | video_deduplicator | ![Video](https://img.shields.io/badge/Video-F2B138?style=plastic) | 使用文档之间视频的精确匹配在文档级别删除重复样本 | [code](../data_juicer/ops/deduplicator/video_deduplicator.py) | [tests](../tests/ops/deduplicator/test_video_deduplicator.py) | -| ray_redis_minhash_deduplicator | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 使用 MinHashLSH 在文档级别对样本去重,面向 RAY 分布式模式(基于Redis) | [code](../data_juicer/ops/deduplicator/ray_redis_minhash_deduplicator.py) | - | | ray_bts_minhash_deduplicator | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 使用 MinHashLSH 在文档级别对样本去重,面向 RAY 分布式模式 | [code](../data_juicer/ops/deduplicator/ray_bts_minhash_deduplicator.py) | - | | ray_document_deduplicator | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 通过比较 MD5 哈希值在文档级别对样本去重,面向RAY分布式模式 | [code](../data_juicer/ops/deduplicator/ray_document_deduplicator.py) | - | | ray_image_deduplicator | ![Image](https://img.shields.io/badge/Image-07B0F2?style=plastic) | 使用文档之间图像的精确匹配在文档级别删除重复样本,面向RAY分布式模式 | [code](../data_juicer/ops/deduplicator/ray_image_deduplicator.py) | - | @@ -187,5 +194,20 @@ Data-Juicer 中的算子分为以下 5 种类型。 | range_specified_field_selector | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 通过比较指定字段的值选出指定范围的 k 个样本 | [code](../data_juicer/ops/selector/range_specified_field_selector.py) | [tests](../tests/ops/selector/test_range_specified_field_selector.py) | | topk_specified_field_selector | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 通过比较指定字段的值选出前 k 个样本 | [code](../data_juicer/ops/selector/topk_specified_field_selector.py) | [tests](../tests/ops/selector/test_topk_specified_field_selector.py) | +## Grouper + +| 算子 | 标签 | 描述 | 源码 | 单测样例 | +|-------------------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|--------------------------|---------------------------------------------------------------------------|---------------------------------------------------------------------------| +| key_value_grouper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 根据给定键的值将样本分组,每一组组成一个批量样本。 | [code](../data_juicer/ops/grouper/key_value_grouper.py) | [tests](../tests/ops/grouper/test_key_value_grouper.py) | +| naive_grouper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 将所有样本分为一个组,返回一个批量样本 | [code](../data_juicer/ops/grouper/naive_grouper.py) | [tests](../tests/ops/grouper/test_naive_grouper.py) | + +## Aggregator + +| 算子 | 标签 | 描述 | 源码 | 单测样例 | +|-------------------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|--------------------------|---------------------------------------------------------------------------|---------------------------------------------------------------------------| +| entity_attribute_aggregator | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 从一些文本中总结出给定实体的属性 | [code](../data_juicer/ops/aggregator/entity_attribute_aggregator.py) | [tests](../tests/ops/aggregator/test_entity_attribute_aggregator.py) | +| most_relavant_entities_aggregator | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 从一些文本中抽取出与给定实体密切相关的实体,按重要性从高到低排序 | [code](../data_juicer/ops/aggregator/most_relavant_entities_aggregator.py) | [tests](../tests/ops/aggregator/test_most_relavant_entities_aggregator.py) | +| nested_aggregator | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 考虑到输入长度的限制,对样本中的内容进行嵌套聚合。 | [code](../data_juicer/ops/aggregator/nested_aggregator.py) | [tests](../tests/ops/aggregator/test_nested_aggregator.py) | + ## 贡献 我们欢迎社区贡献新的算子,具体请参考[开发者指南](DeveloperGuide_ZH.md)。 diff --git a/environments/dev_requires.txt b/environments/dev_requires.txt index 9793d5746..44dd79158 100644 --- a/environments/dev_requires.txt +++ b/environments/dev_requires.txt @@ -4,3 +4,4 @@ sphinx sphinx-autobuild sphinx_rtd_theme recommonmark +wandb<=0.19.0 diff --git a/environments/minimal_requires.txt b/environments/minimal_requires.txt index c484a9847..71aa0ba38 100644 --- a/environments/minimal_requires.txt +++ b/environments/minimal_requires.txt @@ -1,10 +1,14 @@ -fsspec==2023.5.0 -pyarrow<=12.0.0 -pandas==2.0.3 datasets>=2.19.0 -av +fsspec==2023.5.0 +pandas +numpy +av==13.1.0 soundfile +# need to install two dependencies by librosa to avoid lazy_loader error librosa>=0.10 +samplerate +resampy +# need to install two dependencies by librosa to avoid lazy_loader error loguru tabulate tqdm @@ -27,5 +31,6 @@ dill==0.3.4 psutil pydantic>=2.0 Pillow -numpy<2 fastapi[standard]>=0.100 +httpx +wordcloud diff --git a/environments/sandbox_requires.txt b/environments/sandbox_requires.txt index 7f1d27a25..6a1791cf8 100644 --- a/environments/sandbox_requires.txt +++ b/environments/sandbox_requires.txt @@ -1,5 +1,4 @@ torch>=1.11.0 -wandb fire pyspark # vbench-related diff --git a/environments/science_requires.txt b/environments/science_requires.txt index f1e613126..af5d6b362 100644 --- a/environments/science_requires.txt +++ b/environments/science_requires.txt @@ -11,7 +11,7 @@ selectolax nlpaug nlpcda nltk<3.9 -transformers>=4.37 +transformers>=4.47.0 transformers_stream_generator einops accelerate @@ -26,3 +26,5 @@ ffmpeg-python opencv-python vllm>=0.1.3 rouge +dashscope +openai diff --git a/setup.py b/setup.py index 3df3d0170..d0ec5b546 100644 --- a/setup.py +++ b/setup.py @@ -69,6 +69,7 @@ def get_install_requirements(require_f_paths, env_dir='environments'): 'console_scripts': [ 'dj-process = data_juicer.tools.process_data:main', 'dj-analyze = data_juicer.tools.analyze_data:main', + 'dj-install = data_juicer.tools.dj_install:main', ] }, install_requires=min_requires, diff --git a/tests/benchmark_performance/configs/audio.yaml b/tests/benchmark_performance/configs/audio.yaml new file mode 100644 index 000000000..848c537b0 --- /dev/null +++ b/tests/benchmark_performance/configs/audio.yaml @@ -0,0 +1,14 @@ +# The config file for performance benchmark to measure the processing speed for +# the current Data-Juicer system. OPs are selected according to their tags and +# types (https://github.com/modelscope/data-juicer/blob/main/docs/Operators.md) + +project_name: 'performance-benchmark-audio' +dataset_path: 'perf_bench_data/audio/audio-10k.jsonl' +export_path: 'outputs/performance_benchmark_audio/res.jsonl' +np: 16 +use_cache: false + +process: + - audio_duration_filter: + - audio_nmf_snr_filter: + - audio_size_filter: diff --git a/tests/benchmark_performance/configs/image.yaml b/tests/benchmark_performance/configs/image.yaml new file mode 100644 index 000000000..3ce03be53 --- /dev/null +++ b/tests/benchmark_performance/configs/image.yaml @@ -0,0 +1,23 @@ +# The config file for performance benchmark to measure the processing speed for +# the current Data-Juicer system. OPs are selected according to their tags and +# types (https://github.com/modelscope/data-juicer/blob/main/docs/Operators.md) + +project_name: 'performance-benchmark-image' +dataset_path: 'perf_bench_data/image/10k.jsonl' +export_path: 'outputs/performance_benchmark_image/res.jsonl' +np: 16 +use_cache: false + +process: + - image_aesthetics_filter: + hf_scorer_model: 'shunk031/aesthetics-predictor-v2-sac-logos-ava1-l14-linearMSE' + min_score: 0.0 + mem_required: '1500MB' + - image_captioning_mapper: + hf_img2seq: 'Salesforce/blip2-opt-2.7b' + caption_num: 1 + keep_original_sample: false + mem_required: '16GB' + - image_shape_filter: + - image_blur_mapper: + - image_deduplicator: diff --git a/tests/benchmark_performance/configs/text.yaml b/tests/benchmark_performance/configs/text.yaml new file mode 100644 index 000000000..8b39bbeb8 --- /dev/null +++ b/tests/benchmark_performance/configs/text.yaml @@ -0,0 +1,21 @@ +# The config file for performance benchmark to measure the processing speed for +# the current Data-Juicer system. OPs are selected according to their tags and +# types (https://github.com/modelscope/data-juicer/blob/main/docs/Operators.md) + +project_name: 'performance-benchmark-text' +dataset_path: 'perf_bench_data/text/wiki-10k.jsonl' +export_path: 'outputs/performance_benchmark_text/res.jsonl' +np: 16 +use_cache: false + +process: + - whitespace_normalization_mapper: + - token_num_filter: + hf_tokenizer: 'EleutherAI/pythia-6.9b-deduped' + min_num: 0 + - document_deduplicator: + lowercase: false + ignore_non_character: false + - topk_specified_field_selector: + field_key: '__dj__stats__.num_token' + topk: 1000 diff --git a/tests/benchmark_performance/configs/video.yaml b/tests/benchmark_performance/configs/video.yaml new file mode 100644 index 000000000..28fb3b98a --- /dev/null +++ b/tests/benchmark_performance/configs/video.yaml @@ -0,0 +1,21 @@ +# The config file for performance benchmark to measure the processing speed for +# the current Data-Juicer system. OPs are selected according to their tags and +# types (https://github.com/modelscope/data-juicer/blob/main/docs/Operators.md) + +project_name: 'performance-benchmark-video' +dataset_path: 'perf_bench_data/video/msr_vtt_train.jsonl' +export_path: 'outputs/performance_benchmark_video/res.jsonl' +np: 16 +use_cache: false + +process: + - video_nsfw_filter: + hf_nsfw_model: 'Falconsai/nsfw_image_detection' + score_threshold: 1.0 + mem_required: '1GB' + - video_tagging_from_frames_mapper: + mem_required: '9GB' + - video_duration_filter: + - video_split_by_key_frame_mapper: + keep_original_sample: false + - video_deduplicator: diff --git a/tests/benchmark_performance/report.py b/tests/benchmark_performance/report.py new file mode 100644 index 000000000..e53afa63a --- /dev/null +++ b/tests/benchmark_performance/report.py @@ -0,0 +1,126 @@ +import wandb +import fire +import os +import json +import yaml +import regex as re +from loguru import logger + +PROJECT = 'Data-Juicer Reports' +RUN_NAME = 'Performance Benchmark -- %s' +MODALITIES = {'text', 'image', 'video', 'audio'} +DIFF_TH = 0.1 + +def get_run_id(project, run_name, entity='dail'): + api = wandb.Api() + runs = api.runs(path=f'{entity}/{project}') + for run in runs: + if run.name == run_name: + return run.id + return '' + +def init_run(modality, config=None): + # get the run object for specified modality + # if it's not existed, create one + # if it's existed, get the run id and resume from it + run_id = get_run_id(PROJECT, RUN_NAME % modality) + if run_id == '': + # no existing run, create one + run = wandb.init(project=PROJECT, + config=config, + tags=['performance benchmark', modality], + name=RUN_NAME % modality) + run_id = get_run_id(PROJECT, RUN_NAME % modality) + else: + run = wandb.init(project=PROJECT, + id=run_id, + resume='must') + return run, run_id + +def main(): + wandb.login() + for modality in MODALITIES: + logger.info(f'--------------- {modality} ---------------') + work_dir = f'outputs/performance_benchmark_{modality}/' + + # read config + with open(os.path.join(work_dir, f'{modality}.yaml')) as fin: + config = yaml.load(fin, yaml.FullLoader) + + # init the wandb run + run, run_id = init_run(modality, config) + + # collect results from logs + log_pt = r'export_(.*?)_time_(\d*?).txt' + log_dir = os.path.join(work_dir, 'log') + log_files = os.listdir(log_dir) + log_file = None + for fn in log_files: + if re.match(log_pt, fn): + log_file = fn + break + if log_file is None: + logger.warning('No log files found.') + exit() + log_file = os.path.join(log_dir, log_file) + with open(log_file) as fin: + log_content = fin.read() + op_pt = r'OP \[(.*?)\] Done in (.*?)s' + total_pt = r'All OPs are done in (.*?)s' + op_data = re.findall(op_pt, log_content) + ops = [it[0] for it in op_data] + total_data = re.findall(total_pt, log_content) + + res = dict(op_data) + res['total_time'] = total_data[0] + res = {key: {'time': float(res[key])} for key in res} + + # collect resource utilization from monitor logs + monitor_file = os.path.join(work_dir, 'monitor', 'monitor.json') + with open(monitor_file) as fin: + monitor_res = json.load(fin) + assert len(monitor_res) == len(ops) + for op, resource_util_dict in zip(ops, monitor_res): + res[op].update(resource_util_dict['resource_analysis']) + + # upload results and finish the run + upload_res = { + modality: res + } + run.log(upload_res) + run.finish() + + # compare with the last run + api = wandb.Api() + api_run = api.run(f'{PROJECT}/{run_id}') + run_history = api_run.history() + if len(run_history) < 2: + continue + last_record = run_history.iloc[-2] + + for op_name, time in op_data: + last_time = last_record[f'{modality}.{op_name}.time'] + this_time = res[op_name]['time'] + dif = (this_time - last_time) / last_time + if dif > 0.1: + logger.warning(f'Time cost for OP {[op_name]} increased by ' + f'{dif * 100}% (> 10%). Before-{last_time} vs. ' + f'Now-{this_time}') + else: + logger.info(f'Time cost for OP {[op_name]} increased by ' + f'{dif * 100}%. Before-{last_time} vs. ' + f'Now-{this_time}') + last_total = last_record[f'{modality}.total_time.time'] + this_total = res['total_time']['time'] + dif_total = (this_total - last_total) / last_total + if dif_total > 0.1: + logger.warning(f'Total time cost increased by {dif_total * 100}% ' + f'(> 10%). Before-{last_total} vs. ' + f'Now-{this_total}') + else: + logger.info(f'Total time cost increased by {dif_total * 100}%. ' + f'Before-{last_total} vs. Now-{this_total}') + + +if __name__ == '__main__': + fire.Fire(main) diff --git a/tests/benchmark_performance/run.sh b/tests/benchmark_performance/run.sh new file mode 100644 index 000000000..1ec839d57 --- /dev/null +++ b/tests/benchmark_performance/run.sh @@ -0,0 +1,27 @@ +#!/bin/bash + +# setup wandb configs +export WANDB_BASE_URL=$1 +export WANDB_API_KEY=$2 + +BENCH_PATH=$(cd "$(dirname "$0")"; pwd) +RELATIVE_DJ_PATH=../.. +MODALITIES=("text" "image" "video" "audio") + +cd $BENCH_PATH + +# 1. prepare dataset +wget -q http://dail-wlcb.oss-cn-wulanchabu.aliyuncs.com/data_juicer/perf_bench_data/perf_bench_data.tar.gz && tar zxf perf_bench_data.tar.gz + +# 2. run the benchmark +for modality in ${MODALITIES[@]} +do + python $RELATIVE_DJ_PATH/tools/process_data.py --config configs/$modality.yaml +done + +# 3. collect & upload benchmark results +python report.py + +# 4. clear resources +rm -rf perf_bench_data.tar.gz +rm -rf perf_bench_data/ diff --git a/tests/config/test_config_funcs.py b/tests/config/test_config_funcs.py index 1cb7c4463..2502fe1e1 100644 --- a/tests/config/test_config_funcs.py +++ b/tests/config/test_config_funcs.py @@ -54,6 +54,7 @@ def test_yaml_cfg_file(self): 'mem_required': 0, 'turbo': False, 'batch_size': 1000, + 'index_key': None, } }, 'nested dict load fail, for nonparametric op') self.assertDictEqual( @@ -75,6 +76,7 @@ def test_yaml_cfg_file(self): 'mem_required': 0, 'turbo': False, 'batch_size': 1000, + 'index_key': None, } }, 'nested dict load fail, un-expected internal value') @@ -144,6 +146,7 @@ def test_mixture_cfg(self): 'mem_required': 0, 'turbo': False, 'batch_size': 1000, + 'index_key': None, } }) self.assertDictEqual( @@ -165,6 +168,7 @@ def test_mixture_cfg(self): 'mem_required': 0, 'turbo': False, 'batch_size': 1000, + 'index_key': None, } }) self.assertDictEqual( @@ -186,6 +190,7 @@ def test_mixture_cfg(self): 'mem_required': 0, 'turbo': False, 'batch_size': 1000, + 'index_key': None, } }) self.assertDictEqual( @@ -207,6 +212,7 @@ def test_mixture_cfg(self): 'mem_required': 0, 'turbo': False, 'batch_size': 1000, + 'index_key': None, } }) self.assertDictEqual( @@ -228,6 +234,7 @@ def test_mixture_cfg(self): 'mem_required': 0, 'turbo': False, 'batch_size': 1000, + 'index_key': None, } }) diff --git a/tests/core/test_adapter.py b/tests/core/test_adapter.py index 965355b96..4a58d882f 100644 --- a/tests/core/test_adapter.py +++ b/tests/core/test_adapter.py @@ -4,11 +4,12 @@ from datasets import load_dataset from loguru import logger from data_juicer.core import Adapter -from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase +from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase, SKIPPED_TESTS from data_juicer.ops.mapper import FixUnicodeMapper from data_juicer.ops.filter import PerplexityFilter from data_juicer.ops.deduplicator import DocumentDeduplicator +@SKIPPED_TESTS.register_module() class AdapterTest(DataJuicerTestCaseBase): @classmethod @@ -177,6 +178,23 @@ def test_adapt_workloads(self): datasets.enable_caching() + def test_adapt_workloads_multiprocessing(self): + datasets.disable_caching() + # basic test + ds = load_dataset('json', data_files=self.test_file, split='train') + ops = [ + FixUnicodeMapper(num_proc=4), + PerplexityFilter(num_proc=4), + DocumentDeduplicator(num_proc=4), + ] # use some batched OPs later + + adapter = Adapter({'batch_size': 100}) + adapted_batch_sizes = adapter.adapt_workloads(ds, ops) + self.assertEqual(len(adapted_batch_sizes), len(ops)) + logger.info(adapted_batch_sizes) + + datasets.enable_caching() + if __name__ == '__main__': unittest.main() diff --git a/tests/ops/Aggregator/__init__.py b/tests/ops/Aggregator/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/ops/Aggregator/test_entity_attribute_aggregator.py b/tests/ops/Aggregator/test_entity_attribute_aggregator.py new file mode 100644 index 000000000..1f80da3a3 --- /dev/null +++ b/tests/ops/Aggregator/test_entity_attribute_aggregator.py @@ -0,0 +1,139 @@ +import unittest + +from loguru import logger + +from data_juicer.core.data import NestedDataset as Dataset +from data_juicer.ops.aggregator import EntityAttributeAggregator +from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase, SKIPPED_TESTS + + +@SKIPPED_TESTS.register_module() +class EntityAttributeAggregatorTest(DataJuicerTestCaseBase): + + def _run_helper(self, op, samples): + + # before runing this test, set below environment variables: + # export OPENAI_BASE_URL=https://dashscope.aliyuncs.com/compatible-mode/v1/ + # export OPENAI_API_KEY=your_dashscope_key + + dataset = Dataset.from_list(samples) + new_dataset = op.run(dataset) + + for data in new_dataset: + for k in data: + logger.info(f"{k}: {data[k]}") + + self.assertEqual(len(new_dataset), len(samples)) + + def test_default_aggregator(self): + samples = [ + { + 'text': [ + "十年前,李相夷十五岁战胜西域天魔成为天下第一高手,十七岁建立四顾门,二十岁问鼎武林盟主,成为传奇人物。", + "有人视李相夷为中原武林的希望,但也有人以战胜他为目标,包括魔教金鸳盟盟主笛飞声。笛飞声设计加害李相夷的师兄单孤刀,引得李相夷与之一战。", + '在东海的一艘船上,李相夷独自一人对抗金鸳盟的高手,最终击败了大部分敌人。笛飞声突然出现,两人激战,李相夷在战斗中中毒,最终被笛飞声重伤,船只爆炸,李相夷沉入大海。', + '十年后,李莲花在一个寒酸的莲花楼内醒来,表现出与李相夷截然不同的性格。他以神医的身份在小镇上行医,但生活贫困。', + '小镇上的皮影戏摊讲述李相夷和笛飞声的故事,孩子们争论谁赢了。风火堂管事带着人来找李莲花,要求他救治一个“死人”。' + ] + }, + ] + op = EntityAttributeAggregator( + api_model='qwen2.5-72b-instruct', + entity='李莲花', + attribute='主要经历' + ) + self._run_helper(op, samples) + + def test_input_output(self): + samples = [ + { + 'sub_docs': [ + "十年前,李相夷十五岁战胜西域天魔成为天下第一高手,十七岁建立四顾门,二十岁问鼎武林盟主,成为传奇人物。", + "有人视李相夷为中原武林的希望,但也有人以战胜他为目标,包括魔教金鸳盟盟主笛飞声。笛飞声设计加害李相夷的师兄单孤刀,引得李相夷与之一战。", + '在东海的一艘船上,李相夷独自一人对抗金鸳盟的高手,最终击败了大部分敌人。笛飞声突然出现,两人激战,李相夷在战斗中中毒,最终被笛飞声重伤,船只爆炸,李相夷沉入大海。', + '十年后,李莲花在一个寒酸的莲花楼内醒来,表现出与李相夷截然不同的性格。他以神医的身份在小镇上行医,但生活贫困。', + '小镇上的皮影戏摊讲述李相夷和笛飞声的故事,孩子们争论谁赢了。风火堂管事带着人来找李莲花,要求他救治一个“死人”。' + ] + }, + ] + op = EntityAttributeAggregator( + api_model='qwen2.5-72b-instruct', + entity='李莲花', + attribute='身份背景', + input_key='sub_docs', + output_key='text' + ) + self._run_helper(op, samples) + + def test_max_token_num(self): + samples = [ + { + 'text': [ + "十年前,李相夷十五岁战胜西域天魔成为天下第一高手,十七岁建立四顾门,二十岁问鼎武林盟主,成为传奇人物。", + "有人视李相夷为中原武林的希望,但也有人以战胜他为目标,包括魔教金鸳盟盟主笛飞声。笛飞声设计加害李相夷的师兄单孤刀,引得李相夷与之一战。", + '在东海的一艘船上,李相夷独自一人对抗金鸳盟的高手,最终击败了大部分敌人。笛飞声突然出现,两人激战,李相夷在战斗中中毒,最终被笛飞声重伤,船只爆炸,李相夷沉入大海。', + '十年后,李莲花在一个寒酸的莲花楼内醒来,表现出与李相夷截然不同的性格。他以神医的身份在小镇上行医,但生活贫困。', + '小镇上的皮影戏摊讲述李相夷和笛飞声的故事,孩子们争论谁赢了。风火堂管事带着人来找李莲花,要求他救治一个“死人”。' + ] + }, + ] + op = EntityAttributeAggregator( + api_model='qwen2.5-72b-instruct', + entity='李莲花', + attribute='身份背景', + max_token_num=200 + ) + self._run_helper(op, samples) + + def test_word_limit_num(self): + samples = [ + { + 'text': [ + "十年前,李相夷十五岁战胜西域天魔成为天下第一高手,十七岁建立四顾门,二十岁问鼎武林盟主,成为传奇人物。", + "有人视李相夷为中原武林的希望,但也有人以战胜他为目标,包括魔教金鸳盟盟主笛飞声。笛飞声设计加害李相夷的师兄单孤刀,引得李相夷与之一战。", + '在东海的一艘船上,李相夷独自一人对抗金鸳盟的高手,最终击败了大部分敌人。笛飞声突然出现,两人激战,李相夷在战斗中中毒,最终被笛飞声重伤,船只爆炸,李相夷沉入大海。', + '十年后,李莲花在一个寒酸的莲花楼内醒来,表现出与李相夷截然不同的性格。他以神医的身份在小镇上行医,但生活贫困。', + '小镇上的皮影戏摊讲述李相夷和笛飞声的故事,孩子们争论谁赢了。风火堂管事带着人来找李莲花,要求他救治一个“死人”。' + ] + }, + ] + op = EntityAttributeAggregator( + api_model='qwen2.5-72b-instruct', + entity='李莲花', + attribute='身份背景', + word_limit=20 + ) + self._run_helper(op, samples) + + + def test_example_prompt(self): + samples = [ + { + 'text': [ + "十年前,李相夷十五岁战胜西域天魔成为天下第一高手,十七岁建立四顾门,二十岁问鼎武林盟主,成为传奇人物。", + "有人视李相夷为中原武林的希望,但也有人以战胜他为目标,包括魔教金鸳盟盟主笛飞声。笛飞声设计加害李相夷的师兄单孤刀,引得李相夷与之一战。", + '在东海的一艘船上,李相夷独自一人对抗金鸳盟的高手,最终击败了大部分敌人。笛飞声突然出现,两人激战,李相夷在战斗中中毒,最终被笛飞声重伤,船只爆炸,李相夷沉入大海。', + '十年后,李莲花在一个寒酸的莲花楼内醒来,表现出与李相夷截然不同的性格。他以神医的身份在小镇上行医,但生活贫困。', + '小镇上的皮影戏摊讲述李相夷和笛飞声的故事,孩子们争论谁赢了。风火堂管事带着人来找李莲花,要求他救治一个“死人”。' + ] + }, + ] + example_prompt=( + '- 例如,根据相关文档总结`孙悟空`的`另外身份`,样例如下:\n' + '`孙悟空`的`另外身份`总结:\n' + '# 孙悟空\n' + '## 另外身份\n' + '孙行者、齐天大圣、美猴王\n' + ) + op = EntityAttributeAggregator( + api_model='qwen2.5-72b-instruct', + entity='李莲花', + attribute='另外身份', + example_prompt=example_prompt, + word_limit=20 + ) + self._run_helper(op, samples) + + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/tests/ops/Aggregator/test_most_relavant_entities_aggregator.py b/tests/ops/Aggregator/test_most_relavant_entities_aggregator.py new file mode 100644 index 000000000..1d8678134 --- /dev/null +++ b/tests/ops/Aggregator/test_most_relavant_entities_aggregator.py @@ -0,0 +1,93 @@ +import unittest + +from loguru import logger + +from data_juicer.core.data import NestedDataset as Dataset +from data_juicer.ops.aggregator import MostRelavantEntitiesAggregator +from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase, SKIPPED_TESTS + + +@SKIPPED_TESTS.register_module() +class MostRelavantEntitiesAggregatorTest(DataJuicerTestCaseBase): + + def _run_helper(self, op, samples): + + # before runing this test, set below environment variables: + # export OPENAI_BASE_URL=https://dashscope.aliyuncs.com/compatible-mode/v1/ + # export OPENAI_API_KEY=your_dashscope_key + + dataset = Dataset.from_list(samples) + new_dataset = op.run(dataset) + + for data in new_dataset: + for k in data: + logger.info(f"{k}: {data[k]}") + + self.assertEqual(len(new_dataset), len(samples)) + + def test_default_aggregator(self): + samples = [ + { + 'text': [ + "十年前,李相夷十五岁战胜西域天魔成为天下第一高手,十七岁建立四顾门,二十岁问鼎武林盟主,成为传奇人物。", + "有人视李相夷为中原武林的希望,但也有人以战胜他为目标,包括魔教金鸳盟盟主笛飞声。笛飞声设计加害李相夷的师兄单孤刀,引得李相夷与之一战。", + '在东海的一艘船上,李相夷独自一人对抗金鸳盟的高手,最终击败了大部分敌人。笛飞声突然出现,两人激战,李相夷在战斗中中毒,最终被笛飞声重伤,船只爆炸,李相夷沉入大海。', + '十年后,李莲花在一个寒酸的莲花楼内醒来,表现出与李相夷截然不同的性格。他以神医的身份在小镇上行医,但生活贫困。', + '小镇上的皮影戏摊讲述李相夷和笛飞声的故事,孩子们争论谁赢了。风火堂管事带着人来找李莲花,要求他救治一个“死人”。' + ] + }, + ] + + op = MostRelavantEntitiesAggregator( + api_model='qwen2.5-72b-instruct', + entity='李莲花', + query_entity_type='人物' + ) + self._run_helper(op, samples) + + def test_input_output(self): + samples = [ + { + 'dj_result':{ + 'events': [ + "十年前,李相夷十五岁战胜西域天魔成为天下第一高手,十七岁建立四顾门,二十岁问鼎武林盟主,成为传奇人物。", + "有人视李相夷为中原武林的希望,但也有人以战胜他为目标,包括魔教金鸳盟盟主笛飞声。笛飞声设计加害李相夷的师兄单孤刀,引得李相夷与之一战。", + '在东海的一艘船上,李相夷独自一人对抗金鸳盟的高手,最终击败了大部分敌人。笛飞声突然出现,两人激战,李相夷在战斗中中毒,最终被笛飞声重伤,船只爆炸,李相夷沉入大海。', + '十年后,李莲花在一个寒酸的莲花楼内醒来,表现出与李相夷截然不同的性格。他以神医的身份在小镇上行医,但生活贫困。', + '小镇上的皮影戏摊讲述李相夷和笛飞声的故事,孩子们争论谁赢了。风火堂管事带着人来找李莲花,要求他救治一个“死人”。' + ] + } + }, + ] + + op = MostRelavantEntitiesAggregator( + api_model='qwen2.5-72b-instruct', + entity='李莲花', + query_entity_type='人物', + input_key='dj_result.events', + output_key='dj_result.relavant_roles' + ) + self._run_helper(op, samples) + + def test_max_token_num(self): + samples = [ + { + 'text': [ + "十年前,李相夷十五岁战胜西域天魔成为天下第一高手,十七岁建立四顾门,二十岁问鼎武林盟主,成为传奇人物。", + "有人视李相夷为中原武林的希望,但也有人以战胜他为目标,包括魔教金鸳盟盟主笛飞声。笛飞声设计加害李相夷的师兄单孤刀,引得李相夷与之一战。", + '在东海的一艘船上,李相夷独自一人对抗金鸳盟的高手,最终击败了大部分敌人。笛飞声突然出现,两人激战,李相夷在战斗中中毒,最终被笛飞声重伤,船只爆炸,李相夷沉入大海。', + '十年后,李莲花在一个寒酸的莲花楼内醒来,表现出与李相夷截然不同的性格。他以神医的身份在小镇上行医,但生活贫困。', + '小镇上的皮影戏摊讲述李相夷和笛飞声的故事,孩子们争论谁赢了。风火堂管事带着人来找李莲花,要求他救治一个“死人”。' + ] + }, + ] + op = MostRelavantEntitiesAggregator( + api_model='qwen2.5-72b-instruct', + entity='李莲花', + query_entity_type='人物', + max_token_num=40 + ) + self._run_helper(op, samples) + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/tests/ops/Aggregator/test_nested_aggregator.py b/tests/ops/Aggregator/test_nested_aggregator.py new file mode 100644 index 000000000..6347652bc --- /dev/null +++ b/tests/ops/Aggregator/test_nested_aggregator.py @@ -0,0 +1,119 @@ +import unittest + +from loguru import logger + +from data_juicer.core.data import NestedDataset as Dataset +from data_juicer.ops.aggregator import NestedAggregator +from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase, SKIPPED_TESTS + + +@SKIPPED_TESTS.register_module() +class NestedAggregatorTest(DataJuicerTestCaseBase): + + def _run_helper(self, op, samples): + + # before runing this test, set below environment variables: + # export OPENAI_BASE_URL=https://dashscope.aliyuncs.com/compatible-mode/v1/ + # export OPENAI_API_KEY=your_dashscope_key + + dataset = Dataset.from_list(samples) + new_dataset = op.run(dataset) + + for data in new_dataset: + for k in data: + logger.info(f"{k}: {data[k]}") + + self.assertEqual(len(new_dataset), len(samples)) + + def test_default_aggregator(self): + samples = [ + { + 'text': [ + "十年前,李相夷十五岁战胜西域天魔成为天下第一高手,十七岁建立四顾门,二十岁问鼎武林盟主,成为传奇人物。", + "有人视李相夷为中原武林的希望,但也有人以战胜他为目标,包括魔教金鸳盟盟主笛飞声。笛飞声设计加害李相夷的师兄单孤刀,引得李相夷与之一战。", + '在东海的一艘船上,李相夷独自一人对抗金鸳盟的高手,最终击败了大部分敌人。笛飞声突然出现,两人激战,李相夷在战斗中中毒,最终被笛飞声重伤,船只爆炸,李相夷沉入大海。', + '十年后,李莲花在一个寒酸的莲花楼内醒来,表现出与李相夷截然不同的性格。他以神医的身份在小镇上行医,但生活贫困。', + '小镇上的皮影戏摊讲述李相夷和笛飞声的故事,孩子们争论谁赢了。风火堂管事带着人来找李莲花,要求他救治一个“死人”。' + ] + }, + ] + op = NestedAggregator( + api_model='qwen2.5-72b-instruct' + ) + self._run_helper(op, samples) + + def test_input_output(self): + samples = [ + { + 'sub_docs': [ + "十年前,李相夷十五岁战胜西域天魔成为天下第一高手,十七岁建立四顾门,二十岁问鼎武林盟主,成为传奇人物。", + "有人视李相夷为中原武林的希望,但也有人以战胜他为目标,包括魔教金鸳盟盟主笛飞声。笛飞声设计加害李相夷的师兄单孤刀,引得李相夷与之一战。", + '在东海的一艘船上,李相夷独自一人对抗金鸳盟的高手,最终击败了大部分敌人。笛飞声突然出现,两人激战,李相夷在战斗中中毒,最终被笛飞声重伤,船只爆炸,李相夷沉入大海。', + '十年后,李莲花在一个寒酸的莲花楼内醒来,表现出与李相夷截然不同的性格。他以神医的身份在小镇上行医,但生活贫困。', + '小镇上的皮影戏摊讲述李相夷和笛飞声的故事,孩子们争论谁赢了。风火堂管事带着人来找李莲花,要求他救治一个“死人”。' + ] + }, + ] + op = NestedAggregator( + api_model='qwen2.5-72b-instruct', + input_key='sub_docs', + output_key='text' + ) + self._run_helper(op, samples) + + def test_max_token_num_1(self): + samples = [ + { + 'text': [ + "十年前,李相夷十五岁战胜西域天魔成为天下第一高手,十七岁建立四顾门,二十岁问鼎武林盟主,成为传奇人物。", + "有人视李相夷为中原武林的希望,但也有人以战胜他为目标,包括魔教金鸳盟盟主笛飞声。笛飞声设计加害李相夷的师兄单孤刀,引得李相夷与之一战。", + '在东海的一艘船上,李相夷独自一人对抗金鸳盟的高手,最终击败了大部分敌人。笛飞声突然出现,两人激战,李相夷在战斗中中毒,最终被笛飞声重伤,船只爆炸,李相夷沉入大海。', + '十年后,李莲花在一个寒酸的莲花楼内醒来,表现出与李相夷截然不同的性格。他以神医的身份在小镇上行医,但生活贫困。', + '小镇上的皮影戏摊讲述李相夷和笛飞声的故事,孩子们争论谁赢了。风火堂管事带着人来找李莲花,要求他救治一个“死人”。' + ] + }, + ] + op = NestedAggregator( + api_model='qwen2.5-72b-instruct', + max_token_num=2 + ) + self._run_helper(op, samples) + + def test_max_token_num_2(self): + samples = [ + { + 'text': [ + "十年前,李相夷十五岁战胜西域天魔成为天下第一高手,十七岁建立四顾门,二十岁问鼎武林盟主,成为传奇人物。", + "有人视李相夷为中原武林的希望,但也有人以战胜他为目标,包括魔教金鸳盟盟主笛飞声。笛飞声设计加害李相夷的师兄单孤刀,引得李相夷与之一战。", + '在东海的一艘船上,李相夷独自一人对抗金鸳盟的高手,最终击败了大部分敌人。笛飞声突然出现,两人激战,李相夷在战斗中中毒,最终被笛飞声重伤,船只爆炸,李相夷沉入大海。', + '十年后,李莲花在一个寒酸的莲花楼内醒来,表现出与李相夷截然不同的性格。他以神医的身份在小镇上行医,但生活贫困。', + '小镇上的皮影戏摊讲述李相夷和笛飞声的故事,孩子们争论谁赢了。风火堂管事带着人来找李莲花,要求他救治一个“死人”。' + ] + }, + ] + op = NestedAggregator( + api_model='qwen2.5-72b-instruct', + max_token_num=90 + ) + self._run_helper(op, samples) + + def test_max_token_num_3(self): + samples = [ + { + 'text': [ + "十年前,李相夷十五岁战胜西域天魔成为天下第一高手,十七岁建立四顾门,二十岁问鼎武林盟主,成为传奇人物。", + "有人视李相夷为中原武林的希望,但也有人以战胜他为目标,包括魔教金鸳盟盟主笛飞声。笛飞声设计加害李相夷的师兄单孤刀,引得李相夷与之一战。", + '在东海的一艘船上,李相夷独自一人对抗金鸳盟的高手,最终击败了大部分敌人。笛飞声突然出现,两人激战,李相夷在战斗中中毒,最终被笛飞声重伤,船只爆炸,李相夷沉入大海。', + '十年后,李莲花在一个寒酸的莲花楼内醒来,表现出与李相夷截然不同的性格。他以神医的身份在小镇上行医,但生活贫困。', + '小镇上的皮影戏摊讲述李相夷和笛飞声的故事,孩子们争论谁赢了。风火堂管事带着人来找李莲花,要求他救治一个“死人”。' + ] + }, + ] + op = NestedAggregator( + api_model='qwen2.5-72b-instruct', + max_token_num=200 + ) + self._run_helper(op, samples) + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/tests/ops/filter/test_audio_duration_filter.py b/tests/ops/filter/test_audio_duration_filter.py index d336e9b10..64a5c05c8 100644 --- a/tests/ops/filter/test_audio_duration_filter.py +++ b/tests/ops/filter/test_audio_duration_filter.py @@ -7,7 +7,6 @@ from data_juicer.utils.constant import Fields from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase, TEST_TAG - class AudioDurationFilterTest(DataJuicerTestCaseBase): data_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', diff --git a/tests/ops/filter/test_audio_nmf_snr_filter.py b/tests/ops/filter/test_audio_nmf_snr_filter.py index 1cc010b2f..d0dec38b8 100644 --- a/tests/ops/filter/test_audio_nmf_snr_filter.py +++ b/tests/ops/filter/test_audio_nmf_snr_filter.py @@ -7,7 +7,6 @@ from data_juicer.utils.constant import Fields from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase - class AudioNMFSNRFilterTest(DataJuicerTestCaseBase): data_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', diff --git a/tests/ops/filter/test_video_motion_score_raft_filter.py b/tests/ops/filter/test_video_motion_score_raft_filter.py index abd8e7374..89f9e0548 100644 --- a/tests/ops/filter/test_video_motion_score_raft_filter.py +++ b/tests/ops/filter/test_video_motion_score_raft_filter.py @@ -6,9 +6,11 @@ from data_juicer.ops.filter.video_motion_score_raft_filter import \ VideoMotionScoreRaftFilter from data_juicer.utils.constant import Fields -from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase - +from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase, SKIPPED_TESTS +# skip due to conflicts when run lazy_load in multiprocessing in librosa +# tests passed locally. +@SKIPPED_TESTS.register_module() class VideoMotionScoreRaftFilterTest(DataJuicerTestCaseBase): data_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', diff --git a/tests/ops/filter/test_video_tagging_from_frames_filter.py b/tests/ops/filter/test_video_tagging_from_frames_filter.py index 545be9748..bc4f67fb4 100644 --- a/tests/ops/filter/test_video_tagging_from_frames_filter.py +++ b/tests/ops/filter/test_video_tagging_from_frames_filter.py @@ -6,7 +6,7 @@ from data_juicer.ops.filter.video_tagging_from_frames_filter import \ VideoTaggingFromFramesFilter from data_juicer.utils.mm_utils import SpecialTokens -from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase, SKIPPED_TESTS +from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase class VideoTaggingFromFramesFilterTest(DataJuicerTestCaseBase): data_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', diff --git a/tests/ops/grouper/__init__.py b/tests/ops/grouper/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/ops/grouper/test_key_value_grouper.py b/tests/ops/grouper/test_key_value_grouper.py new file mode 100644 index 000000000..1ac186423 --- /dev/null +++ b/tests/ops/grouper/test_key_value_grouper.py @@ -0,0 +1,54 @@ +import unittest + +from data_juicer.core.data import NestedDataset as Dataset +from data_juicer.ops.grouper.key_value_grouper import KeyValueGrouper +from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase + + +class KeyValueGrouperTest(DataJuicerTestCaseBase): + + def _run_helper(self, op, samples, target): + dataset = Dataset.from_list(samples) + new_dataset = op.run(dataset) + + for batched_sample in new_dataset: + lang = batched_sample['meta'][0]['language'] + self.assertEqual(batched_sample['text'], target[lang]) + + def test_key_value_grouper(self): + + source = [ + { + 'text': "Today is Sunday and it's a happy day!", + 'meta': { + 'language': 'en' + } + }, + { + 'text': "Welcome to Alibaba.", + 'meta': { + 'language': 'en' + } + }, + { + 'text': '欢迎来到阿里巴巴!', + 'meta': { + 'language': 'zh' + } + }, + ] + target = { + 'en':[ + "Today is Sunday and it's a happy day!", + "Welcome to Alibaba." + ], + 'zh':[ + '欢迎来到阿里巴巴!' + ] + } + + op = KeyValueGrouper(['meta.language']) + self._run_helper(op, source, target) + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/tests/ops/grouper/test_naive_grouper.py b/tests/ops/grouper/test_naive_grouper.py new file mode 100644 index 000000000..4e69a8ba2 --- /dev/null +++ b/tests/ops/grouper/test_naive_grouper.py @@ -0,0 +1,47 @@ +import unittest + +from data_juicer.core.data import NestedDataset as Dataset +from data_juicer.ops.grouper.naive_grouper import NaiveGrouper +from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase + + +class NaiveGrouperTest(DataJuicerTestCaseBase): + + def _run_helper(self, op, samples, target): + dataset = Dataset.from_list(samples) + new_dataset = op.run(dataset) + + for d, t in zip(new_dataset, target): + self.assertEqual(d['text'], t['text']) + + def test_naive_group(self): + + source = [ + { + 'text': "Today is Sunday and it's a happy day!" + }, + { + 'text': + "Sur la plateforme MT4, plusieurs manières d'accéder à \n" + 'ces fonctionnalités sont conçues simultanément.' + }, + { + 'text': '欢迎来到阿里巴巴!' + }, + ] + target = [ + { + 'text':[ + "Today is Sunday and it's a happy day!", + "Sur la plateforme MT4, plusieurs manières d'accéder à \n" + 'ces fonctionnalités sont conçues simultanément.', + '欢迎来到阿里巴巴!' + ] + } + ] + + op = NaiveGrouper() + self._run_helper(op, source, target) + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/tests/ops/mapper/test_calibrate_qa_mapper.py b/tests/ops/mapper/test_calibrate_qa_mapper.py index ea237093b..5755ed2b1 100644 --- a/tests/ops/mapper/test_calibrate_qa_mapper.py +++ b/tests/ops/mapper/test_calibrate_qa_mapper.py @@ -76,7 +76,7 @@ def test(self): def test_args(self): op = CalibrateQAMapper( api_model='qwen2.5-72b-instruct', - api_url= + api_endpoint= 'https://dashscope.aliyuncs.com/compatible-mode/v1/chat/completions', response_path='choices.0.message.content') self._run_op(op) diff --git a/tests/ops/mapper/test_calibrate_query_mapper.py b/tests/ops/mapper/test_calibrate_query_mapper.py index 8229c10ed..f95b6c5dc 100644 --- a/tests/ops/mapper/test_calibrate_query_mapper.py +++ b/tests/ops/mapper/test_calibrate_query_mapper.py @@ -69,8 +69,8 @@ def _run_op(self, api_model, response_path=None): def test(self): # before runing this test, set below environment variables: - # export DJ_API_URL=https://dashscope.aliyuncs.com/compatible-mode/v1/chat/completions - # export DJ_API_KEY=your_key + # export OPENAI_API_URL=https://dashscope.aliyuncs.com/compatible-mode/v1 + # export OPENAI_API_KEY=your_key self._run_op('qwen2.5-72b-instruct') diff --git a/tests/ops/mapper/test_calibrate_response_mapper.py b/tests/ops/mapper/test_calibrate_response_mapper.py index e092d4c48..4a9ddbe11 100644 --- a/tests/ops/mapper/test_calibrate_response_mapper.py +++ b/tests/ops/mapper/test_calibrate_response_mapper.py @@ -70,8 +70,8 @@ def _run_op(self, api_model, response_path=None): def test(self): # before runing this test, set below environment variables: - # export DJ_API_URL=https://dashscope.aliyuncs.com/compatible-mode/v1/chat/completions - # export DJ_API_KEY=your_key + # export OPENAI_API_URL=https://dashscope.aliyuncs.com/compatible-mode/v1 + # export OPENAI_API_KEY=your_key self._run_op('qwen2.5-72b-instruct') diff --git a/tests/ops/mapper/test_extract_entity_attribute_mapper.py b/tests/ops/mapper/test_extract_entity_attribute_mapper.py index 96f186d29..f15b4ca3f 100644 --- a/tests/ops/mapper/test_extract_entity_attribute_mapper.py +++ b/tests/ops/mapper/test_extract_entity_attribute_mapper.py @@ -21,9 +21,9 @@ def _run_op(self, api_model, response_path=None): query_attributes = ["语言风格", "角色性格"] op = ExtractEntityAttributeMapper( + api_model=api_model, query_entities=query_entities, - query_attributes=query_attributes, - api_model=api_model, + query_attributes=query_attributes, response_path=response_path) raw_text = """△笛飞声独自坐在莲花楼屋顶上。李莲花边走边悠闲地给马喂草。方多病则走在一侧,却总不时带着怀疑地盯向楼顶的笛飞声。 @@ -49,9 +49,14 @@ def _run_op(self, api_model, response_path=None): dataset = Dataset.from_list(samples) dataset = dataset.map(op.process, batch_size=1) for sample in dataset: - logger.info(f'{sample[Fields.main_entity]} {sample[Fields.attribute]}: {sample[Fields.attribute_description]}') - self.assertNotEqual(sample[Fields.attribute_description], '') - self.assertNotEqual(len(sample[Fields.attribute_support_text]), 0) + ents = sample[Fields.main_entities] + attrs = sample[Fields.attributes] + descs = sample[Fields.attribute_descriptions] + sups = sample[Fields.attribute_support_texts] + for ent, attr, desc, sup in zip(ents, attrs, descs, sups): + logger.info(f'{ent} {attr}: {desc}') + self.assertNotEqual(desc, '') + self.assertNotEqual(len(sup), 0) def test(self): # before runing this test, set below environment variables: diff --git a/tests/ops/mapper/test_extract_event_mapper.py b/tests/ops/mapper/test_extract_event_mapper.py index 1652c8db2..aba40d73e 100644 --- a/tests/ops/mapper/test_extract_event_mapper.py +++ b/tests/ops/mapper/test_extract_event_mapper.py @@ -18,7 +18,8 @@ class ExtractEventMapperTest(DataJuicerTestCaseBase): def _run_op(self, api_model, response_path=None): op = ExtractEventMapper(api_model=api_model, - response_path=response_path) + response_path=response_path, + index_key='chunk_id') raw_text = """△芩婆走到中间,看着众人。 芩婆:当年,我那老鬼漆木山与李相夷之父乃是挚交。原本李家隐世而居,一日为了救人,得罪附近山匪,夜里便遭了山匪所袭,唯有二子生还,流落街头。 @@ -57,9 +58,11 @@ def _run_op(self, api_model, response_path=None): }] dataset = Dataset.from_list(samples) - dataset = dataset.map(op.process, batch_size=2) + dataset = op.run(dataset) self.assertNotEqual(len(dataset), 0) for sample in dataset: + logger.info(f"chunk_id: {sample['chunk_id']}") + self.assertEqual(sample['chunk_id'], 0) logger.info(f"event: {sample[Fields.event_description]}") self.assertNotEqual(sample[Fields.event_description], '') logger.info(f"characters: {sample[Fields.relevant_characters]}") diff --git a/tests/ops/mapper/test_extract_nickname_mapper.py b/tests/ops/mapper/test_extract_nickname_mapper.py index 635801155..2911a1002 100644 --- a/tests/ops/mapper/test_extract_nickname_mapper.py +++ b/tests/ops/mapper/test_extract_nickname_mapper.py @@ -49,8 +49,8 @@ def _run_op(self, api_model, response_path=None): def test(self): # before runing this test, set below environment variables: - # export DJ_API_URL=https://dashscope.aliyuncs.com/compatible-mode/v1/chat/completions - # export DJ_API_KEY=your_key + # export OPENAI_API_URL=https://dashscope.aliyuncs.com/compatible-mode/v1 + # export OPENAI_API_KEY=your_key self._run_op('qwen2.5-72b-instruct') diff --git a/tests/ops/mapper/test_extract_support_text_mapper.py b/tests/ops/mapper/test_extract_support_text_mapper.py new file mode 100644 index 000000000..0445d2526 --- /dev/null +++ b/tests/ops/mapper/test_extract_support_text_mapper.py @@ -0,0 +1,80 @@ +import unittest +import json + +from loguru import logger + +from data_juicer.core.data import NestedDataset as Dataset +from data_juicer.ops.mapper.extract_support_text_mapper import ExtractSupportTextMapper +from data_juicer.utils.unittest_utils import (SKIPPED_TESTS, + DataJuicerTestCaseBase) +from data_juicer.utils.constant import Fields +from data_juicer.utils.common_utils import nested_access + +# Skip tests for this OP in the GitHub actions due to unknown DistNetworkError. +# These tests have been tested locally. +@SKIPPED_TESTS.register_module() +class ExtractSupportTextMapperTest(DataJuicerTestCaseBase): + + + def _run_op(self, api_model): + + summary_key = 'data.event' + support_text_key = 'data.support_text' + op = ExtractSupportTextMapper(api_model=api_model, + summary_key=summary_key, + support_text_key=support_text_key) + + raw_text = """△芩婆走到中间,看着众人。 +芩婆:当年,我那老鬼漆木山与李相夷之父乃是挚交。原本李家隐世而居,一日为了救人,得罪附近山匪,夜里便遭了山匪所袭,唯有二子生还,流落街头。 +封磬震惊:二子?不是只有一个儿子吗? +芩婆:我和漆木山得知这个噩耗后,到处寻找李家那两个孩子的下落。只可惜等我们找他们时,李家长子李相显已经病死。 +李莲花似回忆起了什么:李相显...... +芩婆:我们只从乞丐堆里带回了年纪尚且未满四岁的李相夷,以及,(看向单孤刀)二个一直护着李相夷,与李相显年纪相仿的小乞丐...... +闪回/ +李相显将李且给他的玉佩塞给单孤刀,恳切托付:我没什么值钱的东西,这个玉佩是我唯一的家当了、送给你,我弟弟、相夷......求你照顾他一阵...... +△李相显还想再说什么已气绝而亡,小相夷唤着哥哥大哭,单孤刀愕然看着手里的玉佩有点不知所措。 +△话刚说完,哐当一声破庙门倒进来,几个其他少年乞丐进来。少年乞丐老大:这地儿不错,诶,你俩,出去! +△单孤刀把小相夷护在身后,抓住靠在墙边的木棍。单孤刀:这儿,是我,和我弟弟的。 +乞丐们要抢李相夷的馒头,小李相夷哭着死死护住自馒头不放。 +乞丐甲野蛮地抢:给我拿来! +小单孤刀:放开他! +△单孤刀用力撞向几个乞丐,救下小李相夷。乞丐甲:小子,活腻了! +△几个乞丐围攻小单孤刀,小单孤刀和众乞丐厮打到一起。突然其中一个乞丐掏出一把生锈的刀就朝单孤刀砍去、一个点燃火把棍戳他。单孤刀侧手一挡,火把棍在他手腕上烫出一道伤口,身后几根棍子打得他痛苦倒地! +/闪回结束 +△单孤刀拿着自己手里的玉佩看着,又看看自己手上的印记,不肯相信。单孤刀:胡说!全都是胡说!这些事我为何不知道?都是你在信口雌黄! +芩婆:那我问你,我们将你带回云隐山之前的事你又记得多少? +△单孤刀突然愣住,他意识到那之前的事自己竟都想不起来。 +芩婆:怎么?都想不起来了?(拽起单孤刀手腕,露出他的伤痕)你当日被你师父找到时,手腕上就受了伤,也正因为这处伤,高烧不退,醒来后便忘记了不少从前的事。 +△单孤刀呆住。 +芩婆:而相夷当年不过孩童,尚未到记事的年纪,很多事自然不知道。 +△李莲花得知真相,闭目叹息。 +△封磬震惊地看看单孤刀,又看看李莲花,终于想明白了一切,颓然、懊恼。 +封磬:自萱公主之子下落不明后,这近百年来我们整个家族都一直在不遗余力地寻找萱公主的子嗣后代,直到二十几年前终于让我寻得了线索,知道萱公主的曾孙被漆木山夫妇收为徒,但......我只知道萱公主之孙有一年约十岁的儿子,却不知......原来竟还有一幼子!我......我凭着南胤皇族的玉佩、孩子的年纪和他身上的印记来与主上相认,可没想到......这竟是一个错误!全错了! +△封磬神情复杂地看向李莲花,封磬:你,你才是我的主上...... +△封磬颓然地跪倒下来。 +△李莲花对眼前的一切有些意外、无措。 +笛飞声冷声:怪不得单孤刀的血对业火独毫无作用,李莲花的血才能毁掉这东西。 +△笛飞声不禁冷笑一下。 +""" + event = "李相显托付单孤刀。" + samples = [{ + 'text': raw_text, + 'data':{ + 'event': event + } + }] + + dataset = Dataset.from_list(samples) + dataset = dataset.map(op.process, batch_size=2) + sample = dataset[0] + logger.info(f"support_text: \n{nested_access(sample, support_text_key)}") + + def test(self): + # before runing this test, set below environment variables: + # export OPENAI_BASE_URL=https://dashscope.aliyuncs.com/compatible-mode/v1/ + # export OPENAI_API_KEY=your_dashscope_key + self._run_op('qwen2.5-72b-instruct') + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/ops/mapper/test_pair_preference_mapper.py b/tests/ops/mapper/test_pair_preference_mapper.py new file mode 100644 index 000000000..93cd4d877 --- /dev/null +++ b/tests/ops/mapper/test_pair_preference_mapper.py @@ -0,0 +1,57 @@ +import unittest + +from loguru import logger + +from data_juicer.ops.mapper.pair_preference_mapper import PairPreferenceMapper +from data_juicer.utils.unittest_utils import (SKIPPED_TESTS, + DataJuicerTestCaseBase) + + +# Skip tests for this OP because the API call is not configured yet. +# These tests have been tested locally. +@SKIPPED_TESTS.register_module() +class PairPreferenceMapperTest(DataJuicerTestCaseBase): + + def _run_op(self, op, samples): + for sample in samples: + result = op.process(sample) + logger.info(f'Output results: {result}') + self.assertNotEqual(result['rejected_response'], '') + self.assertNotEqual(result['reason'], '') + + def test(self): + # before runing this test, set below environment variables: + # export OPENAI_BASE_URL=https://dashscope.aliyuncs.com/compatible-mode/v1 + # export OPENAI_API_KEY=your_key + + reference = '王八十娘:小远城王八十的娘亲,李莲花刚到小远城时被方多病偷掉钱袋找小乞丐问路时,刚好发现王八十娘被另一个小乞丐撞到便将她扶起,结识了王八十。\n朴二黄:灵山派管家,方多病小厮旺福的父亲。真实身份是金鸳盟的奔雷手辛雷,离开金鸳盟后,用假名朴二黄在灵山派当管家。因害怕王青山看穿他的身份,设计杀死了灵山派的王青山。被捕后识破了李莲花的真实身份,最后在攻击李莲花的时候被方多病情急之下杀死。' # noqa: E501 + samples = [{ + 'text': reference, + 'query': '李莲花,你认识方多病吗?', + 'response': '方多病啊,那可是我的好友。' + }] + op = PairPreferenceMapper(api_model='qwen2.5-72b-instruct') + self._run_op(op, samples) + + def test_no_reference(self): + samples = [{'query': '李莲花,你认识方多病吗?', 'response': '方多病啊,那可是我的好友。'}] + system_prompt = ('修改问答对中的回答,在语言风格、事实性、人物身份、立场等任一方面与原回答相反。' + '必须按照以下标记格式输出,不要输出其他多余内容。\n' + '【回答】\n' + '生成的新回答\n' + '【原因】\n' + '生成该回答的原因') + input_template = ('以下是原始问答对:\n' + '【问题】\n' + '{query}\n' + '【回答】\n' + '{response}') + + op = PairPreferenceMapper(api_model='qwen2.5-72b-instruct', + system_prompt=system_prompt, + input_template=input_template) + self._run_op(op, samples) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/ops/mapper/test_python_file_mapper.py b/tests/ops/mapper/test_python_file_mapper.py new file mode 100644 index 000000000..97d280481 --- /dev/null +++ b/tests/ops/mapper/test_python_file_mapper.py @@ -0,0 +1,108 @@ +import unittest +import tempfile + +from data_juicer.ops.mapper.python_file_mapper import PythonFileMapper +from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase + +class TestPythonFileMapper(DataJuicerTestCaseBase): + + def test_function_execution(self): + """Test the correct execution of a loadable function.""" + with tempfile.NamedTemporaryFile(delete=True, suffix='.py', mode='w+') as temp_file: + temp_file.write( + "def process_data(sample):\n" + " return {'result': sample['value'] + 10}\n" + ) + temp_file.seek(0) # Rewind the file so it can be read + mapper = PythonFileMapper(temp_file.name, "process_data") + result = mapper.process_single({'value': 5}) + self.assertEqual(result, {'result': 15}) + + def test_function_batched(self): + """Test for a funtion that processes a batch.""" + with tempfile.NamedTemporaryFile(delete=True, suffix='.py', mode='w+') as temp_file: + temp_file.write( + "def process_data(samples):\n" + " return {'result': samples['value'] + [10]}\n" + ) + temp_file.seek(0) # Rewind the file so it can be read + mapper = PythonFileMapper(temp_file.name, "process_data", batched=True) + result = mapper.process_batched({'value': [5]}) + self.assertEqual(result, {'result': [5, 10]}) + + def test_function_with_import(self): + """Test for a function that contains an import statement.""" + with tempfile.NamedTemporaryFile(delete=True, suffix='.py', mode='w+') as temp_file: + temp_file.write( + "import numpy as np\n" + "def process_data(sample):\n" + " return {'result': np.sum([sample['value'], 10])}\n" + ) + temp_file.seek(0) # Rewind the file so it can be read + mapper = PythonFileMapper(temp_file.name, "process_data") + result = mapper.process_single({'value': 5}) + self.assertEqual(result, {'result': 15}) + + def test_file_not_found(self): + """Test for a non-existent file.""" + with self.assertRaises(FileNotFoundError) as cm: + PythonFileMapper("non_existent.py", "process_data") + self.assertIn("does not exist", str(cm.exception)) + + def test_file_not_python_extension(self): + """Test for a file that exists but is not a .py file.""" + with tempfile.NamedTemporaryFile(delete=True, suffix='.txt', mode='w+') as temp_file: + temp_file.write("This is a text file.") + temp_file.seek(0) # Rewind the file so it can be read + with self.assertRaises(ValueError) as cm: + PythonFileMapper(temp_file.name, "some_function") + self.assertIn("is not a Python file", str(cm.exception)) + + def test_function_not_found(self): + """Test for function not existing in the provided file.""" + with tempfile.NamedTemporaryFile(delete=True, suffix='.py', mode='w+') as temp_file: + temp_file.write( + "def existing_function(sample):\n" + " return sample\n" + ) + temp_file.seek(0) # Rewind the file so it can be read + with self.assertRaises(ValueError) as cm: + PythonFileMapper(temp_file.name, "non_existing_function") + self.assertIn("not found", str(cm.exception)) + + def test_function_not_callable(self): + """Test for trying to load a non-callable function.""" + with tempfile.NamedTemporaryFile(delete=True, suffix='.py', mode='w+') as temp_file: + temp_file.write("x = 42") + temp_file.seek(0) # Rewind the file so it can be read + with self.assertRaises(ValueError) as cm: + PythonFileMapper(temp_file.name, "x") + self.assertIn("not callable", str(cm.exception)) + + def test_function_mutiple_arguments(self): + """Test for function that requires more than one argument.""" + with tempfile.NamedTemporaryFile(delete=True, suffix='.py', mode='w+') as temp_file: + temp_file.write( + "def multi_arg_function(arg1, arg2):\n" + " return arg1 + arg2\n" + ) + temp_file.seek(0) # Rewind the file so it can be read + with self.assertRaises(ValueError) as cm: + PythonFileMapper(temp_file.name, "multi_arg_function") + self.assertIn("must take exactly one argument", str(cm.exception)) + + def test_invalid_return_type(self): + """Test for a function returning a non-dictionary.""" + with tempfile.NamedTemporaryFile(delete=True, suffix='.py', mode='w+') as temp_file: + temp_file.write( + "def invalid_function(sample):\n" + " return sample['value'] + 5\n" + ) + temp_file.seek(0) # Rewind the file so it can be read + mapper = PythonFileMapper(temp_file.name, "invalid_function") + with self.assertRaises(ValueError) as cm: + mapper.process_single({'value': 5}) + self.assertIn("Function must return a dictionary, got int instead.", str(cm.exception)) + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/tests/ops/mapper/test_python_lambda_mapper.py b/tests/ops/mapper/test_python_lambda_mapper.py new file mode 100644 index 000000000..97fac4794 --- /dev/null +++ b/tests/ops/mapper/test_python_lambda_mapper.py @@ -0,0 +1,68 @@ +import unittest + +from data_juicer.ops.mapper.python_lambda_mapper import PythonLambdaMapper +from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase + +class PythonLambdaMapperMapper(DataJuicerTestCaseBase): + + def test_lambda_function_batched(self): + mapper = PythonLambdaMapper("lambda d: {'value': d['value'] + [6]}", batched=True) # Append '6' to value + result = mapper.process_batched({'value': [5]}) + self.assertEqual(result, {'value': [5, 6]}) + + def test_lambda_modifies_values(self): + mapper = PythonLambdaMapper("lambda d: {'value': d['value'] + 1}") # '+1' to 'value' + result = mapper.process_single({'value': 5}) + self.assertEqual(result, {'value': 6}) + + def test_lambda_combines_values(self): + mapper = PythonLambdaMapper("lambda d: {'combined': d['a'] + d['b']}") + result = mapper.process_single({'a': 3, 'b': 7}) + self.assertEqual(result, {'combined': 10}) + + def test_lambda_swaps_values(self): + mapper = PythonLambdaMapper("lambda d: {'a': d['b'], 'b': d['a']}") + result = mapper.process_single({'a': 1, 'b': 2}) + self.assertEqual(result, {'a': 2, 'b': 1}) + + def test_lambda_result_is_not_dict(self): + mapper = PythonLambdaMapper("lambda d: d['value'] + 1") # This returns an int + with self.assertRaises(ValueError) as cm: + mapper.process_single({'value': 10}) + self.assertIn("Lambda function must return a dictionary, got int instead.", str(cm.exception)) + + def test_invalid_syntax(self): + with self.assertRaises(ValueError) as cm: + PythonLambdaMapper("invalid lambda") # Invalid syntax + self.assertIn("Invalid lambda function", str(cm.exception)) + + def test_invalid_expression(self): + with self.assertRaises(ValueError) as cm: + PythonLambdaMapper("3 + 5") # Not a lambda + self.assertIn("Input string must be a valid lambda function.", str(cm.exception)) + + def test_lambda_with_multiple_arguments(self): + with self.assertRaises(ValueError) as cm: + PythonLambdaMapper("lambda x, y: {'sum': x + y}") # Creating a lambda accepts two arguments + self.assertIn("Lambda function must have exactly one argument.", str(cm.exception)) + + def test_lambda_returning_unexpected_structure(self): + mapper = PythonLambdaMapper("lambda d: ({'value': d['value']}, {'extra': d['extra']})") # Invalid return type; too many dictionaries + with self.assertRaises(ValueError) as cm: + mapper.process_single({'value': 5, 'extra': 10}) + self.assertIn("Lambda function must return a dictionary, got tuple instead.", str(cm.exception)) + + def test_lambda_modifies_in_place_and_returns(self): + mapper = PythonLambdaMapper("lambda d: d.update({'new_key': 'added_value'}) or d") # Returns the modified dictionary + sample_dict = {'value': 3} + result = mapper.process_single(sample_dict) + self.assertEqual(result, {'value': 3, 'new_key': 'added_value'}) # Ensure the update worked + + def test_lambda_function_with_no_operation(self): + mapper = PythonLambdaMapper("lambda d: d") # Simply returns the input dictionary + sample_dict = {'key': 'value'} + result = mapper.process_single(sample_dict) + self.assertEqual(result, {'key': 'value'}) # Unchanged + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/tests/ops/mapper/test_relation_identity_mapper.py b/tests/ops/mapper/test_relation_identity_mapper.py new file mode 100644 index 000000000..d730cb79f --- /dev/null +++ b/tests/ops/mapper/test_relation_identity_mapper.py @@ -0,0 +1,58 @@ +import unittest +import json + +from loguru import logger + +from data_juicer.core.data import NestedDataset as Dataset +from data_juicer.ops.mapper.relation_identity_mapper import RelationIdentityMapper +from data_juicer.utils.unittest_utils import (SKIPPED_TESTS, + DataJuicerTestCaseBase) +from data_juicer.utils.constant import Fields + +# Skip tests for this OP in the GitHub actions due to unknown DistNetworkError. +# These tests have been tested locally. +@SKIPPED_TESTS.register_module() +class RelationIdentityMapperTest(DataJuicerTestCaseBase): + + + def _run_op(self, api_model, response_path=None): + + op = RelationIdentityMapper(api_model=api_model, + source_entity="李莲花", + target_entity="方多病", + response_path=response_path) + + raw_text = """李莲花原名李相夷,十五岁战胜西域天魔,十七岁建立四顾门,二十岁问鼎武林盟主,成为传奇人物。 +在与金鸳盟盟主笛飞声的对决中,李相夷中毒重伤,沉入大海,十年后在莲花楼醒来,过起了市井生活。他帮助肉铺掌柜解决家庭矛盾,表现出敏锐的洞察力。 +李莲花与方多病合作,解决了灵山派掌门王青山的假死案,揭露了朴管家的罪行。 +随后,他与方多病和笛飞声一起调查了玉秋霜的死亡案,最终揭露了玉红烛的阴谋。在朴锄山,李莲花和方多病调查了七具无头尸事件,发现男童的真实身份是笛飞声。 +李莲花利用飞猿爪偷走男童手中的观音垂泪,导致笛飞声恢复内力,但李莲花巧妙逃脱。李莲花与方多病继续合作,调查了少师剑被盗案,揭露了静仁和尚的阴谋。 +在采莲庄,他解决了新娘溺水案,找到了狮魂的线索,并在南门园圃挖出单孤刀的药棺。在玉楼春的案件中,李莲花和方多病揭露了玉楼春的阴谋,救出了被拐的清儿。 +在石寿村,他们发现了柔肠玉酿的秘密,并救出了被控制的武林高手。李莲花与方多病在白水园设下机关,救出方多病的母亲何晓惠,并最终在云隐山找到了治疗碧茶之毒的方法。 +在天机山庄,他揭露了单孤刀的野心,救出了被控制的大臣。在皇宫,李莲花与方多病揭露了魔僧和单孤刀的阴谋,成功解救了皇帝。 +最终,李莲花在东海之滨与笛飞声的决斗中未出现,留下一封信,表示自己已无法赴约。 +一年后,方多病在东海畔的柯厝村找到了李莲花,此时的李莲花双目失明,右手残废,但心态平和,过着简单的生活。 +方多病 (称呼:方小宝、方大少爷)百川院刑探,单孤刀之子,李相夷的徒弟。方多病通过百川院的考核,成为刑探,并在百川院内展示了自己是李相夷的弟子,获得暂时的录用。 +他接到任务前往嘉州调查金鸳盟的余孽,期间与李莲花相识并合作破案。方多病在调查过程中逐渐了解到自己的身世,发现自己的生父是单孤刀。 +他与李莲花、笛飞声等人多次合作,共同对抗金鸳盟和单孤刀的阴谋。方多病在一系列案件中展现了出色的推理能力和武艺,逐渐成长为一名优秀的刑探。 +最终,方多病在天机山庄和皇宫的斗争中发挥了关键作用,帮助李莲花等人挫败了单孤刀的野心。在李莲花中毒后,方多病决心为他寻找解毒之法,展现了深厚的友情。 +""" + samples = [{ + 'text': raw_text, + }] + + dataset = Dataset.from_list(samples) + dataset = dataset.map(op.process, batch_size=2) + for data in dataset: + for k in data: + logger.info(f"{k}: {data[k]}") + + def test(self): + # before runing this test, set below environment variables: + # export OPENAI_API_URL=https://dashscope.aliyuncs.com/compatible-mode/v1 + # export OPENAI_API_KEY=your_key + self._run_op('qwen2.5-72b-instruct') + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/ops/mapper/test_text_chunk_mapper.py b/tests/ops/mapper/test_text_chunk_mapper.py index 8004d9ede..0c0a70db3 100644 --- a/tests/ops/mapper/test_text_chunk_mapper.py +++ b/tests/ops/mapper/test_text_chunk_mapper.py @@ -2,9 +2,10 @@ from data_juicer.core.data import NestedDataset as Dataset from data_juicer.ops.mapper.text_chunk_mapper import TextChunkMapper -from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase +from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase, SKIPPED_TESTS +@SKIPPED_TESTS.register_module() class TextChunkMapperTest(DataJuicerTestCaseBase): def _run_helper(self, op, samples, target): diff --git a/tests/ops/mapper/test_video_extract_frames_mapper.py b/tests/ops/mapper/test_video_extract_frames_mapper.py new file mode 100644 index 000000000..7ae2dd29f --- /dev/null +++ b/tests/ops/mapper/test_video_extract_frames_mapper.py @@ -0,0 +1,242 @@ +import os +import os.path as osp +import re +import copy +import unittest +import json +import tempfile +import shutil +from data_juicer.core.data import NestedDataset as Dataset +from data_juicer.ops.mapper.video_extract_frames_mapper import \ + VideoExtractFramesMapper +from data_juicer.utils.constant import Fields +from data_juicer.utils.mm_utils import SpecialTokens +from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase + + +class VideoExtractFramesMapperTest(DataJuicerTestCaseBase): + + data_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', 'data') + vid1_path = os.path.join(data_path, 'video1.mp4') + vid2_path = os.path.join(data_path, 'video2.mp4') + vid3_path = os.path.join(data_path, 'video3.mp4') + tmp_dir = tempfile.TemporaryDirectory().name + + def tearDown(self): + super().tearDown() + if osp.exists(self.tmp_dir): + shutil.rmtree(self.tmp_dir) + + default_frame_dir_prefix = self._get_default_frame_dir_prefix() + if osp.exists(default_frame_dir_prefix): + shutil.rmtree(osp.dirname(default_frame_dir_prefix)) + + def _get_default_frame_dir_prefix(self): + from data_juicer.ops.mapper.video_extract_frames_mapper import OP_NAME + default_frame_dir_prefix = osp.abspath(osp.join(self.data_path, + f'{Fields.multimodal_data_output_dir}/{OP_NAME}/')) + return default_frame_dir_prefix + + def _get_frames_list(self, filepath, frame_dir, frame_num): + frames_dir = osp.join(frame_dir, osp.splitext(osp.basename(filepath))[0]) + frames_list = [osp.join(frames_dir, f'frame_{i}.jpg') for i in range(frame_num)] + return frames_list + + def _get_frames_dir(self, filepath, frame_dir): + frames_dir = osp.join(frame_dir, osp.splitext(osp.basename(filepath))[0]) + return frames_dir + + def _sort_files(self, file_list): + return sorted(file_list, key=lambda x: int(re.search(r'(\d+)', x).group())) + + def test_duration(self): + ds_list = [{ + 'text': f'{SpecialTokens.video} 白色的小羊站在一旁讲话。旁边还有两只灰色猫咪和一只拉着灰狼的猫咪。', + 'videos': [self.vid1_path] + }, { + 'text': + f'{SpecialTokens.video} 身穿白色上衣的男子,拿着一个东西,拍打自己的胃部。{SpecialTokens.eoc}', + 'videos': [self.vid2_path] + }, { + 'text': + f'{SpecialTokens.video} 两个长头发的女子正坐在一张圆桌前讲话互动。 {SpecialTokens.eoc}', + 'videos': [self.vid3_path] + }] + + frame_num = 2 + frame_dir=os.path.join(self.tmp_dir, 'test1') + vid1_frame_dir = self._get_frames_dir(self.vid1_path, frame_dir) + vid2_frame_dir = self._get_frames_dir(self.vid2_path, frame_dir) + vid3_frame_dir = self._get_frames_dir(self.vid3_path, frame_dir) + + tgt_list = copy.deepcopy(ds_list) + tgt_list[0].update({Fields.video_frames: json.dumps({self.vid1_path: vid1_frame_dir})}) + tgt_list[1].update({Fields.video_frames: json.dumps({self.vid2_path: vid2_frame_dir})}) + tgt_list[2].update({Fields.video_frames: json.dumps({self.vid3_path: vid3_frame_dir})}) + + op = VideoExtractFramesMapper( + frame_sampling_method='uniform', + frame_num=frame_num, + duration=0, + frame_dir=frame_dir) + + dataset = Dataset.from_list(ds_list) + dataset = dataset.map(op.process, batch_size=2, num_proc=1) + res_list = dataset.to_list() + self.assertEqual(res_list, tgt_list) + self.assertListEqual( + self._sort_files(os.listdir(vid1_frame_dir)), + [f'frame_{i}.jpg' for i in range(frame_num)]) + self.assertListEqual( + self._sort_files(os.listdir(vid2_frame_dir)), + [f'frame_{i}.jpg' for i in range(frame_num)]) + self.assertListEqual( + self._sort_files(os.listdir(vid3_frame_dir)), + [f'frame_{i}.jpg' for i in range(frame_num)]) + + def test_uniform_sampling(self): + ds_list = [{ + 'text': f'{SpecialTokens.video} 白色的小羊站在一旁讲话。旁边还有两只灰色猫咪和一只拉着灰狼的猫咪。', + 'videos': [self.vid1_path] + }, { + 'text': + f'{SpecialTokens.video} 身穿白色上衣的男子,拿着一个东西,拍打自己的胃部。{SpecialTokens.eoc}', + 'videos': [self.vid2_path] + }, { + 'text': + f'{SpecialTokens.video} 两个长头发的女子正坐在一张圆桌前讲话互动。 {SpecialTokens.eoc}', + 'videos': [self.vid3_path] + }] + frame_num = 3 + frame_dir=os.path.join(self.tmp_dir, 'test1') + vid1_frame_dir = self._get_frames_dir(self.vid1_path, frame_dir) + vid2_frame_dir = self._get_frames_dir(self.vid2_path, frame_dir) + vid3_frame_dir = self._get_frames_dir(self.vid3_path, frame_dir) + + tgt_list = copy.deepcopy(ds_list) + tgt_list[0].update({Fields.video_frames: json.dumps({self.vid1_path: vid1_frame_dir})}) + tgt_list[1].update({Fields.video_frames: json.dumps({self.vid2_path: vid2_frame_dir})}) + tgt_list[2].update({Fields.video_frames: json.dumps({self.vid3_path: vid3_frame_dir})}) + + op = VideoExtractFramesMapper( + frame_sampling_method='uniform', + frame_num=frame_num, + duration=10, + frame_dir=frame_dir) + + dataset = Dataset.from_list(ds_list) + dataset = dataset.map(op.process, batch_size=2, num_proc=1) + res_list = dataset.to_list() + self.assertEqual(res_list, tgt_list) + self.assertListEqual( + self._sort_files(os.listdir(vid1_frame_dir)), + [f'frame_{i}.jpg' for i in range(3)]) + self.assertListEqual( + self._sort_files(os.listdir(vid2_frame_dir)), + [f'frame_{i}.jpg' for i in range(6)]) + self.assertListEqual( + self._sort_files(os.listdir(vid3_frame_dir)), + [f'frame_{i}.jpg' for i in range(12)]) + + def test_all_keyframes_sampling(self): + ds_list = [{ + 'text': f'{SpecialTokens.video} 白色的小羊站在一旁讲话。旁边还有两只灰色猫咪和一只拉着灰狼的猫咪。', + 'videos': [self.vid1_path] + }, { + 'text': + f'{SpecialTokens.video} 身穿白色上衣的男子,拿着一个东西,拍打自己的胃部。{SpecialTokens.eoc}' + \ + f'{SpecialTokens.video} 两个长头发的女子正坐在一张圆桌前讲话互动。 {SpecialTokens.eoc}', + 'videos': [self.vid2_path, self.vid3_path] + }, { + 'text': + f'{SpecialTokens.video} 两个长头发的女子正坐在一张圆桌前讲话互动。 {SpecialTokens.eoc}', + 'videos': [self.vid3_path] + }] + frame_dir=os.path.join(self.tmp_dir, 'test2') + vid1_frame_dir = self._get_frames_dir(self.vid1_path, frame_dir) + vid2_frame_dir = self._get_frames_dir(self.vid2_path, frame_dir) + vid3_frame_dir = self._get_frames_dir(self.vid3_path, frame_dir) + + tgt_list = copy.deepcopy(ds_list) + tgt_list[0].update({Fields.video_frames: + json.dumps({self.vid1_path: vid1_frame_dir})}) + tgt_list[1].update({Fields.video_frames: json.dumps({ + self.vid2_path: vid2_frame_dir, + self.vid3_path: vid3_frame_dir + })}) + tgt_list[2].update({Fields.video_frames: + json.dumps({self.vid3_path: vid3_frame_dir})}) + + op = VideoExtractFramesMapper( + frame_sampling_method='all_keyframes', + frame_dir=frame_dir, + duration=5) + + dataset = Dataset.from_list(ds_list) + dataset = dataset.map(op.process, batch_size=2, num_proc=2) + res_list = dataset.to_list() + self.assertEqual(res_list, tgt_list) + self.assertListEqual( + self._sort_files(os.listdir(vid1_frame_dir)), + [f'frame_{i}.jpg' for i in range(4)]) + self.assertListEqual( + self._sort_files(os.listdir(vid2_frame_dir)), + [f'frame_{i}.jpg' for i in range(5)]) + self.assertListEqual( + self._sort_files(os.listdir(vid3_frame_dir)), + [f'frame_{i}.jpg' for i in range(13)]) + + def test_default_frame_dir(self): + ds_list = [{ + 'text': f'{SpecialTokens.video} 白色的小羊站在一旁讲话。旁边还有两只灰色猫咪和一只拉着灰狼的猫咪。', + 'videos': [self.vid1_path] + }, { + 'text': + f'{SpecialTokens.video} 身穿白色上衣的男子,拿着一个东西,拍打自己的胃部。{SpecialTokens.eoc}', + 'videos': [self.vid2_path] + }, { + 'text': + f'{SpecialTokens.video} 两个长头发的女子正坐在一张圆桌前讲话互动。 {SpecialTokens.eoc}', + 'videos': [self.vid3_path] + }] + + frame_num = 2 + op = VideoExtractFramesMapper( + frame_sampling_method='uniform', + frame_num=frame_num, + duration=5, + ) + + vid1_frame_dir = op._get_default_frame_dir(self.vid1_path) + vid2_frame_dir = op._get_default_frame_dir(self.vid2_path) + vid3_frame_dir = op._get_default_frame_dir(self.vid3_path) + + tgt_list = copy.deepcopy(ds_list) + tgt_list[0].update({Fields.video_frames: json.dumps({self.vid1_path: vid1_frame_dir})}) + tgt_list[1].update({Fields.video_frames: json.dumps({self.vid2_path: vid2_frame_dir})}) + tgt_list[2].update({Fields.video_frames: json.dumps({self.vid3_path: vid3_frame_dir})}) + + dataset = Dataset.from_list(ds_list) + dataset = dataset.map(op.process, batch_size=2, num_proc=1) + res_list = dataset.to_list() + + frame_dir_prefix = self._get_default_frame_dir_prefix() + self.assertIn(frame_dir_prefix, osp.abspath(vid1_frame_dir)) + self.assertIn(frame_dir_prefix, osp.abspath(vid2_frame_dir)) + self.assertIn(frame_dir_prefix, osp.abspath(vid3_frame_dir)) + + self.assertEqual(res_list, tgt_list) + + self.assertListEqual( + self._sort_files(os.listdir(vid1_frame_dir)), + [f'frame_{i}.jpg' for i in range(4)]) + self.assertListEqual( + self._sort_files(os.listdir(vid2_frame_dir)), + [f'frame_{i}.jpg' for i in range(8)]) + self.assertListEqual( + self._sort_files(os.listdir(vid3_frame_dir)), + [f'frame_{i}.jpg' for i in range(18)]) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/ops/test_op_fusion.py b/tests/ops/test_op_fusion.py index 13d633134..04fc2a50e 100644 --- a/tests/ops/test_op_fusion.py +++ b/tests/ops/test_op_fusion.py @@ -1,13 +1,15 @@ import unittest from data_juicer.ops.load import load_ops +from data_juicer.ops.op_fusion import fuse_operators from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase class OpFusionTest(DataJuicerTestCaseBase): - def _run_op_fusion(self, original_process_list, target_process_list): - ops = load_ops(original_process_list, op_fusion=True) + def _run_op_fusion(self, original_process_list, target_process_list, probe_res=None): + ops = load_ops(original_process_list) + ops = fuse_operators(ops, probe_res) new_process_list = [op._op_cfg for op in ops] self.assertEqual(new_process_list, target_process_list) @@ -1014,6 +1016,950 @@ def test_different_intermediate_vars(self): ] self._run_op_fusion(original_process, target_process) + def test_regular_config_with_probe_res(self): + probed_speeds = [ + # single filter + {'speed': 100}, + + # mappers + {'speed': 2}, + {'speed': 1}, + {'speed': 4}, + {'speed': 5}, + {'speed': 3}, + + # filter groups + # fused OPs: ~2.56 + # single OP 1: 1 (slowest) + # single OP 2: 3 (fastest) + {'speed': 15}, # fusible + {'speed': 1}, + {'speed': 14}, # fusible + {'speed': 3}, + {'speed': 13}, # fusible + {'speed': 12}, # fusible + {'speed': 11}, # fusible + + # deduplicator + {'speed': 0.1}, + ] + + original_process = [{ + 'language_id_score_filter': { + 'lang': 'en', + 'min_score': 0.8, + 'text_key': 'text' + } + }, { + 'whitespace_normalization_mapper': { + 'text_key': 'text' + } + }, { + 'punctuation_normalization_mapper': { + 'text_key': 'text' + } + }, { + 'fix_unicode_mapper': { + 'text_key': 'text' + } + }, { + 'remove_words_with_incorrect_substrings_mapper': { + 'lang': 'en', + 'substrings': None, + 'text_key': 'text', + 'tokenization': False + } + }, { + 'remove_long_words_mapper': { + 'max_len': 25, + 'min_len': 1, + 'text_key': 'text' + } + }, { + 'words_num_filter': { + 'lang': 'en', + 'max_num': 100000, + 'min_num': 20, + 'text_key': 'text', + 'tokenization': False + } + }, { + 'character_repetition_filter': { + 'max_ratio': 0.106, + 'min_ratio': 0.0, + 'rep_len': 10, + 'text_key': 'text' + } + }, { + 'word_repetition_filter': { + 'lang': 'en', + 'max_ratio': 0.19, + 'min_ratio': 0.0, + 'rep_len': 5, + 'text_key': 'text', + 'tokenization': False + } + }, { + 'special_characters_filter': { + 'max_ratio': 0.4, + 'min_ratio': 0.0, + 'text_key': 'text' + } + }, { + 'stopwords_filter': { + 'lang': 'en', + 'min_ratio': 0.3, + 'text_key': 'text', + 'tokenization': False, + 'use_words_aug': False, + 'words_aug_group_sizes': [2], + 'words_aug_join_char': '' + } + }, { + 'flagged_words_filter': { + 'lang': 'en', + 'max_ratio': 0.01, + 'text_key': 'text', + 'tokenization': False, + 'use_words_aug': False, + 'words_aug_group_sizes': [2], + 'words_aug_join_char': '' + } + }, { + 'perplexity_filter': { + 'lang': 'en', + 'max_ppl': 1500, + 'text_key': 'text' + } + }, { + 'document_simhash_deduplicator': { + 'hamming_distance': 4, + 'ignore_pattern': '\\p{P}', + 'lowercase': True, + 'num_blocks': 6, + 'text_key': 'text', + 'tokenization': 'space', + 'window_size': 6 + } + }] + target_process = [ + { + 'language_id_score_filter': { + 'lang': 'en', + 'min_score': 0.8, + 'text_key': 'text' + } + }, + { + 'whitespace_normalization_mapper': { + 'text_key': 'text' + } + }, + { + 'punctuation_normalization_mapper': { + 'text_key': 'text' + } + }, + { + 'fix_unicode_mapper': { + 'text_key': 'text' + } + }, + { + 'remove_words_with_incorrect_substrings_mapper': { + 'lang': 'en', + 'substrings': None, + 'text_key': 'text', + 'tokenization': False + } + }, + { + 'remove_long_words_mapper': { + 'max_len': 25, + 'min_len': 1, + 'text_key': 'text' + } + }, + { + 'special_characters_filter': { + 'max_ratio': 0.4, + 'min_ratio': 0.0, + 'text_key': 'text' + } + }, + { + 'OpFusion:(words_num_filter,word_repetition_filter,stopwords_filter,flagged_words_filter,perplexity_filter)': # noqa: E501 + [ + { + 'words_num_filter': { + 'lang': 'en', + 'max_num': 100000, + 'min_num': 20, + 'text_key': 'text', + 'tokenization': False + } + }, + { + 'word_repetition_filter': { + 'lang': 'en', + 'max_ratio': 0.19, + 'min_ratio': 0.0, + 'rep_len': 5, + 'text_key': 'text', + 'tokenization': False + } + }, + { + 'stopwords_filter': { + 'lang': 'en', + 'min_ratio': 0.3, + 'text_key': 'text', + 'tokenization': False, + 'use_words_aug': False, + 'words_aug_group_sizes': [2], + 'words_aug_join_char': '' + } + }, + { + 'flagged_words_filter': { + 'lang': 'en', + 'max_ratio': 0.01, + 'text_key': 'text', + 'tokenization': False, + 'use_words_aug': False, + 'words_aug_group_sizes': [2], + 'words_aug_join_char': '' + } + }, + { + 'perplexity_filter': { + 'lang': 'en', + 'max_ppl': 1500, + 'text_key': 'text' + } + } + ] + }, + { + 'character_repetition_filter': { + 'max_ratio': 0.106, + 'min_ratio': 0.0, + 'rep_len': 10, + 'text_key': 'text' + } + }, + { + 'document_simhash_deduplicator': { + 'hamming_distance': 4, + 'ignore_pattern': '\\p{P}', + 'lowercase': True, + 'num_blocks': 6, + 'text_key': 'text', + 'tokenization': 'space', + 'window_size': 6 + } + } + ] + self._run_op_fusion(original_process, target_process, probed_speeds) + + def test_not_enough_fusible_ops_to_fuse_with_probe_res(self): + # still apply reordering: + # - ordinary ops + # - ops with InterVars.lines + # - ops with InterVars.words + probe_res_list = [ + {'speed': 3}, + {'speed': 1}, + {'speed': 4}, + {'speed': 2}, + ] + + original_process = [{ + 'language_id_score_filter': { + 'lang': 'en', + 'min_score': 0.8, + 'text_key': 'text' + } + }, { + 'words_num_filter': { + 'lang': 'en', + 'max_num': 100000, + 'min_num': 20, + 'text_key': 'text', + 'tokenization': False + } + }, { + 'character_repetition_filter': { + 'max_ratio': 0.106, + 'min_ratio': 0.0, + 'rep_len': 10, + 'text_key': 'text' + } + }, { + 'average_line_length_filter': { + 'min_len': 10, + 'text_key': 'text' + } + }] + target_process = [{ + 'character_repetition_filter': { + 'max_ratio': 0.106, + 'min_ratio': 0.0, + 'rep_len': 10, + 'text_key': 'text' + } + }, { + 'language_id_score_filter': { + 'lang': 'en', + 'min_score': 0.8, + 'text_key': 'text' + } + }, { + 'average_line_length_filter': { + 'min_len': 10, + 'text_key': 'text' + } + }, { + 'words_num_filter': { + 'lang': 'en', + 'max_num': 100000, + 'min_num': 20, + 'text_key': 'text', + 'tokenization': False + } + }] + self._run_op_fusion(original_process, target_process, probe_res_list) + + def test_multiple_groups_with_probe_res(self): + probe_res_list = [ + # group 1 + # fused filter will be put before the single filter + {'speed': 10}, + {'speed': 10}, + {'speed': 1}, + + # mappers + {'speed': 4}, + {'speed': 2}, + {'speed': 5}, + {'speed': 3}, + {'speed': 1}, + + # group 2 + # fused filter will be put after those two single filters + {'speed': 1}, # fusible + {'speed': 8}, + {'speed': 1}, # fusible + {'speed': 10}, + {'speed': 1}, # fusible + + # deduplicator + {'speed': 1}, + ] + + original_process = [{ + 'stopwords_filter': { + 'lang': 'en', + 'min_ratio': 0.3, + 'text_key': 'text', + 'tokenization': False, + 'use_words_aug': False, + 'words_aug_group_sizes': [2], + 'words_aug_join_char': '' + } + }, { + 'flagged_words_filter': { + 'lang': 'en', + 'max_ratio': 0.01, + 'text_key': 'text', + 'tokenization': False, + 'use_words_aug': False, + 'words_aug_group_sizes': [2], + 'words_aug_join_char': '' + } + }, { + 'language_id_score_filter': { + 'lang': 'en', + 'min_score': 0.8, + 'text_key': 'text' + } + }, { + 'whitespace_normalization_mapper': { + 'text_key': 'text' + } + }, { + 'punctuation_normalization_mapper': { + 'text_key': 'text' + } + }, { + 'fix_unicode_mapper': { + 'text_key': 'text' + } + }, { + 'remove_words_with_incorrect_substrings_mapper': { + 'lang': 'en', + 'substrings': None, + 'text_key': 'text', + 'tokenization': False + } + }, { + 'remove_long_words_mapper': { + 'max_len': 25, + 'min_len': 1, + 'text_key': 'text' + } + }, { + 'words_num_filter': { + 'lang': 'en', + 'max_num': 100000, + 'min_num': 20, + 'text_key': 'text', + 'tokenization': False + } + }, { + 'character_repetition_filter': { + 'max_ratio': 0.106, + 'min_ratio': 0.0, + 'rep_len': 10, + 'text_key': 'text' + } + }, { + 'word_repetition_filter': { + 'lang': 'en', + 'max_ratio': 0.19, + 'min_ratio': 0.0, + 'rep_len': 5, + 'text_key': 'text', + 'tokenization': False + } + }, { + 'special_characters_filter': { + 'max_ratio': 0.4, + 'min_ratio': 0.0, + 'text_key': 'text' + } + }, { + 'perplexity_filter': { + 'lang': 'en', + 'max_ppl': 1500, + 'text_key': 'text' + } + }, { + 'document_simhash_deduplicator': { + 'hamming_distance': 4, + 'ignore_pattern': '\\p{P}', + 'lowercase': True, + 'num_blocks': 6, + 'text_key': 'text', + 'tokenization': 'space', + 'window_size': 6 + } + }] + target_process = [ + { + 'OpFusion:(stopwords_filter,flagged_words_filter)': [{ + 'stopwords_filter': { + 'lang': 'en', + 'min_ratio': 0.3, + 'text_key': 'text', + 'tokenization': False, + 'use_words_aug': False, + 'words_aug_group_sizes': [2], + 'words_aug_join_char': '' + } + }, { + 'flagged_words_filter': { + 'lang': 'en', + 'max_ratio': 0.01, + 'text_key': 'text', + 'tokenization': False, + 'use_words_aug': False, + 'words_aug_group_sizes': [2], + 'words_aug_join_char': '' + } + }] + }, + { + 'language_id_score_filter': { + 'lang': 'en', + 'min_score': 0.8, + 'text_key': 'text' + } + }, + { + 'whitespace_normalization_mapper': { + 'text_key': 'text' + } + }, + { + 'punctuation_normalization_mapper': { + 'text_key': 'text' + } + }, + { + 'fix_unicode_mapper': { + 'text_key': 'text' + } + }, + { + 'remove_words_with_incorrect_substrings_mapper': { + 'lang': 'en', + 'substrings': None, + 'text_key': 'text', + 'tokenization': False + } + }, + { + 'remove_long_words_mapper': { + 'max_len': 25, + 'min_len': 1, + 'text_key': 'text' + } + }, + { + 'special_characters_filter': { + 'max_ratio': 0.4, + 'min_ratio': 0.0, + 'text_key': 'text' + } + }, + { + 'character_repetition_filter': { + 'max_ratio': 0.106, + 'min_ratio': 0.0, + 'rep_len': 10, + 'text_key': 'text' + } + }, + { + 'OpFusion:(words_num_filter,word_repetition_filter,perplexity_filter)': # noqa: E501 + [ + { + 'words_num_filter': { + 'lang': 'en', + 'max_num': 100000, + 'min_num': 20, + 'text_key': 'text', + 'tokenization': False + } + }, + { + 'word_repetition_filter': { + 'lang': 'en', + 'max_ratio': 0.19, + 'min_ratio': 0.0, + 'rep_len': 5, + 'text_key': 'text', + 'tokenization': False + } + }, + { + 'perplexity_filter': { + 'lang': 'en', + 'max_ppl': 1500, + 'text_key': 'text' + } + } + ] + }, + { + 'document_simhash_deduplicator': { + 'hamming_distance': 4, + 'ignore_pattern': '\\p{P}', + 'lowercase': True, + 'num_blocks': 6, + 'text_key': 'text', + 'tokenization': 'space', + 'window_size': 6 + } + } + ] + self._run_op_fusion(original_process, target_process, probe_res_list) + + def test_only_fusible_ops_with_probe_res(self): + probe_res_list = [ + {'speed': 1}, + {'speed': 1}, + {'speed': 1}, + {'speed': 1}, + {'speed': 1}, + ] + + original_process = [{ + 'words_num_filter': { + 'lang': 'en', + 'max_num': 100000, + 'min_num': 20, + 'text_key': 'text', + 'tokenization': False + } + }, { + 'word_repetition_filter': { + 'lang': 'en', + 'max_ratio': 0.19, + 'min_ratio': 0.0, + 'rep_len': 5, + 'text_key': 'text', + 'tokenization': False + } + }, { + 'stopwords_filter': { + 'lang': 'en', + 'min_ratio': 0.3, + 'text_key': 'text', + 'tokenization': False, + 'use_words_aug': False, + 'words_aug_group_sizes': [2], + 'words_aug_join_char': '' + } + }, { + 'flagged_words_filter': { + 'lang': 'en', + 'max_ratio': 0.01, + 'text_key': 'text', + 'tokenization': False, + 'use_words_aug': False, + 'words_aug_group_sizes': [2], + 'words_aug_join_char': '' + } + }, { + 'perplexity_filter': { + 'lang': 'en', + 'max_ppl': 1500, + 'text_key': 'text' + } + }] + target_process = [{ + 'OpFusion:(words_num_filter,word_repetition_filter,stopwords_filter,flagged_words_filter,perplexity_filter)': # noqa: E501 + [ + { + 'words_num_filter': { + 'lang': 'en', + 'max_num': 100000, + 'min_num': 20, + 'text_key': 'text', + 'tokenization': False + } + }, + { + 'word_repetition_filter': { + 'lang': 'en', + 'max_ratio': 0.19, + 'min_ratio': 0.0, + 'rep_len': 5, + 'text_key': 'text', + 'tokenization': False + } + }, + { + 'stopwords_filter': { + 'lang': 'en', + 'min_ratio': 0.3, + 'text_key': 'text', + 'tokenization': False, + 'use_words_aug': False, + 'words_aug_group_sizes': [2], + 'words_aug_join_char': '' + } + }, + { + 'flagged_words_filter': { + 'lang': 'en', + 'max_ratio': 0.01, + 'text_key': 'text', + 'tokenization': False, + 'use_words_aug': False, + 'words_aug_group_sizes': [2], + 'words_aug_join_char': '' + } + }, + { + 'perplexity_filter': { + 'lang': 'en', + 'max_ppl': 1500, + 'text_key': 'text' + } + } + ] + }] + self._run_op_fusion(original_process, target_process, probe_res_list) + + def test_different_intermediate_vars_with_probe_res(self): + probe_res_list = [ + # single filter + {'speed': 1}, + + # mappers + {'speed': 5}, + {'speed': 3}, + {'speed': 1}, + {'speed': 2}, + {'speed': 4}, + + # filter group + # single 1: 1 (2) + # single 2: 0.5 (3) + # group 1: 0.04 (4) + # group 2: 1.5 (1) + {'speed': 0.1}, # group 1 + {'speed': 1}, + {'speed': 3}, # group 2 + {'speed': 0.2}, # group 1 + {'speed': 0.5}, + {'speed': 0.3}, # group 1 + {'speed': 0.4}, # group 1 + {'speed': 3}, # group 2 + {'speed': 0.5}, # group 1 + + # deduplicator + {'speed': 1}, + ] + + original_process = [{ + 'language_id_score_filter': { + 'lang': 'en', + 'min_score': 0.8, + 'text_key': 'text' + } + }, { + 'whitespace_normalization_mapper': { + 'text_key': 'text' + } + }, { + 'punctuation_normalization_mapper': { + 'text_key': 'text' + } + }, { + 'fix_unicode_mapper': { + 'text_key': 'text' + } + }, { + 'remove_words_with_incorrect_substrings_mapper': { + 'lang': 'en', + 'substrings': None, + 'text_key': 'text', + 'tokenization': False + } + }, { + 'remove_long_words_mapper': { + 'max_len': 25, + 'min_len': 1, + 'text_key': 'text' + } + }, { + 'words_num_filter': { + 'lang': 'en', + 'max_num': 100000, + 'min_num': 20, + 'text_key': 'text', + 'tokenization': False + } + }, { + 'character_repetition_filter': { + 'max_ratio': 0.106, + 'min_ratio': 0.0, + 'rep_len': 10, + 'text_key': 'text' + } + }, { + 'average_line_length_filter': { + 'min_len': 10, + 'text_key': 'text' + } + }, { + 'word_repetition_filter': { + 'lang': 'en', + 'max_ratio': 0.19, + 'min_ratio': 0.0, + 'rep_len': 5, + 'text_key': 'text', + 'tokenization': False + } + }, { + 'special_characters_filter': { + 'max_ratio': 0.4, + 'min_ratio': 0.0, + 'text_key': 'text' + } + }, { + 'stopwords_filter': { + 'lang': 'en', + 'min_ratio': 0.3, + 'text_key': 'text', + 'tokenization': False, + 'use_words_aug': False, + 'words_aug_group_sizes': [2], + 'words_aug_join_char': '' + } + }, { + 'flagged_words_filter': { + 'lang': 'en', + 'max_ratio': 0.01, + 'text_key': 'text', + 'tokenization': False, + 'use_words_aug': False, + 'words_aug_group_sizes': [2], + 'words_aug_join_char': '' + } + }, { + 'maximum_line_length_filter': { + 'min_len': 20, + 'text_key': 'text' + } + }, { + 'perplexity_filter': { + 'lang': 'en', + 'max_ppl': 1500, + 'text_key': 'text' + } + }, { + 'document_simhash_deduplicator': { + 'hamming_distance': 4, + 'ignore_pattern': '\\p{P}', + 'lowercase': True, + 'num_blocks': 6, + 'text_key': 'text', + 'tokenization': 'space', + 'window_size': 6 + } + }] + target_process = [ + { + 'language_id_score_filter': { + 'lang': 'en', + 'min_score': 0.8, + 'text_key': 'text' + } + }, + { + 'whitespace_normalization_mapper': { + 'text_key': 'text' + } + }, + { + 'punctuation_normalization_mapper': { + 'text_key': 'text' + } + }, + { + 'fix_unicode_mapper': { + 'text_key': 'text' + } + }, + { + 'remove_words_with_incorrect_substrings_mapper': { + 'lang': 'en', + 'substrings': None, + 'text_key': 'text', + 'tokenization': False + } + }, + { + 'remove_long_words_mapper': { + 'max_len': 25, + 'min_len': 1, + 'text_key': 'text' + } + }, + { + 'OpFusion:(average_line_length_filter,maximum_line_length_filter)': # noqa: E501 + [ + { + 'average_line_length_filter': { + 'min_len': 10, + 'text_key': 'text', + } + }, + { + 'maximum_line_length_filter': { + 'min_len': 20, + 'text_key': 'text', + } + } + ] + }, + { + 'character_repetition_filter': { + 'max_ratio': 0.106, + 'min_ratio': 0.0, + 'rep_len': 10, + 'text_key': 'text' + } + }, + { + 'special_characters_filter': { + 'max_ratio': 0.4, + 'min_ratio': 0.0, + 'text_key': 'text' + } + }, + { + 'OpFusion:(words_num_filter,word_repetition_filter,stopwords_filter,flagged_words_filter,perplexity_filter)': # noqa: E501 + [ + { + 'words_num_filter': { + 'lang': 'en', + 'max_num': 100000, + 'min_num': 20, + 'text_key': 'text', + 'tokenization': False + } + }, + { + 'word_repetition_filter': { + 'lang': 'en', + 'max_ratio': 0.19, + 'min_ratio': 0.0, + 'rep_len': 5, + 'text_key': 'text', + 'tokenization': False + } + }, + { + 'stopwords_filter': { + 'lang': 'en', + 'min_ratio': 0.3, + 'text_key': 'text', + 'tokenization': False, + 'use_words_aug': False, + 'words_aug_group_sizes': [2], + 'words_aug_join_char': '' + } + }, + { + 'flagged_words_filter': { + 'lang': 'en', + 'max_ratio': 0.01, + 'text_key': 'text', + 'tokenization': False, + 'use_words_aug': False, + 'words_aug_group_sizes': [2], + 'words_aug_join_char': '' + } + }, + { + 'perplexity_filter': { + 'lang': 'en', + 'max_ppl': 1500, + 'text_key': 'text' + } + } + ] + }, + { + 'document_simhash_deduplicator': { + 'hamming_distance': 4, + 'ignore_pattern': '\\p{P}', + 'lowercase': True, + 'num_blocks': 6, + 'text_key': 'text', + 'tokenization': 'space', + 'window_size': 6 + } + } + ] + self._run_op_fusion(original_process, target_process, probe_res_list) + if __name__ == '__main__': unittest.main() diff --git a/tools/dj_install.py b/tools/dj_install.py new file mode 100644 index 000000000..54b0b3dd3 --- /dev/null +++ b/tools/dj_install.py @@ -0,0 +1,65 @@ +import os +import subprocess +import sys +import tempfile + +from loguru import logger + +from data_juicer.config import init_configs +from data_juicer.utils.auto_install_mapping import OPS_TO_PKG + +require_version_paths = ['./environments/science_requires.txt'] + + +def main(): + cfg = init_configs() + + # get the ops in the recipe + op_names = [list(op.keys())[0] for op in cfg.process] + recipe_reqs = [] + for op_name in op_names: + recipe_reqs.extend(OPS_TO_PKG[op_name]) + recipe_reqs = list(set(recipe_reqs)) + + # get the package version limit of Data-Juicer + version_map, reqs = {}, [] + for path in require_version_paths: + if not os.path.exists(path): + logger.warning(f'target file does not exist: {path}') + else: + with open(path, 'r', encoding='utf-8') as fin: + reqs += [x.strip() for x in fin.read().splitlines()] + for req in reqs: + clean_req = req.replace('<', + ' ').replace('>', + ' ').replace('=', + ' ').split(' ')[0] + version_map[clean_req] = req + + # generate require file for the recipe + with tempfile.NamedTemporaryFile(delete=False, mode='w') as temp_file: + temp_file_path = temp_file.name + for req in recipe_reqs: + if req in version_map: + temp_file.write(version_map[req] + '\n') + else: + temp_file.write(req + '\n') + + # install by calling 'pip install -r ...' + try: + subprocess.check_call( + [sys.executable, '-m', 'pip', 'install', '-r', temp_file_path]) + logger.info('Requirements were installed successfully.') + except subprocess.CalledProcessError as e: + logger.info( + f'An error occurred while installing the requirements: {e}') + if os.path.exists(temp_file_path): + os.remove(temp_file_path) + sys.exit(1) + finally: + if os.path.exists(temp_file_path): + os.remove(temp_file_path) + + +if __name__ == '__main__': + main()