Skip to content

Commit

Permalink
Add DPO data OP (#491)
Browse files Browse the repository at this point in the history
* add DPO OP

* fix args

* refine

* refine DPO op

* add docs
  • Loading branch information
drcege authored Nov 28, 2024
1 parent 8ade9b5 commit 4f0f16c
Show file tree
Hide file tree
Showing 18 changed files with 243 additions and 47 deletions.
12 changes: 12 additions & 0 deletions configs/config_all.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,18 @@ process:
sampling_params: {} # Sampling parameters for text generation. e.g {'temperature': 0.9, 'top_p': 0.95}
- optimize_query_mapper: # optimize query in question-answer pairs.
- optimize_response_mapper: # optimize response in question-answer pairs.
- pair_preference_mapper: # construct paired preference samples.
api_model: 'gpt-4o' # API model name.
api_endpoint: null # URL endpoint for the API.
response_path: null # Path to extract content from the API response. Defaults to 'choices.0.message.content'.
system_prompt: null # System prompt for guiding the generation task.
input_template: null # Template for building the model input.
output_pattern: null # Regular expression for parsing model output.
rejected_key: 'rejected_response' # The field name in the sample to store the generated rejected response.
reason_key: 'reason' # The field name in the sample to store the reason for generating the response.
try_num: 3 # The number of retries for the API call in case of response parsing failure.
model_params: {} # Parameters for initializing the API model.
sampling_params: {} # Extra parameters passed to the API call.
- punctuation_normalization_mapper: # normalize unicode punctuations to English punctuations.
- remove_bibliography_mapper: # remove bibliography from Latex text.
- remove_comments_mapper: # remove comments from Latex text, code, etc.
Expand Down
28 changes: 15 additions & 13 deletions data_juicer/ops/mapper/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from .optimize_qa_mapper import OptimizeQAMapper
from .optimize_query_mapper import OptimizeQueryMapper
from .optimize_response_mapper import OptimizeResponseMapper
from .pair_preference_mapper import PairPreferenceMapper
from .punctuation_normalization_mapper import PunctuationNormalizationMapper
from .remove_bibliography_mapper import RemoveBibliographyMapper
from .remove_comments_mapper import RemoveCommentsMapper
Expand Down Expand Up @@ -73,17 +74,18 @@
'ImageCaptioningMapper', 'ImageDiffusionMapper', 'ImageFaceBlurMapper',
'ImageTaggingMapper', 'NlpaugEnMapper', 'NlpcdaZhMapper',
'OptimizeQAMapper', 'OptimizeQueryMapper', 'OptimizeResponseMapper',
'PunctuationNormalizationMapper', 'RemoveBibliographyMapper',
'RemoveCommentsMapper', 'RemoveHeaderMapper', 'RemoveLongWordsMapper',
'RemoveNonChineseCharacterlMapper', 'RemoveRepeatSentencesMapper',
'RemoveSpecificCharsMapper', 'RemoveTableTextMapper',
'RemoveWordsWithIncorrectSubstringsMapper', 'ReplaceContentMapper',
'SentenceSplitMapper', 'TextChunkMapper', 'VideoCaptioningFromAudioMapper',
'VideoCaptioningFromFramesMapper', 'VideoCaptioningFromSummarizerMapper',
'VideoCaptioningFromVideoMapper', 'VideoFFmpegWrappedMapper',
'VideoFaceBlurMapper', 'VideoRemoveWatermarkMapper',
'VideoResizeAspectRatioMapper', 'VideoResizeResolutionMapper',
'VideoSplitByDurationMapper', 'VideoSplitByKeyFrameMapper',
'VideoSplitBySceneMapper', 'VideoTaggingFromAudioMapper',
'VideoTaggingFromFramesMapper', 'WhitespaceNormalizationMapper'
'PairPreferenceMapper', 'PunctuationNormalizationMapper',
'RemoveBibliographyMapper', 'RemoveCommentsMapper', 'RemoveHeaderMapper',
'RemoveLongWordsMapper', 'RemoveNonChineseCharacterlMapper',
'RemoveRepeatSentencesMapper', 'RemoveSpecificCharsMapper',
'RemoveTableTextMapper', 'RemoveWordsWithIncorrectSubstringsMapper',
'ReplaceContentMapper', 'SentenceSplitMapper', 'TextChunkMapper',
'VideoCaptioningFromAudioMapper', 'VideoCaptioningFromFramesMapper',
'VideoCaptioningFromSummarizerMapper', 'VideoCaptioningFromVideoMapper',
'VideoFFmpegWrappedMapper', 'VideoFaceBlurMapper',
'VideoRemoveWatermarkMapper', 'VideoResizeAspectRatioMapper',
'VideoResizeResolutionMapper', 'VideoSplitByDurationMapper',
'VideoSplitByKeyFrameMapper', 'VideoSplitBySceneMapper',
'VideoTaggingFromAudioMapper', 'VideoTaggingFromFramesMapper',
'WhitespaceNormalizationMapper'
]
5 changes: 2 additions & 3 deletions data_juicer/ops/mapper/calibrate_qa_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,13 @@
from loguru import logger
from pydantic import PositiveInt

