From 37f47371d262e2081090b3e7aafaeef7e01a8fd6 Mon Sep 17 00:00:00 2001 From: Yu Fan Date: Sun, 6 Oct 2024 00:10:37 +0800 Subject: [PATCH] add: Oracle Router Signed-off-by: Yu Fan --- .../joint_inference/joint_inference.py | 71 +++++++++++--- .../query-routing/cloud_model.py | 7 +- .../query-routing/edge_model.py | 13 +-- .../query-routing/hard_sample_mining.py | 66 +++++++++++-- .../query-routing/models/base_llm.py | 95 ++++++++++++------- .../query-routing/models/huggingface_llm.py | 3 - .../query-routing/test_queryrouting.yaml | 39 +------- 7 files changed, 187 insertions(+), 107 deletions(-) diff --git a/core/testcasecontroller/algorithm/paradigm/joint_inference/joint_inference.py b/core/testcasecontroller/algorithm/paradigm/joint_inference/joint_inference.py index b5a8755c..7886bb96 100644 --- a/core/testcasecontroller/algorithm/paradigm/joint_inference/joint_inference.py +++ b/core/testcasecontroller/algorithm/paradigm/joint_inference/joint_inference.py @@ -18,6 +18,7 @@ import os from tqdm import tqdm +from core.common.log import LOGGER from core.common.constant import ParadigmType from core.testcasecontroller.algorithm.paradigm.base import ParadigmBase @@ -55,6 +56,36 @@ def __init__(self, workspace, **kwargs): "mining-then-inference" ) + def set_config(self): + shot_nums = self.kwargs.get("shot_nums", 0) + + inference_output_dir = os.path.join(os.path.dirname(self.workspace), f"{shot_nums}-shot/") + os.environ["RESULT_SAVED_URL"] = inference_output_dir + os.makedirs(inference_output_dir, exist_ok=True) + + LOGGER.info("Loading dataset") + + self.inference_dataset = self.dataset.load_data( + self.dataset.test_data_info, + "inference", + shot_nums = self.kwargs.get("shot_nums", 0) + ) + + # validate module instances + required_modules = {"edgemodel", "cloudmodel", "hard_example_mining"} + if self.module_instances.keys() != required_modules: + raise ValueError( + f"Required modules: {required_modules}, " + f"but got: {self.module_instances.keys()}" + ) + + # if hard example mining is OracleRouter, add the edgemodel and cloudmodel object to its kwargs so that it can use them. + mining = self.module_instances["hard_example_mining"] + param = mining.get("param") + if mining.get("method", None) == "OracleRouter": + param["edgemodel"] = self.module_instances["edgemodel"] + param["cloudmodel"] = self.module_instances["cloudmodel"] + def run(self): """ run the test flow of multi-edge inference paradigm. @@ -66,39 +97,47 @@ def run(self): information needed to compute system metrics. """ + self.set_config() job = self.build_paradigm_job(ParadigmType.JOINT_INFERENCE.value) inference_result = self._inference(job) + self._cleanup(job) + + return inference_result, self.system_metric_info + + def _cleanup(self, job): + LOGGER.info("Release models") + # release module resources for module in self.module_instances.values(): if hasattr(module, "cleanup"): module.cleanup() + + # Since the hard example mining module is instantiated within the job, special handling is required. + mining_instance = job.hard_example_mining_algorithm + if hasattr(mining_instance, "cleanup"): + mining_instance.cleanup() + del job - return inference_result, self.system_metric_info - def _inference(self, job): - shot_nums = self.kwargs.get("shot_nums", 0) - # Ianvs API - inference_dataset = self.dataset.load_data( - self.dataset.test_data_info, - "inference", - shot_nums = shot_nums - ) - inference_output_dir = os.path.join(os.path.dirname(self.workspace), f"{shot_nums}-shot/") - os.environ["RESULT_SAVED_URL"] = inference_output_dir - os.makedirs(inference_output_dir, exist_ok=True) - results = [] cloud_count, edge_count = 0,0 - pbar = tqdm(inference_dataset.x, ncols=100) + + LOGGER.info("Inference Start") + + pbar = tqdm( + zip(self.inference_dataset.x, self.inference_dataset.y), + total=len(self.inference_dataset.x), + ncols=100 + ) for data in pbar: # inference via sedna JointInference API infer_res = job.inference( - data, + {"messages": data[0], "gold": data[1]}, mining_mode=self.hard_example_mining_mode ) @@ -111,4 +150,6 @@ def _inference(self, job): results.append(infer_res) + LOGGER.info("Inference Finished") + return results # (is_hard_example, res, edge_result, cloud_result) diff --git a/examples/cloud-edge-collaborative-inference-for-llm/testalgorithms/query-routing/cloud_model.py b/examples/cloud-edge-collaborative-inference-for-llm/testalgorithms/query-routing/cloud_model.py index 7d84ee70..bb903ee2 100644 --- a/examples/cloud-edge-collaborative-inference-for-llm/testalgorithms/query-routing/cloud_model.py +++ b/examples/cloud-edge-collaborative-inference-for-llm/testalgorithms/query-routing/cloud_model.py @@ -15,10 +15,10 @@ from __future__ import absolute_import, division, print_function import os +import json import tempfile import time import zipfile -import logging import numpy as np from sedna.common.config import Context @@ -29,15 +29,16 @@ from transformers import AutoModelForCausalLM, AutoTokenizer device = "cuda" # the device to load the model onto -os.environ['BACKEND_TYPE'] = 'TORCH' +from core.common.log import LOGGER -logging.disable(logging.WARNING) +os.environ['BACKEND_TYPE'] = 'TORCH' __all__ = ["BaseModel"] @ClassFactory.register(ClassType.GENERAL, alias="CloudModel") class CloudModel: def __init__(self, **kwargs): + LOGGER.info(kwargs) # The API KEY and API URL are confidential data and should not be written in yaml. self.model = APIBasedLLM(**kwargs) diff --git a/examples/cloud-edge-collaborative-inference-for-llm/testalgorithms/query-routing/edge_model.py b/examples/cloud-edge-collaborative-inference-for-llm/testalgorithms/query-routing/edge_model.py index 31e072f7..7ae6a52d 100644 --- a/examples/cloud-edge-collaborative-inference-for-llm/testalgorithms/query-routing/edge_model.py +++ b/examples/cloud-edge-collaborative-inference-for-llm/testalgorithms/query-routing/edge_model.py @@ -18,7 +18,6 @@ import tempfile import time import zipfile -import logging import numpy as np from sedna.common.config import Context @@ -29,14 +28,9 @@ from transformers import AutoModelForCausalLM, AutoTokenizer device = "cuda" # the device to load the model onto -os.environ['BACKEND_TYPE'] = 'TORCH' +from core.common.log import LOGGER -logging.disable(logging.WARNING) -logging.basicConfig( - level=logging.INFO, - format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" -) -logger = logging.getLogger("EdgeModel") +os.environ['BACKEND_TYPE'] = 'TORCH' __all__ = ["BaseModel"] @@ -46,6 +40,8 @@ class EdgeModel: This is actually the Edge Model. """ def __init__(self, **kwargs): + LOGGER.info(kwargs) + self.kwargs = kwargs self.model_name = kwargs.get("model", None) self.backend = kwargs.get("backend", "huggingface") @@ -71,7 +67,6 @@ def load(self, **kwargs): else: raise Exception(f"Backend {self.backend} is not supported") - logger.info(f"Using Backend: {self.backend}") self.model.load(model_url=self.model_name) # TODO cloud service must be configured in JointInference diff --git a/examples/cloud-edge-collaborative-inference-for-llm/testalgorithms/query-routing/hard_sample_mining.py b/examples/cloud-edge-collaborative-inference-for-llm/testalgorithms/query-routing/hard_sample_mining.py index a053641e..3b7dcb98 100644 --- a/examples/cloud-edge-collaborative-inference-for-llm/testalgorithms/query-routing/hard_sample_mining.py +++ b/examples/cloud-edge-collaborative-inference-for-llm/testalgorithms/query-routing/hard_sample_mining.py @@ -18,12 +18,16 @@ import random from transformers import pipeline from sedna.common.class_factory import ClassFactory, ClassType +from core.common.log import LOGGER __all__ = ('ThresholdFilter', 'CrossEntropyFilter', 'IBTFilter') class BaseFilter(metaclass=abc.ABCMeta): """The base class to define unified interface.""" + def __init__(self, **kwargs): + LOGGER.info(f"USING {self.__class__.__name__}") + def __call__(self, infer_result=None): """ predict function, judge the sample is hard or not. @@ -49,7 +53,10 @@ def data_check(cls, data): @ClassFactory.register(ClassType.HEM, alias="BERTRouter") class BERTFilter(BaseFilter, abc.ABC): def __init__(self, **kwargs): + super().__init__(**kwargs) + self.kwargs = kwargs + LOGGER.info(kwargs) self.model = kwargs.get("model", "routellm/bert") self.task = kwargs.get("task", "text-classification") @@ -77,9 +84,8 @@ def _predict(self, data): return is_hard_example def _preprocess(self, data): - if "question" in data: - data = data.get("question") - return data[:self.max_length] + messages = data.get("messages") + return messages[-1]["content"][:self.max_length] def cleanup(self): del self.classifier @@ -91,7 +97,7 @@ def __call__(self, data=None) -> bool: @ClassFactory.register(ClassType.HEM, alias="EdgeOnly") class EdgeOnlyFilter(BaseFilter, abc.ABC): def __init__(self, **kwargs): - pass + super().__init__(**kwargs) def __call__(self, data=None) -> bool: return False @@ -99,7 +105,7 @@ def __call__(self, data=None) -> bool: @ClassFactory.register(ClassType.HEM, alias="CloudOnly") class CloudOnlyFilter(BaseFilter, abc.ABC): def __init__(self, **kwargs): - pass + super().__init__(**kwargs) def __call__(self, data=None) -> bool: return True @@ -107,7 +113,55 @@ def __call__(self, data=None) -> bool: @ClassFactory.register(ClassType.HEM, alias="RandomRouter") class RandomRouterFilter(BaseFilter, abc.ABC): def __init__(self, **kwargs): + super().__init__(**kwargs) self.threshold = kwargs.get("threshold", 0) def __call__(self, data=None) -> bool: - return False if random.random() < self.threshold else True \ No newline at end of file + return False if random.random() < self.threshold else True + +@ClassFactory.register(ClassType.HEM, alias="OracleRouter") +class OracleRouterFilter(BaseFilter, abc.ABC): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.edge_better = 0 + self.cloud_better = 0 + self.both_right = 0 + self.both_wrong = 0 + + self.edge_model = kwargs.get("edgemodel") + self.cloud_model = kwargs.get("cloudmodel") + + def __call__(self, data=None) -> bool: + gold = data.get("gold", None) + + edge_result = self.edge_model.predict(data).get("prediction") + cloud_result = self.cloud_model.inference(data).get("prediction") + + both_right = edge_result == gold and cloud_result == gold + both_wrong = edge_result != gold and cloud_result != gold + edge_better = edge_result == gold and cloud_result != gold + cloud_better = edge_result != gold and cloud_result == gold + + if both_right: + self.both_right +=1 + elif both_wrong: + self.both_wrong += 1 + elif edge_better: + self.edge_better += 1 + elif cloud_better: + self.cloud_better += 1 + + if cloud_better: + # cloud is better than edge, hard sample + return True + else: + # both correct + both wrong + edge_better, easy sample + return False + + def cleanup(self): + print( + f"Both Wrong: {self.both_wrong}\n", + f"Both Correct: {self.both_right}\n", + f"Edge Better: {self.edge_better}\n", + f"Cloud Better: {self.cloud_better}" + ) \ No newline at end of file diff --git a/examples/cloud-edge-collaborative-inference-for-llm/testalgorithms/query-routing/models/base_llm.py b/examples/cloud-edge-collaborative-inference-for-llm/testalgorithms/query-routing/models/base_llm.py index 502667cb..122bc34d 100644 --- a/examples/cloud-edge-collaborative-inference-for-llm/testalgorithms/query-routing/models/base_llm.py +++ b/examples/cloud-edge-collaborative-inference-for-llm/testalgorithms/query-routing/models/base_llm.py @@ -1,57 +1,78 @@ import os import json +# from evals import extract_prediction + +def extract_prediction(input_string): + # 检查输入是否为空或只包含非字母字符 + if not input_string or not any(char.isalpha() for char in input_string): + return None + # 倒序遍历字符串,找到最后一个字母 + for char in reversed(input_string): + if 'A' <= char <= 'D': + return char + # 如果没有找到字母,返回None + return None + class BaseLLM: def __init__(self, **kwargs) -> None: self.config = kwargs - self.parse_kwargs(**kwargs) + self._parse_kwargs(**kwargs) self.is_cache_loaded = False - + def load(self): raise NotImplementedError - - def parse_kwargs(self, **kwargs): + + def _parse_kwargs(self, **kwargs): self.quantization = kwargs.get("quantization", "full") self.temperature = kwargs.get("temperature", 0.8) self.top_p = kwargs.get("top_p", 0.8) self.repetition_penalty = kwargs.get("repetition_penalty", 1.05) self.max_tokens = kwargs.get("max_tokens", 512) self.use_cache = kwargs.get("use_cache", True) - + def inference(self, data): - - if isinstance(data, dict): + + if isinstance(data, list): return [self._infer(line) for line in data] - + elif isinstance(data, str): return self._infer(data) - - elif isinstance(data, list): - # from viztracer import VizTracer - # import sys - # with VizTracer(output_file="optional.json") as tracer: - # question, system_prompt = self.parse_input(data) - messages = data - system_prompt = messages[0]["content"] + + elif isinstance(data, dict): + + gold = data.get("gold", None) + messages = data.get("messages") + + if messages[0]['role'] == "system": + system_prompt = messages[0]["content"] + else: + system_prompt = "" + question = messages[-1]["content"] if self.use_cache: - response = self.try_cache(question, system_prompt) + response = self._try_cache(question, system_prompt) if response is not None: return response response = self._infer(messages) + + prediction = extract_prediction(response.get("completion")) + + response["prediction"] = prediction + if self.use_cache: - self._update_cache(question, system_prompt, response) + self._update_cache(question, system_prompt, response, prediction, gold) # sys.exit(0) return response - + else: raise ValueError(f"DataType {type(data)} is not supported, it must be `list` or `str` or `dict`") - + def get_message_chain(self, question, system = None): - if system: + if system: messages = [ {"role": "system", "content": system}, {"role": "user", "content": question} @@ -60,9 +81,9 @@ def get_message_chain(self, question, system = None): messages = [ {"role": "user", "content": question} ] - + return messages - + def validate_input(self, data): expected_format = """{'question':'Lorem', "prompts": {infer_system_prompt:"Lorem"}}""" @@ -73,7 +94,7 @@ def validate_input(self, data): def parse_input(self,data): self.validate_input(data) - # data should have format like: + # data should have format like: # {"question":"Lorem", "prompt": {infer_system_prompt:"Lorem"}} question = data.get("question") prompt_dict = data.get("prompts") @@ -83,7 +104,7 @@ def parse_input(self,data): def _infer(self, messages): raise NotImplementedError - + def _format_response(self, text, prompt_tokens, completion_tokens, time_to_first_token, internal_token_latency, throughput): total_tokens = prompt_tokens + completion_tokens @@ -100,13 +121,13 @@ def _format_response(self, text, prompt_tokens, completion_tokens, time_to_first } resposne = { - "completion": text, + "completion": text, "usage":usage, "perf":perf } return resposne - + def _load_cache(self): self.cache = None self.cache_hash = {} @@ -115,29 +136,31 @@ def _load_cache(self): cache_file = os.path.join(os.environ["RESULT_SAVED_URL"], "cache.json") if os.path.exists(cache_file): with open(cache_file, "r", encoding="utf-8") as f: - self.cache_models = json.load(f) + self.cache_models = json.load(f) for cache in self.cache_models: if cache["config"] == self.config: self.cache = cache self.cache_hash = {(item["question"], item["system_prompt"]):item['response'] for item in cache["result"]} self.is_cache_loaded = True - def try_cache(self, question, system_prompt): - + def _try_cache(self, question, system_prompt): + if not self.is_cache_loaded: self._load_cache() return self.cache_hash.get((question, system_prompt), None) - - def _update_cache(self, question, system_prompt, response): - + + def _update_cache(self, question, system_prompt, response, prediction, gold): + if not self.is_cache_loaded: self._load_cache() new_item = { "question": question, "system_prompt": system_prompt, - "response": response + "response": response, + "prediction": prediction, + "gold": gold } self.cache_hash[(question, system_prompt)] = response @@ -147,9 +170,9 @@ def _update_cache(self, question, system_prompt, response): else: self.cache = {"config": self.config, "result": [new_item]} self.cache_models.append(self.cache) - + def save_cache(self): - + cache_file = os.path.join(os.environ["RESULT_SAVED_URL"], "cache.json") if self.is_cache_loaded: diff --git a/examples/cloud-edge-collaborative-inference-for-llm/testalgorithms/query-routing/models/huggingface_llm.py b/examples/cloud-edge-collaborative-inference-for-llm/testalgorithms/query-routing/models/huggingface_llm.py index 4bd51d10..914361cd 100644 --- a/examples/cloud-edge-collaborative-inference-for-llm/testalgorithms/query-routing/models/huggingface_llm.py +++ b/examples/cloud-edge-collaborative-inference-for-llm/testalgorithms/query-routing/models/huggingface_llm.py @@ -7,9 +7,6 @@ device = "cuda" os.environ["TOKENIZERS_PARALLELISM"] = "true" -from logging import getLogger -logger = getLogger("HuggingfaceLLM") - class HuggingfaceLLM(BaseLLM): def __init__(self, **kwargs) -> None: BaseLLM.__init__(self, **kwargs) diff --git a/examples/cloud-edge-collaborative-inference-for-llm/testalgorithms/query-routing/test_queryrouting.yaml b/examples/cloud-edge-collaborative-inference-for-llm/testalgorithms/query-routing/test_queryrouting.yaml index fb41f008..32c53254 100644 --- a/examples/cloud-edge-collaborative-inference-for-llm/testalgorithms/query-routing/test_queryrouting.yaml +++ b/examples/cloud-edge-collaborative-inference-for-llm/testalgorithms/query-routing/test_queryrouting.yaml @@ -22,20 +22,9 @@ algorithm: # name of the hyperparameter; string type; - model: values: - # - "openbmb/MiniCPM3-4B" - # - "openbmb/MiniCPM3-4B-GPTQ-Int4" - # - "Qwen/Qwen2.5-1.5B-Instruct" - # - "Qwen/Qwen2.5-3B-Instruct" - # - "Qwen/Qwen2.5-7B-Instruct" - # - "Qwen/Qwen2-1.5B-Instruct" - # - "Qwen/Qwen2-7B-Instruct" - # - "Qwen/Qwen2.5-3B-Instruct-GPTQ-Int4" - - "Qwen/Qwen2.5-3B-Instruct-GPTQ-Int8" - # - "Qwen/Qwen2.5-3B-Instruct-AWQ" - # - "Qwen/Qwen2-7B-Instruct-GPTQ-Int4" + - "Qwen/Qwen2.5-0.5B-Instruct" - backend: values: - # - "huggingface" - "vllm" - temperature: values: @@ -72,8 +61,7 @@ algorithm: # name of the hyperparameter; string type; - model: values: - # - "gpt-4o-mini" - - "deepseek-chat" + - "gpt-4o-mini" - temperature: values: - 0 @@ -93,25 +81,6 @@ algorithm: - type: "hard_example_mining" # name of python module; string type; # BERTRouter, EdgeOnly, CloudOnly, RandomRouter - name: "BERTRouter" + name: "OracleRouter" # the url address of python module; string type; - url: "./examples/cloud-edge-collaborative-inference-for-llm/testalgorithms/query-routing/hard_sample_mining.py" - hyperparameters: - - model: - values: - - routellm/bert - - threshold: - values: - - 0.65 - - 0.60 - - 0.55 - - 0.50 - - 0.45 - - 0.40 - - 0.35 - - 0.30 - - 0.25 - - 0.20 - - 0.15 - - 0.10 - - 0.05 \ No newline at end of file + url: "./examples/cloud-edge-collaborative-inference-for-llm/testalgorithms/query-routing/hard_sample_mining.py" \ No newline at end of file