Skip to content

Commit

Permalink
add: Oracle Router
Browse files Browse the repository at this point in the history
Signed-off-by: Yu Fan <[email protected]>
  • Loading branch information
FuryMartin committed Oct 5, 2024
1 parent 3a3e6ab commit 37f4737
Show file tree
Hide file tree
Showing 7 changed files with 187 additions and 107 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import os
from tqdm import tqdm

from core.common.log import LOGGER
from core.common.constant import ParadigmType
from core.testcasecontroller.algorithm.paradigm.base import ParadigmBase

Expand Down Expand Up @@ -55,6 +56,36 @@ def __init__(self, workspace, **kwargs):
"mining-then-inference"
)

def set_config(self):
shot_nums = self.kwargs.get("shot_nums", 0)

inference_output_dir = os.path.join(os.path.dirname(self.workspace), f"{shot_nums}-shot/")
os.environ["RESULT_SAVED_URL"] = inference_output_dir
os.makedirs(inference_output_dir, exist_ok=True)

LOGGER.info("Loading dataset")

self.inference_dataset = self.dataset.load_data(
self.dataset.test_data_info,
"inference",
shot_nums = self.kwargs.get("shot_nums", 0)
)

# validate module instances
required_modules = {"edgemodel", "cloudmodel", "hard_example_mining"}
if self.module_instances.keys() != required_modules:
raise ValueError(
f"Required modules: {required_modules}, "
f"but got: {self.module_instances.keys()}"
)

# if hard example mining is OracleRouter, add the edgemodel and cloudmodel object to its kwargs so that it can use them.
mining = self.module_instances["hard_example_mining"]
param = mining.get("param")
if mining.get("method", None) == "OracleRouter":
param["edgemodel"] = self.module_instances["edgemodel"]
param["cloudmodel"] = self.module_instances["cloudmodel"]

def run(self):
"""
run the test flow of multi-edge inference paradigm.
Expand All @@ -66,39 +97,47 @@ def run(self):
information needed to compute system metrics.
"""
self.set_config()

job = self.build_paradigm_job(ParadigmType.JOINT_INFERENCE.value)

inference_result = self._inference(job)

self._cleanup(job)

return inference_result, self.system_metric_info

def _cleanup(self, job):
LOGGER.info("Release models")
# release module resources
for module in self.module_instances.values():
if hasattr(module, "cleanup"):
module.cleanup()

# Since the hard example mining module is instantiated within the job, special handling is required.
mining_instance = job.hard_example_mining_algorithm
if hasattr(mining_instance, "cleanup"):
mining_instance.cleanup()

del job

return inference_result, self.system_metric_info

def _inference(self, job):
shot_nums = self.kwargs.get("shot_nums", 0)
# Ianvs API
inference_dataset = self.dataset.load_data(
self.dataset.test_data_info,
"inference",
shot_nums = shot_nums
)
inference_output_dir = os.path.join(os.path.dirname(self.workspace), f"{shot_nums}-shot/")
os.environ["RESULT_SAVED_URL"] = inference_output_dir
os.makedirs(inference_output_dir, exist_ok=True)

results = []

cloud_count, edge_count = 0,0
pbar = tqdm(inference_dataset.x, ncols=100)

LOGGER.info("Inference Start")

pbar = tqdm(
zip(self.inference_dataset.x, self.inference_dataset.y),
total=len(self.inference_dataset.x),
ncols=100
)

for data in pbar:
# inference via sedna JointInference API
infer_res = job.inference(
data,
{"messages": data[0], "gold": data[1]},
mining_mode=self.hard_example_mining_mode
)

Expand All @@ -111,4 +150,6 @@ def _inference(self, job):

results.append(infer_res)

LOGGER.info("Inference Finished")

return results # (is_hard_example, res, edge_result, cloud_result)
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@
from __future__ import absolute_import, division, print_function

import os
import json
import tempfile
import time
import zipfile
import logging

import numpy as np
from sedna.common.config import Context
Expand All @@ -29,15 +29,16 @@
from transformers import AutoModelForCausalLM, AutoTokenizer
device = "cuda" # the device to load the model onto

os.environ['BACKEND_TYPE'] = 'TORCH'
from core.common.log import LOGGER

logging.disable(logging.WARNING)
os.environ['BACKEND_TYPE'] = 'TORCH'

__all__ = ["BaseModel"]

@ClassFactory.register(ClassType.GENERAL, alias="CloudModel")
class CloudModel:
def __init__(self, **kwargs):
LOGGER.info(kwargs)
# The API KEY and API URL are confidential data and should not be written in yaml.

self.model = APIBasedLLM(**kwargs)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import tempfile
import time
import zipfile
import logging