from data_juicer.ops.base_op import OPERATORS, UNFORKABLE, Mapper
from data_juicer.ops.base_op import OPERATORS, Mapper
from data_juicer.utils.model_utils import get_model, prepare_model

OP_NAME = 'calibrate_qa_mapper'


# TODO: LLM-based inference.
@UNFORKABLE.register_module(OP_NAME)
@OPERATORS.register_module(OP_NAME)
class CalibrateQAMapper(Mapper):
"""
Expand Down Expand Up @@ -107,7 +106,7 @@ def process_single(self, sample, rank=None):
'content': self.build_input(sample)
}]
parsed_q, parsed_a = None, None
for i in range(self.try_num):
for _ in range(self.try_num):
try:
output = client(messages, **self.sampling_params)
parsed_q, parsed_a = self.parse_output(output)
Expand Down
3 changes: 1 addition & 2 deletions data_juicer/ops/mapper/calibrate_query_mapper.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
from data_juicer.ops.base_op import OPERATORS, UNFORKABLE
from data_juicer.ops.base_op import OPERATORS
from data_juicer.ops.mapper.calibrate_qa_mapper import CalibrateQAMapper

OP_NAME = 'calibrate_query_mapper'


# TODO: LLM-based inference.
@UNFORKABLE.register_module(OP_NAME)
@OPERATORS.register_module(OP_NAME)
class CalibrateQueryMapper(CalibrateQAMapper):
"""
Expand Down
3 changes: 1 addition & 2 deletions data_juicer/ops/mapper/calibrate_response_mapper.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
from data_juicer.ops.base_op import OPERATORS, UNFORKABLE
from data_juicer.ops.base_op import OPERATORS
from data_juicer.ops.mapper.calibrate_qa_mapper import CalibrateQAMapper

OP_NAME = 'calibrate_response_mapper'


# TODO: LLM-based inference.
@UNFORKABLE.register_module(OP_NAME)
@OPERATORS.register_module(OP_NAME)
class CalibrateResponseMapper(CalibrateQAMapper):
"""
Expand Down
5 changes: 2 additions & 3 deletions data_juicer/ops/mapper/extract_entity_attribute_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,14 @@
from loguru import logger
from pydantic import PositiveInt

from data_juicer.ops.base_op import OPERATORS, UNFORKABLE, Mapper
from data_juicer.ops.base_op import OPERATORS, Mapper
from data_juicer.utils.constant import Fields
from data_juicer.utils.model_utils import get_model, prepare_model

OP_NAME = 'extract_entity_attribute_mapper'


# TODO: LLM-based inference.
@UNFORKABLE.register_module(OP_NAME)
@OPERATORS.register_module(OP_NAME)
class ExtractEntityAttributeMapper(Mapper):
"""
Expand Down Expand Up @@ -154,7 +153,7 @@ def _process_single_sample(self, text='', rank=None):
}]

desc, demos = '', []
for i in range(self.try_num):
for _ in range(self.try_num):
try:
output = client(messages, **self.sampling_params)
desc, demos = self.parse_output(output, attribute)
Expand Down
5 changes: 2 additions & 3 deletions data_juicer/ops/mapper/extract_entity_relation_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from loguru import logger
from pydantic import NonNegativeInt, PositiveInt

