-
Notifications
You must be signed in to change notification settings - Fork 192
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Document deduplication tools using pyspark. (#290)
* init spark dedup * add doc strings and Readme.md * remove commented debug line * fix typo * fix comments * add readme_zh, add more config * fix typos * fix typo * fix for efficiency --------- Tested on larger scale
- Loading branch information
1 parent
ab5f0f6
commit c1a8aa8
Showing
5 changed files
with
276 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
# Distributed Fuzzy Deduplication Tools | ||
|
||
Help you reproduce and apply fuzzy deduplication to your web datasets similar to GPT-3 paper. | ||
|
||
**The General Description about Fuzzy Deduplication**: | ||
|
||
The fuzzy deduplication method here mainly refer to the fuzzy deduplication method mentioned in the Appendix A of [GPT-3 paper](https://arxiv.org/pdf/2005.14165.pdf). | ||
|
||
> To further improve model quality and prevent overfitting (which becomes increasingly important as model capacity increases), we fuzzily deduplicated documents (i.e. removed documents with high overlap with other documents) within each dataset using Spark’s MinHashLSH implementation with 10 hashes, using **the same features as were used for classification above**. We also fuzzily removed WebText from Common Crawl. Overall this decreased dataset size by an average of 10%. | ||
As the paper mentioned, the features used are the same as were used for quality classification, as described in [quality_classifier tools](../quality_classifier/README.md). | ||
|
||
The whole toolkit is based on PySpark. | ||
|
||
- tokenizer: Since the standard tokenizer of pyspark have trouble tokenizing text in languages such as Chinese, the [standard Tokenizer](https://spark.apache.org/docs/latest/api/python/reference/api/pyspark.ml.feature.Tokenizer.html#tokenizer) of PySpark or [sentencepiece](https://github.com/google/sentencepiece) model are used. | ||
- feature extractor: [HashingTF](https://spark.apache.org/docs/latest/api/python/reference/api/pyspark.ml.feature.HashingTF.html) | ||
- minhashLSH: [MinHashLSH](https://spark.apache.org/docs/latest/api/python/reference/api/pyspark.ml.feature.MinHashLSH.html) | ||
|
||
|
||
## Usage | ||
|
||
Use `spark_dedup.py` to fuzzily deduplicate documents. | ||
|
||
```shell | ||
python spark_dedup.py \ | ||
<dataset_path> \ | ||
<result_path> \ | ||
[--tokenizer <tokenizer_type>] \ | ||
[--num_features <num_features>] \ | ||
[--num_hashtables <num_hashtables>] \ | ||
[--text_key <text_key>] \ | ||
[--master_url <master_url>] | ||
|
||
# print the usage message | ||
python spark_dedup.py --help | ||
``` | ||
|
||
- `dataset_path`: the input dataset path. The suffix of the path should be one of the `[json, jsonl, parquet]`. | ||
- `result_path`: the path to store the dataset with prediction results. The suffix of the path should be one of the `[json, jsonl, parquet]`. | ||
- `tokenizer`: (Optional. Default: None) the tokenizer to tokenize texts to be classified. If it's None, the [standard Tokenizer](https://spark.apache.org/docs/latest/api/python/reference/api/pyspark.ml.feature.Tokenizer.html#tokenizer) of PySpark will be used. Besides, you can use one of the tokenizers we provide `[zh.sp.model, code.sp.model]`. Or you can set it to a path to your own [sentencepiece](https://github.com/google/sentencepiece) model. | ||
- `num_features`: the number of features that HashingTF generates. Default with 1047576 as mentioned in megatron-turing-nlg paper. | ||
- `num_hashtables`: (Optional. Default: 10) the number of hashes used in MinHashLSH. Default with 10 hashes as mentioned in the GPT3 paper. | ||
- `text_key`: (Optional. Default: "text") the field name to store texts to be classified in the input dataset. | ||
- `master_url`: (Optional. Default: None) the master url for spark config. If None, then run with "local[*]" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
# 分布式模糊去重工具 | ||
复现与GPT-3论文相似的模糊去重方法并应用到您的Web数据集。 | ||
|
||
**模糊去重的一般描述**: | ||
这里的模糊去重方法主要指的是 [GPT-3论文](https://arxiv.org/pdf/2005.14165.pdf)附录A中提到的模糊去重方法。 | ||
> 为了进一步提高模型质量并防止过拟合(随着模型容量的增加越来越重要),我们使用Spark的MinHashLSH实现对每个数据集中的文档进行了模糊去重(即移除了与其他文档高度重合的文档),使用了10个哈希,使用的**特征与上面用于分类的特征相同**。我们还从Common Crawl中模糊移除了WebText。总体而言,这使数据集的大小平均减少了10%。 | ||
正如论文中提到的,使用的特征与前文描述的质量分类器([quality_classifier tools](../quality_classifier/README.md))中所用的一致。 | ||
整个工具包基于PySpark。 | ||
- 分词器:由于pyspark的标准分词器无法很好地处理中文等语言的文本,所以使用了PySpark的[标准分词器](https://spark.apache.org/docs/latest/api/python/reference/api/pyspark.ml.feature.Tokenizer.html#tokenizer)或[sentencepiece](https://github.com/google/sentencepiece)模型。 | ||
- 特征提取器:[HashingTF](https://spark.apache.org/docs/latest/api/python/reference/api/pyspark.ml.feature.HashingTF.html) | ||
- minhashLSH:[MinHashLSH](https://spark.apache.org/docs/latest/api/python/reference/api/pyspark.ml.feature.MinHashLSH.html) | ||
|
||
## 使用方法 | ||
使用`spark_dedup.py`对文档进行模糊去重。 | ||
```shell | ||
python spark_dedup.py \ | ||
<dataset_path> \ | ||
<result_path> \ | ||
[--tokenizer <tokenizer_type>] \ | ||
[--num_features <num_features>] \ | ||
[--num_hashtables <num_hashtables>] \ | ||
[--text_key <text_key>] \ | ||
[--master_url <master_url>] | ||
# 打印使用信息 | ||
python spark_dedup.py --help | ||
|
||
``` | ||
|
||
- `dataset_path`:输入数据集路径。路径的后缀应该是`[json, jsonl, parquet]`中的一个。 | ||
- `result_path`:存储带有预测结果数据集的路径。路径的后缀应该是`[json, jsonl, parquet]`中的一个。 | ||
- `tokenizer`:(可选。默认值:None)用于对将要分类的文本进行分词的分词器。如果为None,将使用PySpark的[标准分词器](https://spark.apache.org/docs/latest/api/python/reference/api/pyspark.ml.feature.Tokenizer.html#tokenizer)。此外,你可以使用我们提供的分词器`[zh.sp.model, code.sp.model]`中的一个,或者你可以将其设置为你自己的[sentencepiece](https://github.com/google/sentencepiece)模型的路径。 | ||
- `num_features`:HashingTF生成的特征数量。默认值为1047576,如megatron-turing-nlg论文中所述。 | ||
- `num_hashtables`:(可选。默认值:10)MinHashLSH中使用的哈希数量。默认使用10个哈希,如GPT-3论文中所述。 | ||
- `text_key`:(可选。默认值:"text")输入数据集中用于存储待分类文本的字段名称。 | ||
- `master_url`:(可选。默认值:None)用于Spark配置的master URL。如果为空,则默认运行在"local[*]"模式下。 |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
# The Star-Graph-Connected-Components (SGCC) algorithm here referenced from: | ||
# https://github.com/bigcode-project/bigcode-dataset/blob/main/near_deduplication/minhash_deduplication_spark.py | ||
# -------------------------------------------------------- | ||
|
||
from typing import List, Tuple, Union | ||
|
||
from loguru import logger | ||
from pyspark import SparkConf | ||
from pyspark.sql import SparkSession | ||
|
||
|
||
def init_spark(master_url: Union[str, None] = None, | ||
spark_executor_memory=None, | ||
spark_driver_memory=None, | ||
spark_executor_memoryOverhead=None): | ||
if not spark_executor_memory: | ||
spark_executor_memory = '64g' | ||
if not spark_driver_memory: | ||
spark_driver_memory = '64g' | ||
if not spark_executor_memoryOverhead: | ||
spark_executor_memoryOverhead = '20000' | ||
if not master_url: | ||
master_url = 'local[*]' | ||
conf = SparkConf() | ||
conf.set('spark.app.name', 'MinHashLSH') | ||
conf.set('spark.debug.maxToStringFields', '100') | ||
conf.set('spark.master', master_url) | ||
conf.set('spark.executor.memory', spark_executor_memory) | ||
conf.set('spark.driver.memory', spark_driver_memory) | ||
conf.set('spark.sql.execution.arrow.pyspark.enabled', 'true') | ||
conf.set('spark.executor.memoryOverhead', spark_executor_memoryOverhead) | ||
spark = SparkSession.builder.config(conf=conf).getOrCreate() | ||
logger.info('Spark initialization done.') | ||
return spark | ||
|
||
|
||
def generate_edges(nodes: List[int]) -> List[Tuple[int, int]]: | ||
""" | ||
Generate edges from a cluster. Instead of generating N^2 edges, | ||
we only need all nodes align to a single node, | ||
since we will be running connected components on the edges later. | ||
Parameters | ||
---------- | ||
nodes : List[int] | ||
The list of nodes in the cluster. | ||
Returns | ||
------- | ||
List[Tuple[int, int]] | ||
The list of edges. | ||
""" | ||
if len(nodes) <= 1: | ||
return [] | ||
|
||
min_node = min(nodes) | ||
return [(n, min_node) for n in nodes if n != min_node] | ||
|
||
|
||
# Connected Components in MapReduce and Beyond | ||
def large_star_map(edge): | ||
return [(edge[0], edge[1]), (edge[1], edge[0])] | ||
|
||
|
||
def large_star_reduce(group): | ||
x, neighbors = group | ||
nodes = [x] + list(neighbors) | ||
minimum = min(nodes) | ||
return [(n, minimum) for n in nodes if n > x] | ||
|
||
|
||
def small_star_map(edge): | ||
x, y = edge | ||
if y <= x: | ||
return (x, y) | ||
else: | ||
return (y, x) | ||
|
||
|
||
def small_star_reduce(group): | ||
x, neighbors = group | ||
nodes = [x] + list(neighbors) | ||
minimum = min(nodes) | ||
return [(n, minimum) for n in nodes if n != minimum] | ||
|
||
|
||
def find_components(edges): | ||
""" | ||
Star-Graph-Connected-Components (SGCC) algorithm | ||
""" | ||
|
||
a = edges | ||
while True: | ||
b = a.flatMap(large_star_map).groupByKey().flatMap( | ||
large_star_reduce).distinct().cache() | ||
a = b.map(small_star_map).groupByKey().flatMap( | ||
small_star_reduce).distinct().cache() | ||
changes = a.subtract(b).union(b.subtract(a)).collect() | ||
if len(changes) == 0: | ||
break | ||
|
||
results = a.collect() | ||
return results |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
import sys | ||
import time | ||
from typing import Union | ||
|
||
import fire | ||
from loguru import logger | ||
from pyspark.ml.feature import HashingTF, MinHashLSH, Tokenizer | ||
from pyspark.sql import functions as F | ||
from pyspark.sql.functions import posexplode | ||
|
||
from tools.distributed_deduplication.dedup_utils import (find_components, | ||
generate_edges, | ||
init_spark) | ||
from tools.quality_classifier.qc_utils import (export_result, load_dataset, | ||
tokenize_dataset) | ||
|
||
|
||
@logger.catch | ||
def dedup_dataset(dataset_path: str, | ||
result_path: str, | ||
tokenizer: Union[str, None] = None, | ||
num_features: int = 1047576, | ||
num_hashtables: int = 10, | ||
text_key: str = 'text', | ||
master_url: Union[str, None] = None): | ||
""" | ||
Perform fuzzy text deduplication on the given dataset. | ||
:param dataset_path: the path to the dataset to perform deduplication, | ||
The suffix of the path should be one of the json, jsonl, parquet. | ||
:param result_path: the path to store the predicted result dataset. | ||
:param tokenizer: what tokenizer to use to tokenize texts. It's None in | ||
default, which means using the standard Tokenizer of PySpark. You can | ||
use one of ["zh.sp.model", "code.sp.model"] we provided, or you can set | ||
it to the path to your own sentencepiece model. | ||
:param num_features: the number of features that HashingTF generates. | ||
Default with 1047576 as mentioned in megatron-turing-nlg paper. | ||
:param num_hashtables: the number of hashes used in MinHashLSH. | ||
Default with 10 hashes as mentioned in the GPT3 paper. | ||
:param text_key: the field key name to hold texts to be classified. It's | ||
"text" in default. | ||
:param master_url: the master url for spark config. Default is None. | ||
If None, then run with local[*]. | ||
""" | ||
# for inited cluster, | ||
# provide master url such as "spark://master:7077" | ||
spark = init_spark(master_url=master_url) | ||
ds = load_dataset(spark, dataset_path, text_key=text_key) | ||
ds = ds.withColumn('id', F.monotonically_increasing_id()).cache() | ||
df = ds | ||
|
||
if tokenizer: | ||
ds = tokenize_dataset(ds, tokenizer) | ||
else: | ||
ds = Tokenizer(inputCol='text', outputCol='words').transform(ds) | ||
|
||
hashingTF = HashingTF(inputCol='words', | ||
outputCol='features', | ||
numFeatures=num_features) | ||
ds = hashingTF.transform(ds) | ||
|
||
minHash = MinHashLSH(inputCol='features', | ||
outputCol='hashes', | ||
numHashTables=num_hashtables) | ||
model = minHash.fit(ds) | ||
|
||
ds = model.transform(ds) | ||
|
||
ds = ds.select('id', posexplode('hashes').alias('band_idx', 'hash_vector')) | ||
|
||
record = ds.rdd.map(lambda x: | ||
(x['band_idx'], int(x['hash_vector'][0]), x['id'])) | ||
|
||
edges = (record.groupBy(lambda x: (x[0], x[1])).flatMap( | ||
lambda x: generate_edges([i[2] for i in x[1]])).distinct().cache()) | ||
|
||
results = find_components(edges) | ||
if len(results) == 0: | ||
logger.info('No components found.') | ||
sys.exit(0) | ||
|
||
components = spark.createDataFrame(results, | ||
schema=['id', 'component' | ||
]).sort(['component', 'id']) | ||
components.show() | ||
df = df.join(components, on='id', how='left') | ||
df = df.filter(F.col('component').isNull()).drop('id', 'component').cache() | ||
export_result(df, result_path) | ||
|
||
|
||
if __name__ == '__main__': | ||
stime = time.time() | ||
fire.Fire(dedup_dataset) | ||
etime = time.time() | ||
logger.info(f'Execution Done, Total time {etime - stime}') |
c1a8aa8
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@garyzhang99 could I have your email and arrange a talk with you.
c1a8aa8
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@shewang-rh Sure. My email is [email protected].