-
Notifications
You must be signed in to change notification settings - Fork 191
/
replace_content_mapper.py
66 lines (53 loc) · 2.37 KB
/
replace_content_mapper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
from typing import List, Union
import regex as re
from ..base_op import OPERATORS, Mapper
@OPERATORS.register_module('replace_content_mapper')
class ReplaceContentMapper(Mapper):
"""Mapper to replace all content in the text that matches
a specific regular expression pattern with a designated
replacement string."""
_batched_op = True
def __init__(self,
pattern: Union[str, List[str], None] = None,
repl: Union[str, List[str]] = '',
*args,
**kwargs):
"""
Initialization method.
:param pattern: regular expression pattern(s) to search for within text
:param repl: replacement string(s), default is empty string
:param args: extra args
:param kwargs: extra args
"""
super().__init__(*args, **kwargs)
self.pattern = pattern
self.repl = repl
self.compiled_patterns = []
if isinstance(pattern, str):
self.compiled_patterns.append(self._prepare_pattern(pattern))
elif isinstance(pattern, list):
for p in pattern:
self.compiled_patterns.append(self._prepare_pattern(p))
def _prepare_pattern(self, pattern: str) -> re.Pattern:
"""Prepare the regular expression pattern."""
if ((pattern is not None and len(pattern) > 2)
and (pattern.startswith("r'") and pattern.endswith("'")
or pattern.startswith('r"') and pattern.endswith('"'))):
pattern = pattern[2:-1]
return re.compile(pattern, flags=re.DOTALL)
def process_batched(self, samples):
if self.pattern is None:
return samples
for idx, text in enumerate(samples[self.text_key]):
for i, pattern in enumerate(self.compiled_patterns):
if isinstance(self.repl, list) and i < len(self.repl):
replacement = self.repl[i]
elif isinstance(self.repl, list) and i >= len(self.repl):
raise ValueError(f"pattern length: {len(self.pattern)} '"
f'must be equal to '
f'repl length: {len(self.repl)}')
else:
replacement = self.repl
text = pattern.sub(replacement, text)
samples[self.text_key][idx] = text
return samples