from data_juicer.ops.base_op import OPERATORS, UNFORKABLE, Mapper
from data_juicer.ops.base_op import OPERATORS, Mapper
from data_juicer.utils.common_utils import is_float
from data_juicer.utils.constant import Fields
from data_juicer.utils.model_utils import get_model, prepare_model
Expand All @@ -20,7 +20,6 @@


# TODO: LLM-based inference.
@UNFORKABLE.register_module(OP_NAME)
@OPERATORS.register_module(OP_NAME)
class ExtractEntityRelationMapper(Mapper):
"""
Expand Down Expand Up @@ -319,7 +318,7 @@ def process_single(self, sample, rank=None):
messages = [{'role': 'user', 'content': input_prompt}]

entities, relations = [], []
for i in range(self.try_num):
for _ in range(self.try_num):
try:
result = self.light_rag_extraction(messages, rank=rank)
entities, relations = self.parse_output(result)
Expand Down
5 changes: 2 additions & 3 deletions data_juicer/ops/mapper/extract_event_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from loguru import logger
from pydantic import PositiveInt

from data_juicer.ops.base_op import OPERATORS, UNFORKABLE, Mapper
from data_juicer.ops.base_op import OPERATORS, Mapper
from data_juicer.utils.constant import Fields
from data_juicer.utils.model_utils import get_model, prepare_model

Expand All @@ -15,7 +15,6 @@


# TODO: LLM-based inference.
@UNFORKABLE.register_module(OP_NAME)
@OPERATORS.register_module(OP_NAME)
class ExtractEventMapper(Mapper):
"""
Expand Down Expand Up @@ -134,7 +133,7 @@ def _process_single_sample(self, text='', rank=None):
}]

event_list, character_list = [], []
for i in range(self.try_num):
for _ in range(self.try_num):
try:
output = client(messages, **self.sampling_params)
event_list, character_list = self.parse_output(output)
Expand Down
5 changes: 2 additions & 3 deletions data_juicer/ops/mapper/extract_keyword_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from loguru import logger
from pydantic import PositiveInt

from data_juicer.ops.base_op import OPERATORS, UNFORKABLE, Mapper
from data_juicer.ops.base_op import OPERATORS, Mapper
from data_juicer.utils.constant import Fields
from data_juicer.utils.model_utils import get_model, prepare_model

Expand All @@ -16,7 +16,6 @@


# TODO: LLM-based inference.
@UNFORKABLE.register_module(OP_NAME)
@OPERATORS.register_module(OP_NAME)
class ExtractKeywordMapper(Mapper):
"""
Expand Down Expand Up @@ -173,7 +172,7 @@ def process_single(self, sample, rank=None):
messages = [{'role': 'user', 'content': input_prompt}]

keywords = []
for i in range(self.try_num):
for _ in range(self.try_num):
try:
result = client(messages, **self.sampling_params)
keywords = self.parse_output(result)
Expand Down
5 changes: 2 additions & 3 deletions data_juicer/ops/mapper/extract_nickname_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,14 @@
from loguru import logger
from pydantic import PositiveInt

from data_juicer.ops.base_op import OPERATORS, UNFORKABLE, Mapper
from data_juicer.ops.base_op import OPERATORS, Mapper
from data_juicer.utils.constant import Fields
from data_juicer.utils.model_utils import get_model, prepare_model

OP_NAME = 'extract_nickname_mapper'


# TODO: LLM-based inference.
@UNFORKABLE.register_module(OP_NAME)
@OPERATORS.register_module(OP_NAME)
class ExtractNicknameMapper(Mapper):
"""
Expand Down Expand Up @@ -143,7 +142,7 @@ def process_single(self, sample, rank=None):
'content': input_prompt
}]
nickname_relations = []
for i in range(self.try_num):
for _ in range(self.try_num):
try:
output = client(messages, **self.sampling_params)
nickname_relations = self.parse_output(output)
Expand Down
131 changes: 131 additions & 0 deletions data_juicer/ops/mapper/pair_preference_mapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
import re
from typing import Dict, Optional

from loguru import logger
from pydantic import PositiveInt

