Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Two addition filter for image captions #115

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions configs/demo/process_test.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Process config example for dataset

# global parameters
project_name: 'demo-process'
dataset_path: './demos/data/demo-dataset.jsonl' # path to your dataset directory or file
np: 4 # number of subprocess to process your dataset

export_path: './outputs/demo-process/demo-processed.jsonl'

# process schedule
# a list of several process operators with their arguments
process:
- text_entity_dependency_filter:
lang: 'zh'
1 change: 1 addition & 0 deletions data_juicer/ops/filter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,6 @@
maximum_line_length_filter, perplexity_filter,
special_characters_filter, specified_field_filter,
specified_numeric_field_filter, stopwords_filter, suffix_filter,
text_action_filter, text_entity_dependency_filter,
text_length_filter, token_num_filter, word_num_filter,
word_repetition_filter)
78 changes: 78 additions & 0 deletions data_juicer/ops/filter/text_action_filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
from data_juicer.utils.constant import Fields, StatsKeys
from data_juicer.utils.mm_utils import SpecialTokens
from data_juicer.utils.model_utils import get_model, prepare_model

from ..base_op import OPERATORS, Filter

OP_NAME = 'text_action_filter'


@OPERATORS.register_module(OP_NAME)
class TextActionFilter(Filter):
"""
Filter to keep texts those contain actions in the text.
"""

def __init__(self,
lang: str = 'en',
min_action_num: int = 1,
*args,
**kwargs):
"""
Initialization method.

:param lang: language of the text in the samples. 'en' for detection of
actions in English an'zh' for detection of actions in Chinese.
:param mini_action_num: The min action number in the filtering. samples
will be filtered if their action number in the text is below this
parameter.
"""
super().__init__(*args, **kwargs)

if lang not in ['en', 'zh']:
raise ValueError(
f'Language [{lang}] is not supported in action detection.'
f'Can only be one of ["en", "zh"].')
self.lang = lang
self.model_key = prepare_model(model_type='spacy', lang=lang)
self.action_poss = ['VERB']
self.action_tags = ['VV', 'VB', 'VBP', 'VBZ', 'VBD', 'VBG', 'VBN']
self.min_action_num = min_action_num

def compute_stats(self, sample, context=False):
# check if it's computed already
if StatsKeys.num_action in sample[Fields.stats]:
return sample

# load text.
special_token_dict = {
key: value
for key, value in SpecialTokens.__dict__.items()
if not key.startswith('__')
}

def remove_special_token(text):
for value in special_token_dict.values():
text = text.replace(value, '')
return text

text = remove_special_token(sample[self.text_key])

# process text via spacy and count the actions in text
model = get_model(self.model_key)
doc = model(text)
num_action = 0
for token in doc:
if token.pos_ in self.action_poss \
and token.tag_ in self.action_tags:
num_action += 1
sample[Fields.stats][StatsKeys.num_action] = num_action

return sample

def process(self, sample):
num_action = sample[Fields.stats][StatsKeys.num_action]
if self.min_action_num <= num_action:
return True
else:
return False
115 changes: 115 additions & 0 deletions data_juicer/ops/filter/text_entity_dependency_filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
import numpy as np

from data_juicer.utils.constant import Fields, StatsKeys
from data_juicer.utils.mm_utils import SpecialTokens
from data_juicer.utils.model_utils import get_model, prepare_model

from ..base_op import OPERATORS, Filter

OP_NAME = 'text_entity_dependency_filter'


@OPERATORS.register_module(OP_NAME)
class TextEntityDependencyFilter(Filter):
"""
Identify the entities in the text which are independent with other token,
and filter them. The text containing no entities will be omitted.
"""

def __init__(self,
lang: str = 'en',
min_dependency_num: int = 1,
any_or_all: str = 'all',
*args,
**kwargs):
"""
Initialization method.

:param lang: language of the text in the samples. 'en' for detection of
actions in English an'zh' for detection of actions in Chinese.
:param mini_dependency_num: The min token number in the filtering.
Objects is independent if their number of edges in the dependency
tree is below this parameter.
:param any_or_all: keep this sample with 'any' or 'all' strategy.
'any': keep this sample if any objet is dependent. 'all': keep this
sample only if all images are dependent.
"""
super().__init__(*args, **kwargs)

if lang not in ['en', 'zh']:
raise ValueError(
f'Language [{lang}] is not supported in action detection.'
f'Can only be one of ["en", "zh"].')
self.lang = lang
self.model_key = prepare_model(model_type='spacy', lang=lang)
self.entity_poss = ['NOUN', 'PROPN', 'PRON']
self.entity_tags = ['NN', 'NR', 'PN', 'NNS', 'NNP', 'NNPS', 'PRP']
self.min_dependency_num = min_dependency_num
if any_or_all not in ['any', 'all']:
raise ValueError(f'Keep strategy [{any_or_all}] is not supported. '
f'Can only be one of ["any", "all"].')
self.any = (any_or_all == 'any')

def compute_stats(self, sample, context=False):
# check if it's computed already
if StatsKeys.num_dependency_edges in sample[Fields.stats]:
return sample

# load text.
special_token_dict = {
key: value
for key, value in SpecialTokens.__dict__.items()
if not key.startswith('__')
}

def remove_special_token(text):
for value in special_token_dict.values():
text = text.replace(value, '')
return text

