-
Notifications
You must be signed in to change notification settings - Fork 445
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Feature] Update CHARM Memeorziation (#1230)
* update gemini api and add gemini models * add openai models * update CHARM evaluation * add CHARM memorization tasks * add CharmMemSummarizer (output eval details for memorization-independent reasoning analysis * update CHARM readme --------- Co-authored-by: wujiang <[email protected]>
- Loading branch information
Showing
15 changed files
with
759 additions
and
38 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
import os | ||
from mmengine.config import read_base | ||
|
||
from opencompass.openicl.icl_prompt_template import PromptTemplate | ||
from opencompass.openicl.icl_retriever import ZeroRetriever | ||
from opencompass.openicl.icl_inferencer import GenInferencer | ||
from opencompass.datasets import CharmDataset, CharmMemoryEvaluator, LMEvaluator | ||
|
||
with read_base(): | ||
from .charm_memory_settings import charm_memory_tasks, judge_system_prompts, dataset_path | ||
|
||
charm_memory_datasets = [] | ||
|
||
for _task in charm_memory_tasks: | ||
|
||
charm_memory_reader_cfg = dict(input_columns=['input'], | ||
output_column='target') | ||
|
||
charm_memory_infer_cfg = dict( | ||
prompt_template=dict( | ||
type=PromptTemplate, | ||
template=dict(round=[ | ||
dict(role='HUMAN', prompt='请尽可能简短地回答下述问题。\n问题:{input}\n答:') | ||
]), | ||
), | ||
retriever=dict(type=ZeroRetriever), | ||
inferencer=dict(type=GenInferencer, max_out_len=512), | ||
) | ||
|
||
if _task == 'Chinese_Movie_and_Music_Recommendation': | ||
charm_memory_eval_cfg = dict( | ||
evaluator=dict(type=CharmMemoryEvaluator), | ||
pred_role='BOT', | ||
) | ||
else: | ||
judge_system_prompt = judge_system_prompts[_task] | ||
charm_memory_eval_cfg = dict( | ||
evaluator=dict( | ||
type=LMEvaluator, | ||
prompt_template=dict( | ||
type=PromptTemplate, | ||
template=dict(round=[ | ||
dict( | ||
role='HUMAN', | ||
prompt=judge_system_prompt + | ||
"\n\n[Question]\n{input}\n[The Start of Reference Answer]\n{target}\n[The End of Reference Answer]\n\n[The Start of Assistant's Answer]\n{prediction}\n[The End of Assistant's Answer]" # noqa | ||
), | ||
]), | ||
), | ||
), | ||
pred_role='BOT', | ||
) | ||
|
||
charm_memory_datasets.append( | ||
dict( | ||
type=CharmDataset, | ||
path=dataset_path, | ||
name=_task, | ||
abbr='charm-memory-' + _task, | ||
reader_cfg=charm_memory_reader_cfg, | ||
infer_cfg=charm_memory_infer_cfg.copy(), | ||
eval_cfg=charm_memory_eval_cfg.copy(), | ||
)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
import os | ||
|
||
charm_memory_tasks = [ | ||
'Chinese_Anachronisms_Judgment', | ||
'Chinese_Movie_and_Music_Recommendation', | ||
'Chinese_Sport_Understanding', | ||
'Chinese_Time_Understanding', | ||
] | ||
|
||
dataset_path = 'data/CHARM/memorization' | ||
|
||
system_prompt_template = """Please act as an impartial judge, comparing the responses of the AI assistants to the reference answer and determining if the answers are correct. | ||
You will receive the reference answer provided by a human and the responses of the AI assistants. | ||
Your task is to judge whether the AI assistant's answers is correct. | ||
{task_specific_prompt} | ||
After providing your explanation, strictly output your final judgment in the following format: “[正确]” if the AI assistant's response is correct, “[错误]” if the AI assistant's response is incorrect. | ||
""" | ||
|
||
task_specific_prompts = { | ||
'Chinese_Anachronisms_Judgment': | ||
"If the provided reference answer is a list, the model's prediction is considered correct if it matches any item in the list.", | ||
'Chinese_Time_Understanding': | ||
"When evaluating the AI assistant's response regarding Chinese solar terms, as long as the AI assistant's response falls within the time frame provided in the reference answer, consider it correct.", | ||
'Chinese_Sport_Understanding': | ||
"If the provided reference answer is a list, the model's prediction is considered correct if it matches any item in the list." | ||
} | ||
|
||
judge_system_prompts = { | ||
k: system_prompt_template.format(task_specific_prompt=v) | ||
for k, v in task_specific_prompts.items() | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
from mmengine.config import read_base | ||
|
||
from opencompass.models import OpenAI | ||
from opencompass.runners import LocalRunner | ||
from opencompass.partitioners.sub_size import SubjectiveSizePartitioner | ||
from opencompass.tasks.subjective_eval import SubjectiveEvalTask | ||
from opencompass.summarizers import CharmMemSummarizer | ||
|
||
with read_base(): | ||
from .datasets.CHARM.charm_memory_gen_bbbd53 import charm_memory_datasets as datasets | ||
|
||
# ------>>>>>> https://arxiv.org/abs/2403.14112 | ||
# from .models.openai.gpt_3_5_turbo_1106 import models as gpt_3_5_turbo_1106_model | ||
# from .models.openai.gpt_4_1106_preview import models as gpt_4_1106_preview_model | ||
# from .models.hf_llama.hf_llama2_7b_chat import models as llama2_7b_chat_model | ||
# from .models.hf_llama.hf_llama2_13b_chat import models as llama2_13b_chat_model | ||
# from .models.hf_llama.hf_llama2_70b_chat import models as llama2_70b_chat_model | ||
# from .models.vicuna.hf_vicuna_7b_v15_16k import models as vicuna_7b_v15_16k_model | ||
# from .models.vicuna.hf_vicuna_13b_v15_16k import models as vicuna_13b_v15_16k_model | ||
# from .models.chatglm.hf_chatglm3_6b_32k import models as chatglm3_6b_32k_model | ||
# from .models.baichuan.hf_baichuan2_7b_chat import models as baichuan2_7b_chat_model # need torch 2.1 | ||
# from .models.baichuan.hf_baichuan2_13b_chat import models as baichuan2_13b_chat_model # need torch 2.1 | ||
# from .models.hf_internlm.hf_internlm2_chat_7b import models as hf_internlm2_chat_7b_model | ||
# from .models.hf_internlm.hf_internlm2_chat_20b import models as hf_internlm2_chat_20b_model | ||
# from .models.yi.hf_yi_6b_chat import models as yi_6b_chat_model | ||
# from .models.yi.hf_yi_34b_chat import models as yi_34b_chat_model | ||
# from .models.deepseek.hf_deepseek_7b_chat import models as deepseek_7b_chat_model | ||
# from .models.deepseek.hf_deepseek_67b_chat import models as deepseek_67b_chat_model | ||
# from .models.qwen.hf_qwen_7b_chat import models as qwen_7b_chat_model | ||
# from .models.qwen.hf_qwen_14b_chat import models as qwen_14b_chat_model | ||
# from .models.qwen.hf_qwen_72b_chat import models as qwen_72b_chat_model | ||
# <<<<<<------ https://arxiv.org/abs/2403.14112 | ||
|
||
# from .models.openai.gpt_3_5_turbo_0125 import models as gpt_3_5_turbo_0125_model | ||
# from .models.openai.gpt_4o_2024_05_13 import models as gpt_4o_2024_05_13_model | ||
# from .models.gemini.gemini_1_5_flash import models as gemini_1_5_flash_model | ||
# from .models.gemini.gemini_1_5_pro import models as gemini_1_5_pro_model | ||
|
||
# from .models.hf_llama.lmdeploy_llama3_8b_instruct import models as lmdeploy_llama3_8b_instruct_model | ||
# from .models.hf_llama.lmdeploy_llama3_70b_instruct import models as lmdeploy_llama3_70b_instruct_model | ||
|
||
# from .models.hf_internlm.lmdeploy_internlm2_chat_1_8b import models as lmdeploy_internlm2_chat_1_8b_model | ||
# from .models.hf_internlm.lmdeploy_internlm2_chat_7b import models as lmdeploy_internlm2_chat_7b_model | ||
# from .models.hf_internlm.lmdeploy_internlm2_chat_20b import models as lmdeploy_internlm2_chat_20b_model | ||
|
||
# from .models.yi.hf_yi_1_5_6b_chat import models as yi_1_5_6b_chat_model | ||
# from .models.yi.hf_yi_1_5_34b_chat import models as yi_1_5_34b_chat_model | ||
|
||
# from .models.deepseek.hf_deepseek_v2_chat import models as deepseek_v2_chat_model | ||
|
||
# from .models.qwen.hf_qwen1_5_1_8b_chat import models as qwen1_5_1_8b_chat_model | ||
# from .models.qwen.hf_qwen1_5_7b_chat import models as qwen1_5_7b_chat_model | ||
# from .models.qwen.hf_qwen1_5_14b_chat import models as qwen1_5_14b_chat_model | ||
# from .models.qwen.hf_qwen1_5_72b_chat import models as qwen1_5_72b_chat_model | ||
|
||
models = sum([v for k, v in locals().items() if k.endswith('_model')], []) | ||
|
||
## ------------- JudgeLLM Configuration | ||
api_meta_template = dict(round=[ | ||
dict(role='HUMAN', api_role='HUMAN'), | ||
dict(role='BOT', api_role='BOT', generate=True), | ||
]) | ||
judge_models = [ | ||
dict( | ||
abbr='GPT-3.5-turbo-0125', | ||
type=OpenAI, | ||
path='gpt-3.5-turbo-0125', | ||
key='ENV', | ||
meta_template=api_meta_template, | ||
query_per_second=16, | ||
max_out_len=2048, | ||
max_seq_len=2048, | ||
batch_size=8, | ||
temperature=0, | ||
) | ||
] | ||
|
||
## ------------- Evaluation Configuration | ||
eval = dict( | ||
partitioner=dict( | ||
type=SubjectiveSizePartitioner, | ||
max_task_size=1000, | ||
mode='singlescore', | ||
models=models, | ||
judge_models=judge_models, | ||
), | ||
runner=dict(type=LocalRunner, | ||
max_num_workers=2, | ||
task=dict(type=SubjectiveEvalTask)), | ||
) | ||
|
||
summarizer = dict(type=CharmMemSummarizer) | ||
|
||
work_dir = './outputs/CHARM_mem/chat/' |
Oops, something went wrong.