from data_juicer.ops.base_op import OPERATORS, Mapper
from data_juicer.utils.model_utils import get_model, prepare_model

OP_NAME = 'pair_preference_mapper'


# TODO: Extend LLM-based OPs into API-based implementation.
@OPERATORS.register_module(OP_NAME)
class PairPreferenceMapper(Mapper):
"""
Mapper to construct paired preference samples.
"""

# avoid leading whitespace
DEFAULT_SYSTEM_PROMPT = (
'你的任务是根据参考信息修改问答对中的回答,在语言风格、事实性、人物身份、立场等任一方面与原回答相反。'
'必须按照以下标记格式输出,不要输出其他多余内容。\n'
'【回答】\n'
'生成的新回答\n'
'【原因】\n'
'生成该回答的原因')
DEFAULT_INPUT_TEMPLATE = ('【参考信息】\n'
'{reference}\n'
'\n'
'以下是原始问答对:\n'
'【问题】\n'
'{query}\n'
'【回答】\n'
'{response}')
DEFAULT_OUTPUT_PATTERN = r'.*?【回答】\s*(.*?)\s*【原因】\s*(.*)'

def __init__(self,
api_model: str = 'gpt-4o',
*,
api_endpoint: Optional[str] = None,
response_path: Optional[str] = None,
system_prompt: Optional[str] = None,
input_template: Optional[str] = None,
output_pattern: Optional[str] = None,
rejected_key: str = 'rejected_response',
reason_key: str = 'reason',
try_num: PositiveInt = 3,
model_params: Dict = {},
sampling_params: Dict = {},
**kwargs):
"""
Initialization method.
:param api_model: API model name.
:param api_endpoint: URL endpoint for the API.
:param response_path: Path to extract content from the API response.
Defaults to 'choices.0.message.content'.
:param system_prompt: System prompt for guiding the generation task.
:param input_template: Template for building the model input. It must
contain placeholders '{query}' and '{reponse}', and can optionally
include '{reference}'.
:param output_pattern: Regular expression for parsing model output.
:param rejected_key: The field name in the sample to store the
generated rejected response. Defaults to 'rejected_response'.
:param reason_key: The field name in the sample to store the reason for
generating the response. Defaults to 'reason'.
:param try_num: The number of retries for the API call in case of
response parsing failure. Defaults to 3.
:param model_params: Parameters for initializing the API model.
:param sampling_params: Extra parameters passed to the API call.
e.g {'temperature': 0.9, 'top_p': 0.95}
:param kwargs: Extra keyword arguments.
"""
super().__init__(**kwargs)

self.system_prompt = system_prompt or self.DEFAULT_SYSTEM_PROMPT
self.input_template = input_template or self.DEFAULT_INPUT_TEMPLATE
self.output_pattern = output_pattern or self.DEFAULT_OUTPUT_PATTERN

self.rejected_key = rejected_key
self.reason_key = reason_key

self.model_key = prepare_model(model_type='api',
model=api_model,
endpoint=api_endpoint,
response_path=response_path,
**model_params)
self.try_num = try_num
self.sampling_params = sampling_params

def build_input(self, sample):
mapping = {
'query': sample[self.query_key],
'response': sample[self.response_key],
'reference': sample.get(self.text_key, '')
}
return self.input_template.format_map(mapping)

def parse_output(self, raw_output):
logger.debug(raw_output)
match = re.match(self.output_pattern, raw_output, re.DOTALL)
if match:
return match.group(1).strip(), match.group(2).strip()
else:
return ('', '')

def process_single(self, sample, rank=None):
client = get_model(self.model_key, rank=rank)

messages = [{
'role': 'system',
'content': self.system_prompt
}, {
'role': 'user',
'content': self.build_input(sample)
}]

parsed_rejected, parsed_reason = '', ''
for _ in range(self.try_num):
try:
output = client(messages, **self.sampling_params)
parsed_rejected, parsed_reason = self.parse_output(output)
if parsed_rejected and parsed_reason:
break
except Exception as e:
logger.warning(f'Exception: {e}')
sample[self.rejected_key] = parsed_rejected
sample[self.reason_key] = parsed_reason

return sample
Loading

0 comments on commit 4f0f16c

Please sign in to comment.