diff --git a/tools/distributed_deduplication/README.md b/tools/distributed_deduplication/README.md index 012cde561..23464af32 100644 --- a/tools/distributed_deduplication/README.md +++ b/tools/distributed_deduplication/README.md @@ -26,7 +26,6 @@ python spark_dedup.py \ \ \ [--tokenizer ] \ - [--threshold ] \ [--num_features ] \ [--num_hashtables ] \ [--text_key ] \ @@ -39,7 +38,6 @@ python spark_dedup.py --help - `dataset_path`: the input dataset path. The suffix of the path should be one of the `[json, jsonl, parquet]`. - `result_path`: the path to store the dataset with prediction results. The suffix of the path should be one of the `[json, jsonl, parquet]`. - `tokenizer`: (Optional. Default: None) the tokenizer to tokenize texts to be classified. If it's None, the [standard Tokenizer](https://spark.apache.org/docs/latest/api/python/reference/api/pyspark.ml.feature.Tokenizer.html#tokenizer) of PySpark will be used. Besides, you can use one of the tokenizers we provide `[zh.sp.model, code.sp.model]`. Or you can set it to a path to your own [sentencepiece](https://github.com/google/sentencepiece) model. -- `threshold`: (Optional. Default: 0.7) If the Jaccard similarity between two documents exceeds a predetermined threshold, they are considered duplicates. The accuracy of deduplication depends on the similarity threshold set. The lower the threshold, the more duplicates can be identified, but this may also increase the risk of false positives. You need to adjust the threshold based on your requirements for deduplication accuracy. - `num_features`: the number of features that HashingTF generates. Default with 1047576 as mentioned in megatron-turing-nlg paper. - `num_hashtables`: (Optional. Default: 10) the number of hashes used in MinHashLSH. Default with 10 hashes as mentioned in the GPT3 paper. - `text_key`: (Optional. Default: "text") the field name to store texts to be classified in the input dataset. diff --git a/tools/distributed_deduplication/README_ZH.md b/tools/distributed_deduplication/README_ZH.md index 1af9469ab..6a74006e7 100644 --- a/tools/distributed_deduplication/README_ZH.md +++ b/tools/distributed_deduplication/README_ZH.md @@ -17,7 +17,6 @@ python spark_dedup.py \ \ \ [--tokenizer ] \ - [--threshold ] \ [--num_features ] \ [--num_hashtables ] \ [--text_key ] \ @@ -30,7 +29,6 @@ python spark_dedup.py --help - `dataset_path`:输入数据集路径。路径的后缀应该是`[json, jsonl, parquet]`中的一个。 - `result_path`:存储带有预测结果数据集的路径。路径的后缀应该是`[json, jsonl, parquet]`中的一个。 - `tokenizer`:(可选。默认值:None)用于对将要分类的文本进行分词的分词器。如果为None,将使用PySpark的[标准分词器](https://spark.apache.org/docs/latest/api/python/reference/api/pyspark.ml.feature.Tokenizer.html#tokenizer)。此外,你可以使用我们提供的分词器`[zh.sp.model, code.sp.model]`中的一个,或者你可以将其设置为你自己的[sentencepiece](https://github.com/google/sentencepiece)模型的路径。 -- `threshold`:(可选。默认值:0.7)如果两个文档之间的Jaccard相似度超出预设定的阈值,则它们会被认为是重复的。去重的准确度取决于相似度阈值的设定。阈值越低,能识别的重复项就越多,但这可能增加误报的风险。你需要根据你对去重准确度的需求来调整阈值。 - `num_features`:HashingTF生成的特征数量。默认值为1047576,如megatron-turing-nlg论文中所述。 - `num_hashtables`:(可选。默认值:10)MinHashLSH中使用的哈希数量。默认使用10个哈希,如GPT-3论文中所述。 - `text_key`:(可选。默认值:"text")输入数据集中用于存储待分类文本的字段名称。 diff --git a/tools/distributed_deduplication/dedup_utils.py b/tools/distributed_deduplication/dedup_utils.py index b04c3631a..25c46d4ed 100644 --- a/tools/distributed_deduplication/dedup_utils.py +++ b/tools/distributed_deduplication/dedup_utils.py @@ -2,7 +2,7 @@ # https://github.com/bigcode-project/bigcode-dataset/blob/main/near_deduplication/minhash_deduplication_spark.py # -------------------------------------------------------- -from typing import Union +from typing import List, Tuple, Union from loguru import logger from pyspark import SparkConf @@ -34,6 +34,29 @@ def init_spark(master_url: Union[str, None] = None, return spark +def generate_edges(nodes: List[int]) -> List[Tuple[int, int]]: + """ + Generate edges from a cluster. Instead of generating N^2 edges, + we only need all nodes align to a single node, + since we will be running connected components on the edges later. + + Parameters + ---------- + nodes : List[int] + The list of nodes in the cluster. + + Returns + ------- + List[Tuple[int, int]] + The list of edges. + """ + if len(nodes) <= 1: + return [] + + min_node = min(nodes) + return [(n, min_node) for n in nodes if n != min_node] + + # Connected Components in MapReduce and Beyond def large_star_map(edge): return [(edge[0], edge[1]), (edge[1], edge[0])] diff --git a/tools/distributed_deduplication/spark_dedup.py b/tools/distributed_deduplication/spark_dedup.py index 15119123f..871f1811b 100644 --- a/tools/distributed_deduplication/spark_dedup.py +++ b/tools/distributed_deduplication/spark_dedup.py @@ -6,10 +6,10 @@ from loguru import logger from pyspark.ml.feature import HashingTF, MinHashLSH, Tokenizer from pyspark.sql import functions as F -from pyspark.sql.functions import col -from pyspark.sql.functions import min as mincol +from pyspark.sql.functions import posexplode from tools.distributed_deduplication.dedup_utils import (find_components, + generate_edges, init_spark) from tools.quality_classifier.qc_utils import (export_result, load_dataset, tokenize_dataset) @@ -19,7 +19,6 @@ def dedup_dataset(dataset_path: str, result_path: str, tokenizer: Union[str, None] = None, - threshold: float = 0.7, num_features: int = 1047576, num_hashtables: int = 10, text_key: str = 'text', @@ -33,13 +32,6 @@ def dedup_dataset(dataset_path: str, default, which means using the standard Tokenizer of PySpark. You can use one of ["zh.sp.model", "code.sp.model"] we provided, or you can set it to the path to your own sentencepiece model. - :param threshold: if the Jaccard similarity between two documents - exceeds a predetermined threshold, they are considered duplicates. - The accuracy of deduplication depends on the similarity threshold set. - The lower the threshold, the more duplicates can be identified, - but this may also increase the risk of false positives. - You need to adjust the threshold based on your requirements for - deduplication accuracy. :param num_features: the number of features that HashingTF generates. Default with 1047576 as mentioned in megatron-turing-nlg paper. :param num_hashtables: the number of hashes used in MinHashLSH. @@ -73,16 +65,13 @@ def dedup_dataset(dataset_path: str, ds = model.transform(ds) - self_join = model.approxSimilarityJoin( - ds, ds, threshold=threshold, - distCol='JaccardDistance').filter('datasetA.id > datasetB.id').select( - col('datasetA.id').alias('idA'), - col('datasetB.id').alias('idB'), col('JaccardDistance')) + ds = ds.select('id', posexplode('hashes').alias('band_idx', 'hash_vector')) - self_dup_edge = self_join.groupBy('idA').agg( - mincol(col('idB')).alias('min_idB')) + record = ds.rdd.map(lambda x: + (x['band_idx'], int(x['hash_vector'][0]), x['id'])) - edges = (self_dup_edge.rdd.map(lambda row: (row.idA, row.min_idB))) + edges = (record.groupBy(lambda x: (x[0], x[1])).flatMap( + lambda x: generate_edges([i[2] for i in x[1]])).distinct().cache()) results = find_components(edges) if len(results) == 0: