-
Notifications
You must be signed in to change notification settings - Fork 191
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* add DPO OP * fix args * refine * refine DPO op * add docs
- Loading branch information
Showing
18 changed files
with
243 additions
and
47 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,131 @@ | ||
import re | ||
from typing import Dict, Optional | ||
|
||
from loguru import logger | ||
from pydantic import PositiveInt | ||
|
||
from data_juicer.ops.base_op import OPERATORS, Mapper | ||
from data_juicer.utils.model_utils import get_model, prepare_model | ||
|
||
OP_NAME = 'pair_preference_mapper' | ||
|
||
|
||
# TODO: Extend LLM-based OPs into API-based implementation. | ||
@OPERATORS.register_module(OP_NAME) | ||
class PairPreferenceMapper(Mapper): | ||
""" | ||
Mapper to construct paired preference samples. | ||
""" | ||
|
||
# avoid leading whitespace | ||
DEFAULT_SYSTEM_PROMPT = ( | ||
'你的任务是根据参考信息修改问答对中的回答,在语言风格、事实性、人物身份、立场等任一方面与原回答相反。' | ||
'必须按照以下标记格式输出,不要输出其他多余内容。\n' | ||
'【回答】\n' | ||
'生成的新回答\n' | ||
'【原因】\n' | ||
'生成该回答的原因') | ||
DEFAULT_INPUT_TEMPLATE = ('【参考信息】\n' | ||
'{reference}\n' | ||
'\n' | ||
'以下是原始问答对:\n' | ||
'【问题】\n' | ||
'{query}\n' | ||
'【回答】\n' | ||
'{response}') | ||
DEFAULT_OUTPUT_PATTERN = r'.*?【回答】\s*(.*?)\s*【原因】\s*(.*)' | ||
|
||
def __init__(self, | ||
api_model: str = 'gpt-4o', | ||
*, | ||
api_endpoint: Optional[str] = None, | ||
response_path: Optional[str] = None, | ||
system_prompt: Optional[str] = None, | ||
input_template: Optional[str] = None, | ||
output_pattern: Optional[str] = None, | ||
rejected_key: str = 'rejected_response', | ||
reason_key: str = 'reason', | ||
try_num: PositiveInt = 3, | ||
model_params: Dict = {}, | ||
sampling_params: Dict = {}, | ||
**kwargs): | ||
""" | ||
Initialization method. | ||
:param api_model: API model name. | ||
:param api_endpoint: URL endpoint for the API. | ||
:param response_path: Path to extract content from the API response. | ||
Defaults to 'choices.0.message.content'. | ||
:param system_prompt: System prompt for guiding the generation task. | ||
:param input_template: Template for building the model input. It must | ||
contain placeholders '{query}' and '{reponse}', and can optionally | ||
include '{reference}'. | ||
:param output_pattern: Regular expression for parsing model output. | ||
:param rejected_key: The field name in the sample to store the | ||
generated rejected response. Defaults to 'rejected_response'. | ||
:param reason_key: The field name in the sample to store the reason for | ||
generating the response. Defaults to 'reason'. | ||
:param try_num: The number of retries for the API call in case of | ||
response parsing failure. Defaults to 3. | ||
:param model_params: Parameters for initializing the API model. | ||
:param sampling_params: Extra parameters passed to the API call. | ||
e.g {'temperature': 0.9, 'top_p': 0.95} | ||
:param kwargs: Extra keyword arguments. | ||
""" | ||
super().__init__(**kwargs) | ||
|
||
self.system_prompt = system_prompt or self.DEFAULT_SYSTEM_PROMPT | ||
self.input_template = input_template or self.DEFAULT_INPUT_TEMPLATE | ||
self.output_pattern = output_pattern or self.DEFAULT_OUTPUT_PATTERN | ||
|
||
self.rejected_key = rejected_key | ||
self.reason_key = reason_key | ||
|
||
self.model_key = prepare_model(model_type='api', | ||
model=api_model, | ||
endpoint=api_endpoint, | ||
response_path=response_path, | ||
**model_params) | ||
self.try_num = try_num | ||
self.sampling_params = sampling_params | ||
|
||
def build_input(self, sample): | ||
mapping = { | ||
'query': sample[self.query_key], | ||
'response': sample[self.response_key], | ||
'reference': sample.get(self.text_key, '') | ||
} | ||
return self.input_template.format_map(mapping) | ||
|
||
def parse_output(self, raw_output): | ||
logger.debug(raw_output) | ||
match = re.match(self.output_pattern, raw_output, re.DOTALL) | ||
if match: | ||
return match.group(1).strip(), match.group(2).strip() | ||
else: | ||
return ('', '') | ||
|
||
def process_single(self, sample, rank=None): | ||
client = get_model(self.model_key, rank=rank) | ||
|
||
messages = [{ | ||
'role': 'system', | ||
'content': self.system_prompt | ||
}, { | ||
'role': 'user', | ||
'content': self.build_input(sample) | ||
}] | ||
|
||
parsed_rejected, parsed_reason = '', '' | ||
for _ in range(self.try_num): | ||
try: | ||
output = client(messages, **self.sampling_params) | ||
parsed_rejected, parsed_reason = self.parse_output(output) | ||
if parsed_rejected and parsed_reason: | ||
break | ||
except Exception as e: | ||
logger.warning(f'Exception: {e}') | ||
sample[self.rejected_key] = parsed_rejected | ||
sample[self.reason_key] = parsed_reason | ||
|
||
return sample |
Oops, something went wrong.