Skip to content

Commit

Permalink
support extract QA operator (#333)
Browse files Browse the repository at this point in the history
* support extract QA operator

* update
  • Loading branch information
Cathy0908 authored Jul 3, 2024
1 parent c85e024 commit 1244d4f
Show file tree
Hide file tree
Showing 4 changed files with 150 additions and 1 deletion.
2 changes: 2 additions & 0 deletions configs/config_all.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ process:
- clean_links_mapper: # remove web links from text.
- clean_copyright_mapper: # remove copyright comments.
- expand_macro_mapper: # expand macro definitions in Latex text.
- extract_qa_mapper: # mapper to extract question and answer pair from text.
hf_model: 'alibaba-pai/pai-qwen1_5-7b-doc2qa'
- fix_unicode_mapper: # fix unicode errors in text.
- image_blur_mapper: # mapper to blur images.
p: 0.2 # probability of the image being blured
Expand Down
4 changes: 3 additions & 1 deletion data_juicer/ops/mapper/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from . import (audio_ffmpeg_wrapped_mapper, chinese_convert_mapper,
clean_copyright_mapper, clean_email_mapper, clean_html_mapper,
clean_ip_mapper, clean_links_mapper, expand_macro_mapper,
fix_unicode_mapper, image_blur_mapper,
extract_qa_mapper, fix_unicode_mapper, image_blur_mapper,
image_captioning_from_gpt4v_mapper, image_captioning_mapper,
image_diffusion_mapper, image_face_blur_mapper,
nlpaug_en_mapper, nlpcda_zh_mapper,
Expand Down Expand Up @@ -32,6 +32,7 @@
from .clean_ip_mapper import CleanIpMapper
from .clean_links_mapper import CleanLinksMapper
from .expand_macro_mapper import ExpandMacroMapper
from .extract_qa_mapper import ExtractQAMapper
from .fix_unicode_mapper import FixUnicodeMapper
from .image_blur_mapper import ImageBlurMapper
from .image_captioning_from_gpt4v_mapper import ImageCaptioningFromGPT4VMapper
Expand Down Expand Up @@ -102,6 +103,7 @@
'VideoTaggingFromFramesMapper',
'RemoveCommentsMapper',
'ExpandMacroMapper',
'ExtractQAMapper',
'ImageCaptioningMapper',
'RemoveWordsWithIncorrectSubstringsMapper',
'VideoCaptioningFromVideoMapper',
Expand Down
109 changes: 109 additions & 0 deletions data_juicer/ops/mapper/extract_qa_mapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
import json
import logging
import re

from data_juicer.ops.base_op import OPERATORS, Mapper
from data_juicer.utils.model_utils import get_model, prepare_model


@OPERATORS.register_module('extract_qa_mapper')
class ExtractQAMapper(Mapper):
"""
Mapper to extract question and answer pair from text samples.
Recommended model list: [
'alibaba-pai/pai-llama3-8b-doc2qa',
'alibaba-pai/pai-baichuan2-7b-doc2qa',
'alibaba-pai/pai-qwen1_5-4b-doc2qa',
'alibaba-pai/pai-qwen1_5-7b-doc2qa',
'alibaba-pai/pai-qwen1_5-1b8-doc2qa',
'alibaba-pai/pai-qwen1_5-0b5-doc2qa'
]
These recommended models are all trained with Chinese data
and are suitable for Chinese.
"""

def __init__(self,
hf_model: str = 'alibaba-pai/pai-qwen1_5-7b-doc2qa',
pattern: str = None,
qa_format: str = 'chatml',
*args,
**kwargs):
"""
Initialization method.
:param hf_model: Hugginface model id.
:param pattern: regular expression pattern to search for within text.
:param qa_format: Output format of question and answer pair.
:param args: extra args
:param kwargs: extra args
The default data format parsed by this interface is as follows:
Model Input:
蒙古国的首都是乌兰巴托(Ulaanbaatar)
冰岛的首都是雷克雅未克(Reykjavik)
Model Output:
蒙古国的首都是乌兰巴托(Ulaanbaatar)
冰岛的首都是雷克雅未克(Reykjavik)
Human: 请问蒙古国的首都是哪里?
Assistant: 你好,根据提供的信息,蒙古国的首都是乌兰巴托(Ulaanbaatar)。
Human: 冰岛的首都是哪里呢?
Assistant: 冰岛的首都是雷克雅未克(Reykjavik)。
...
"""

super().__init__(*args, **kwargs)
self._batched_op = True
self._accelerator = 'cuda'

if pattern is None:
self.pattern = r'Human: (.*?)\nAssistant: (.*?)(?=\nHuman|$)'
else:
self.pattern = pattern

self.qa_format = qa_format
self.model_key = prepare_model(model_type='huggingface',
pretrained_model_name_or_path=hf_model)

def _extract_qa(self, output):
"""Extract qestion and answer pair from model output response."""
qa_list = []

pat = re.compile(self.pattern, re.DOTALL)
qa_pairs = pat.findall(output)

for _, qa in enumerate(qa_pairs, 1):
user, assistant = qa
qa_list.append((user.strip(), assistant.strip()))

return qa_list

def process(self, sample, rank=None):
model, processor = get_model(self.model_key, rank=rank)

inputs = processor(sample[self.text_key],
return_tensors='pt').to(model.device)
response = model.generate(**inputs)
output = processor.decode(response.cpu()[0], skip_special_tokens=True)
qa_list = self._extract_qa(output)

if not len(qa_list):
logging.info(
'No question and answer data was extracted from this sample!')

dialogue_data = []
if self.qa_format == 'chatml':
for qa in qa_list:
dialogue_data.append({
'messages': [{
'role': 'user',
'content': qa[0]
}, {
'role': 'assistant',
'content': qa[1]
}]
})
else:
raise ValueError(f'Not support {self.qa_format}!')

sample[self.text_key] = json.dumps(dialogue_data, ensure_ascii=False)

return sample
36 changes: 36 additions & 0 deletions tests/ops/mapper/test_extract_qa_mapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import unittest
import json
from data_juicer.ops.mapper.extract_qa_mapper import ExtractQAMapper
from data_juicer.utils.unittest_utils import (SKIPPED_TESTS,
DataJuicerTestCaseBase)

# Skip tests for this OP in the GitHub actions due to disk space limitation.
# These tests have been tested locally.
@SKIPPED_TESTS.register_module()
class ExtractQAMapperTest(DataJuicerTestCaseBase):
text_key = 'text'

def _run_extract_qa(self, samples):
op = ExtractQAMapper(
hf_model='alibaba-pai/pai-qwen1_5-7b-doc2qa',
qa_format='chatml'
)
for sample in samples:
result = op.process(sample)
out_text = json.loads(result[self.text_key])

# test one output qa sample
qa_sample = out_text[0]
self.assertIn('role', qa_sample['messages'][0])
self.assertIn('content', qa_sample['messages'][0])

def test_extract_qa(self):
samples = [
{
self.text_key: '蒙古国的首都是乌兰巴托(Ulaanbaatar)\n冰岛的首都是雷克雅未克(Reykjavik)\n'
}]
self._run_extract_qa(samples)


if __name__ == '__main__':
unittest.main()

0 comments on commit 1244d4f

Please sign in to comment.