Skip to content

Commit

Permalink
add: Support Custom Dataset Process for Joint Inference
Browse files Browse the repository at this point in the history
Signed-off-by: Yu Fan <[email protected]>
  • Loading branch information
FuryMartin committed Oct 17, 2024
1 parent 8fee150 commit 1e6b6f3
Show file tree
Hide file tree
Showing 7 changed files with 88 additions and 348 deletions.
3 changes: 3 additions & 0 deletions core/common/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ class ModuleType(Enum):
EDGEMODEL = "edgemodel"
CLOUDMODEL = "cloudmodel"

# Dataset Preprocessor
DATA_PROCESSOR = "dataset_processor"

# HEM
HARD_EXAMPLE_MINING = "hard_example_mining"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,14 @@ def set_config(self):
shot_nums = self.kwargs.get("shot_nums", 0)
)

dataset_processor = self.module_instances.get("dataset_processor", None)
if callable(dataset_processor):
self.inference_dataset = dataset_processor(self.inference_dataset)

# validate module instances
required_modules = {"edgemodel", "cloudmodel", "hard_example_mining"}
if self.module_instances.keys() != required_modules:

if not required_modules.issubset(set(self.module_instances.keys())):
raise ValueError(
f"Required modules: {required_modules}, "
f"but got: {self.module_instances.keys()}"
Expand Down Expand Up @@ -129,15 +134,15 @@ def _inference(self, job):
LOGGER.info("Inference Start")

pbar = tqdm(
zip(self.inference_dataset.x, self.inference_dataset.y),
self.inference_dataset.x,
total=len(self.inference_dataset.x),
ncols=100
)

for data in pbar:
# inference via sedna JointInference API
infer_res = job.inference(
{"messages": data[0], "gold": data[1]},
data,
mining_mode=self.hard_example_mining_mode
)

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import numpy as np

from sedna.common.class_factory import ClassFactory, ClassType

PROMPTS = {
"system_prompt": {
"role": "system",
"content": "You are a helpful assistant"
},
"ice_template": [
{
"role": "user",
"content": "There is a single choice question about {level_4_dim}. Answer the question by replying A, B, C or D.\n{query}\nAnswer: "
},
{
"role": "assistant",
"content": "{response}\n"
}
],
"prompt_template": {
"role": "user",
"content": "There is a single choice question about {level_4_dim}. Answer the question by replying A, B, C or D.\n{query}\nAnswer: "
}
}

@ClassFactory.register(ClassType.GENERAL, alias="MultiShotGenertor")
class Multi_Shot_Generator:
def __init__(self, **kwargs):
self.shot_nums = kwargs.get("shot_nums", 0)
def load_prompts(self):
self.system_prompt = PROMPTS.get("system_prompt", None)
self.ice_template = PROMPTS.get('ice_template', None)
self.prompt_template = PROMPTS.get('prompt_template', None)

def multi_shot_generation(self, dataset, shot_nums = 0):
data = [{"query":query, "response":response, "level_4_dim":level_4_dim}
for query, response,level_4_dim in zip(dataset.x, dataset.y, dataset.level_4)]

format_chat = lambda chat, item: {key: value.format(**item) for key, value in chat.items()}

data_array = np.array(data)
data_index = np.arange(len(data))

x = []

for i, item in enumerate(data):
messages = []
if self.system_prompt:
messages.append(self.system_prompt)
if self.ice_template:
shots = np.random.choice(data_array[data_index != i], size=shot_nums, replace=False)
for shot in shots:
formatted_chat = [format_chat(chat, shot) for chat in self.ice_template]
messages.extend(formatted_chat)
final_chat = format_chat(self.prompt_template, item)
messages.append(final_chat)

x.append({"messages":messages,"gold": item["response"]})

dataset.x = x

return dataset

def __call__(self, dataset):
self.load_prompts()
return self.multi_shot_generation(dataset, self.shot_nums)

This file was deleted.

This file was deleted.

Loading

0 comments on commit 1e6b6f3

Please sign in to comment.