import numpy as np
from sedna.common.config import Context
Expand All @@ -29,14 +28,9 @@
from transformers import AutoModelForCausalLM, AutoTokenizer
device = "cuda" # the device to load the model onto

os.environ['BACKEND_TYPE'] = 'TORCH'
from core.common.log import LOGGER

logging.disable(logging.WARNING)
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger("EdgeModel")
os.environ['BACKEND_TYPE'] = 'TORCH'

__all__ = ["BaseModel"]

Expand All @@ -46,6 +40,8 @@ class EdgeModel:
This is actually the Edge Model.
"""
def __init__(self, **kwargs):
LOGGER.info(kwargs)

self.kwargs = kwargs
self.model_name = kwargs.get("model", None)
self.backend = kwargs.get("backend", "huggingface")
Expand All @@ -71,7 +67,6 @@ def load(self, **kwargs):
else:
raise Exception(f"Backend {self.backend} is not supported")

logger.info(f"Using Backend: {self.backend}")
self.model.load(model_url=self.model_name)

# TODO cloud service must be configured in JointInference
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,16 @@
import random
from transformers import pipeline
from sedna.common.class_factory import ClassFactory, ClassType
from core.common.log import LOGGER

__all__ = ('ThresholdFilter', 'CrossEntropyFilter', 'IBTFilter')

class BaseFilter(metaclass=abc.ABCMeta):
"""The base class to define unified interface."""

def __init__(self, **kwargs):
LOGGER.info(f"USING {self.__class__.__name__}")

def __call__(self, infer_result=None):
"""
predict function, judge the sample is hard or not.
Expand All @@ -49,7 +53,10 @@ def data_check(cls, data):
@ClassFactory.register(ClassType.HEM, alias="BERTRouter")
class BERTFilter(BaseFilter, abc.ABC):
def __init__(self, **kwargs):
super().__init__(**kwargs)

self.kwargs = kwargs
LOGGER.info(kwargs)

self.model = kwargs.get("model", "routellm/bert")
self.task = kwargs.get("task", "text-classification")
Expand Down Expand Up @@ -77,9 +84,8 @@ def _predict(self, data):
return is_hard_example

def _preprocess(self, data):
if "question" in data:
data = data.get("question")
return data[:self.max_length]
messages = data.get("messages")
return messages[-1]["content"][:self.max_length]

def cleanup(self):
del self.classifier
Expand All @@ -91,23 +97,71 @@ def __call__(self, data=None) -> bool:
@ClassFactory.register(ClassType.HEM, alias="EdgeOnly")
class EdgeOnlyFilter(BaseFilter, abc.ABC):
def __init__(self, **kwargs):
pass
super().__init__(**kwargs)

def __call__(self, data=None) -> bool:
return False

@ClassFactory.register(ClassType.HEM, alias="CloudOnly")
class CloudOnlyFilter(BaseFilter, abc.ABC):
def __init__(self, **kwargs):
pass
super().__init__(**kwargs)

def __call__(self, data=None) -> bool:
return True

@ClassFactory.register(ClassType.HEM, alias="RandomRouter")
class RandomRouterFilter(BaseFilter, abc.ABC):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.threshold = kwargs.get("threshold", 0)

def __call__(self, data=None) -> bool:
return False if random.random() < self.threshold else True
return False if random.random() < self.threshold else True

@ClassFactory.register(ClassType.HEM, alias="OracleRouter")
class OracleRouterFilter(BaseFilter, abc.ABC):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.edge_better = 0
self.cloud_better = 0
self.both_right = 0
self.both_wrong = 0

self.edge_model = kwargs.get("edgemodel")
self.cloud_model = kwargs.get("cloudmodel")

def __call__(self, data=None) -> bool:
gold = data.get("gold", None)

edge_result = self.edge_model.predict(data).get("prediction")
cloud_result = self.cloud_model.inference(data).get("prediction")

both_right = edge_result == gold and cloud_result == gold
both_wrong = edge_result != gold and cloud_result != gold
edge_better = edge_result == gold and cloud_result != gold
cloud_better = edge_result != gold and cloud_result == gold

if both_right:
self.both_right +=1
elif both_wrong:
self.both_wrong += 1
elif edge_better:
self.edge_better += 1
elif cloud_better:
self.cloud_better += 1

if cloud_better:
# cloud is better than edge, hard sample
return True
else:
# both correct + both wrong + edge_better, easy sample
return False

def cleanup(self):
print(
f"Both Wrong: {self.both_wrong}\n",
f"Both Correct: {self.both_right}\n",
f"Edge Better: {self.edge_better}\n",
f"Cloud Better: {self.cloud_better}"
)
Loading

0 comments on commit 37f4737

Please sign in to comment.