-
Notifications
You must be signed in to change notification settings - Fork 191
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* support extract QA operator * update
- Loading branch information
Showing
4 changed files
with
150 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
import json | ||
import logging | ||
import re | ||
|
||
from data_juicer.ops.base_op import OPERATORS, Mapper | ||
from data_juicer.utils.model_utils import get_model, prepare_model | ||
|
||
|
||
@OPERATORS.register_module('extract_qa_mapper') | ||
class ExtractQAMapper(Mapper): | ||
""" | ||
Mapper to extract question and answer pair from text samples. | ||
Recommended model list: [ | ||
'alibaba-pai/pai-llama3-8b-doc2qa', | ||
'alibaba-pai/pai-baichuan2-7b-doc2qa', | ||
'alibaba-pai/pai-qwen1_5-4b-doc2qa', | ||
'alibaba-pai/pai-qwen1_5-7b-doc2qa', | ||
'alibaba-pai/pai-qwen1_5-1b8-doc2qa', | ||
'alibaba-pai/pai-qwen1_5-0b5-doc2qa' | ||
] | ||
These recommended models are all trained with Chinese data | ||
and are suitable for Chinese. | ||
""" | ||
|
||
def __init__(self, | ||
hf_model: str = 'alibaba-pai/pai-qwen1_5-7b-doc2qa', | ||
pattern: str = None, | ||
qa_format: str = 'chatml', | ||
*args, | ||
**kwargs): | ||
""" | ||
Initialization method. | ||
:param hf_model: Hugginface model id. | ||
:param pattern: regular expression pattern to search for within text. | ||
:param qa_format: Output format of question and answer pair. | ||
:param args: extra args | ||
:param kwargs: extra args | ||
The default data format parsed by this interface is as follows: | ||
Model Input: | ||
蒙古国的首都是乌兰巴托(Ulaanbaatar) | ||
冰岛的首都是雷克雅未克(Reykjavik) | ||
Model Output: | ||
蒙古国的首都是乌兰巴托(Ulaanbaatar) | ||
冰岛的首都是雷克雅未克(Reykjavik) | ||
Human: 请问蒙古国的首都是哪里? | ||
Assistant: 你好,根据提供的信息,蒙古国的首都是乌兰巴托(Ulaanbaatar)。 | ||
Human: 冰岛的首都是哪里呢? | ||
Assistant: 冰岛的首都是雷克雅未克(Reykjavik)。 | ||
... | ||
""" | ||
|
||
super().__init__(*args, **kwargs) | ||
self._batched_op = True | ||
self._accelerator = 'cuda' | ||
|
||
if pattern is None: | ||
self.pattern = r'Human: (.*?)\nAssistant: (.*?)(?=\nHuman|$)' | ||
else: | ||
self.pattern = pattern | ||
|
||
self.qa_format = qa_format | ||
self.model_key = prepare_model(model_type='huggingface', | ||
pretrained_model_name_or_path=hf_model) | ||
|
||
def _extract_qa(self, output): | ||
"""Extract qestion and answer pair from model output response.""" | ||
qa_list = [] | ||
|
||
pat = re.compile(self.pattern, re.DOTALL) | ||
qa_pairs = pat.findall(output) | ||
|
||
for _, qa in enumerate(qa_pairs, 1): | ||
user, assistant = qa | ||
qa_list.append((user.strip(), assistant.strip())) | ||
|
||
return qa_list | ||
|
||
def process(self, sample, rank=None): | ||
model, processor = get_model(self.model_key, rank=rank) | ||
|
||
inputs = processor(sample[self.text_key], | ||
return_tensors='pt').to(model.device) | ||
response = model.generate(**inputs) | ||
output = processor.decode(response.cpu()[0], skip_special_tokens=True) | ||
qa_list = self._extract_qa(output) | ||
|
||
if not len(qa_list): | ||
logging.info( | ||
'No question and answer data was extracted from this sample!') | ||
|
||
dialogue_data = [] | ||
if self.qa_format == 'chatml': | ||
for qa in qa_list: | ||
dialogue_data.append({ | ||
'messages': [{ | ||
'role': 'user', | ||
'content': qa[0] | ||
}, { | ||
'role': 'assistant', | ||
'content': qa[1] | ||
}] | ||
}) | ||
else: | ||
raise ValueError(f'Not support {self.qa_format}!') | ||
|
||
sample[self.text_key] = json.dumps(dialogue_data, ensure_ascii=False) | ||
|
||
return sample |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
import unittest | ||
import json | ||
from data_juicer.ops.mapper.extract_qa_mapper import ExtractQAMapper | ||
from data_juicer.utils.unittest_utils import (SKIPPED_TESTS, | ||
DataJuicerTestCaseBase) | ||
|
||
# Skip tests for this OP in the GitHub actions due to disk space limitation. | ||
# These tests have been tested locally. | ||
@SKIPPED_TESTS.register_module() | ||
class ExtractQAMapperTest(DataJuicerTestCaseBase): | ||
text_key = 'text' | ||
|
||
def _run_extract_qa(self, samples): | ||
op = ExtractQAMapper( | ||
hf_model='alibaba-pai/pai-qwen1_5-7b-doc2qa', | ||
qa_format='chatml' | ||
) | ||
for sample in samples: | ||
result = op.process(sample) | ||
out_text = json.loads(result[self.text_key]) | ||
|
||
# test one output qa sample | ||
qa_sample = out_text[0] | ||
self.assertIn('role', qa_sample['messages'][0]) | ||
self.assertIn('content', qa_sample['messages'][0]) | ||
|
||
def test_extract_qa(self): | ||
samples = [ | ||
{ | ||
self.text_key: '蒙古国的首都是乌兰巴托(Ulaanbaatar)\n冰岛的首都是雷克雅未克(Reykjavik)\n' | ||
}] | ||
self._run_extract_qa(samples) | ||
|
||
|
||
if __name__ == '__main__': | ||
unittest.main() |