text = remove_special_token(sample[self.text_key])

# identify entities
model = get_model(self.model_key)
doc = model(text)
entity_to_dependency_nums = {}
for token in doc:
if token.pos_ in self.entity_poss \
and token.tag_ in self.entity_tags:
entity_to_dependency_nums[token] = 0

# count the edges of each entity in dependency tree
for obj in entity_to_dependency_nums:
if obj.dep_ != 'ROOT':
entity_to_dependency_nums[obj] += 1
for token in doc:
# the punctation mark such as ',', '.'
if token.pos_ == 'PUNCT':
continue

if token.head in entity_to_dependency_nums.keys(
) and token.dep_ != 'ROOT':
entity_to_dependency_nums[token.head] += 1

sample[Fields.stats][StatsKeys.num_dependency_edges] = [
n for _, n in entity_to_dependency_nums.items()
]

return sample

def process(self, sample):
num_dependency_edges = sample[Fields.stats][
StatsKeys.num_dependency_edges]
keep_bools = np.array([
self.min_dependency_num <= num_edge
for num_edge in num_dependency_edges
])
# omit the samples without entity
if len(keep_bools) <= 0:
return False

# different strategies
if self.any:
return keep_bools.any()
else:
return keep_bools.all()
2 changes: 2 additions & 0 deletions data_juicer/utils/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ class StatsKeys(object):
num_token = 'num_token'
num_words = 'num_words'
word_rep_ratio = 'word_rep_ratio'
num_action = 'num_action'
num_dependency_edges = 'num_dependency_edges'

# image
aspect_ratios = 'aspect_ratios'
Expand Down
1 change: 1 addition & 0 deletions environments/minimal_requires.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,6 @@ plotly
python-docx
streamlit
spacy==3.5.0
spacy-pkuseg==0.0.32
multiprocess==0.70.12
dill==0.3.4
114 changes: 114 additions & 0 deletions tests/ops/filter/test_text_action_filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
import unittest
import os

from datasets import Dataset

from data_juicer.ops.filter.text_action_filter import TextActionFilter
from data_juicer.utils.constant import Fields
from data_juicer.utils.mm_utils import SpecialTokens


class TextActionFilterTest(unittest.TestCase):

data_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..',
'data')

cat_path = os.path.join(data_path, 'cat.jpg')
img3_path = os.path.join(data_path, 'img3.jpg')

def _run_text_action_filter(self, dataset: Dataset, target_list, op, column_names):
if Fields.stats not in dataset.features:
dataset = dataset.add_column(name=Fields.stats,
column=[{}] * dataset.num_rows)
dataset = dataset.map(op.compute_stats)
dataset = dataset.filter(op.process)
dataset = dataset.select_columns(column_names=column_names)
res_list = dataset.to_list()
self.assertEqual(res_list, target_list)

def test_en_text_case(self):

ds_list = [{
'text': 'Tom is playing piano.'
}, {
'text': 'Tom plays piano.'
}, {
'text': 'Tom played piano.'
},{
'text': 'I play piano.'
}, {
'text': 'to play piano.'
}, {
'text': 'Tom 在打篮球'
}, {
'text': 'a v s e c s f e f g a a a '
}, {
'text': ',。、„”“«»1」「《》´∶:?!();–—.~’…━〈〉【】%►'
}, {
'text': 'that is a green tree'
}]
tgt_list = [{
'text': 'Tom is playing piano.'
}, {
'text': 'Tom plays piano.'
}, {
'text': 'Tom played piano.'
},{
'text': 'I play piano.'
}, {
'text': 'to play piano.'
}]
dataset = Dataset.from_list(ds_list)
op = TextActionFilter(lang='en')
self._run_text_action_filter(dataset, tgt_list, op, ['text'])

def test_zh_text_case(self):

ds_list = [{
'text': '小明在 弹奏钢琴'
}, {
'text': 'Tom is playing 篮球'
}, {
'text': '上上下下左左右右'
}, {
'text': 'Tom在打篮球'
}, {
'text': '我有一只猫,它是一只猫'
}]
tgt_list = [{
'text': '小明在 弹奏钢琴'
}, {
'text': 'Tom在打篮球'
}]
dataset = Dataset.from_list(ds_list)
op = TextActionFilter(lang='zh')
self._run_text_action_filter(dataset, tgt_list, op, ['text'])

def test_image_text_case(self):
ds_list = [{
'text': f'{SpecialTokens.image}小猫咪正在睡觉。{SpecialTokens.eoc}',
'images': [self.cat_path]
}, {
'text': f'{SpecialTokens.image}小猫咪',
'images': [self.cat_path]
}, {
'text': f'{SpecialTokens.image}背影{SpecialTokens.eoc}',
'images': [self.img3_path]
}, {
'text': f'雨中行走的女人背影',
'images': [self.img3_path]
}]
tgt_list = [{
'text': f'{SpecialTokens.image}小猫咪正在睡觉。{SpecialTokens.eoc}',
'images': [self.cat_path]
}, {
'text': f'雨中行走的女人背影',
'images': [self.img3_path]
}]

dataset = Dataset.from_list(ds_list)
op = TextActionFilter(lang='zh')
self._run_text_action_filter(dataset, tgt_list, op, ['text', 'images'])

if __name__ == '__main__':
unittest.main()
Loading