Skip to content

Commit

Permalink
add: Joint Inference paradigm and cloud-edge collaborative inference …
Browse files Browse the repository at this point in the history
…example

Signed-off-by: Yu Fan <[email protected]>

add: sedna 0.6.0.1 and move 0.4.1 to third_party-bk

Signed-off-by: Yu Fan <[email protected]>
  • Loading branch information
FuryMartin committed Oct 30, 2024
1 parent 9553051 commit 62b598c
Show file tree
Hide file tree
Showing 44 changed files with 2,577 additions and 33 deletions.
10 changes: 10 additions & 0 deletions core/common/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ class DatasetFormat(Enum):
CSV = "csv"
TXT = "txt"
JSON = "json"
JSONL = "jsonl"
JSONFORLLM = "jsonforllm"


class ParadigmType(Enum):
Expand All @@ -39,6 +41,7 @@ class ParadigmType(Enum):
LIFELONG_LEARNING = "lifelonglearning"
FEDERATED_LEARNING = "federatedlearning"
FEDERATED_CLASS_INCREMENTAL_LEARNING = "federatedclassincrementallearning"
JOINT_INFERENCE = "jointinference"


class ModuleType(Enum):
Expand All @@ -48,6 +51,13 @@ class ModuleType(Enum):

BASEMODEL = "basemodel"

# JOINT INFERENCE
EDGEMODEL = "edgemodel"
CLOUDMODEL = "cloudmodel"

# Dataset Preprocessor
DATA_PROCESSOR = "dataset_processor"

# HEM
HARD_EXAMPLE_MINING = "hard_example_mining"

Expand Down
6 changes: 5 additions & 1 deletion core/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,12 @@ def is_local_dir(url):

def get_file_format(url):
"""Get file format of the url."""
return os.path.splitext(url)[-1][1:]
# Check if the url
if os.path.basename(url) == "metadata.json":
return "jsonforllm"

# Check if the url
return os.path.splitext(url)[-1][1:]

def parse_kwargs(func, **kwargs):
"""Get valid parameters of the func in kwargs."""
Expand Down
4 changes: 0 additions & 4 deletions core/storymanager/rank/rank.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,6 @@ def _sort_all_df(self, all_df, all_metric_names):

if metric_name not in all_metric_names:
continue
print(metric_name)
sort_metric_list.append(metric_name)
is_ascend_list.append(ele.get(metric_name) == "ascend")

Expand Down Expand Up @@ -234,15 +233,12 @@ def _draw_pictures(self, test_cases, test_results):
out_put = test_case.output_dir
test_result = test_results[test_case.id][0]
matrix = test_result.get("Matrix")
# print(out_put)
for key in matrix.keys():
draw_heatmap_picture(out_put, key, matrix[key])

def _prepare(self, test_cases, test_results, output_dir):
all_metric_names = self._get_all_metric_names(test_results)
print(f"in_prepare all_metric_names: {all_metric_names}")
all_hps_names = self._get_all_hps_names(test_cases)
print(f"in_prepare all_hps_names: {all_hps_names}")
all_module_types = self._get_all_module_types(test_cases)
self.all_df_header = [
"algorithm", *all_metric_names,
Expand Down
3 changes: 1 addition & 2 deletions core/storymanager/visualization/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
def print_table(rank_file):
""" print rank of the test"""
with open(rank_file, "r", encoding="utf-8") as file:
table = from_csv(file)
table = from_csv(file, delimiter=",")
print(table)

def draw_heatmap_picture(output, title, matrix):
Expand All @@ -40,7 +40,6 @@ def draw_heatmap_picture(output, title, matrix):
plt.title(title, fontsize=15)
plt.colorbar(format='%.2f')
output_dir = os.path.join(output, f"output/{title}-heatmap.png")
#print(output_dir)
plt.savefig(output_dir)
plt.show()

Expand Down
6 changes: 5 additions & 1 deletion core/testcasecontroller/algorithm/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@
MultiedgeInference,
LifelongLearning,
FederatedLearning,
FederatedClassIncrementalLearning
FederatedClassIncrementalLearning,
JointInference
)
from core.testcasecontroller.generation_assistant import get_full_combinations

Expand Down Expand Up @@ -120,6 +121,9 @@ def paradigm(self, workspace: str, **kwargs):
if self.paradigm_type == ParadigmType.FEDERATED_CLASS_INCREMENTAL_LEARNING.value:
return FederatedClassIncrementalLearning(workspace, **config)

if self.paradigm_type == ParadigmType.JOINT_INFERENCE.value:
return JointInference(workspace, **config)

return None

def _check_fields(self):
Expand Down
17 changes: 11 additions & 6 deletions core/testcasecontroller/algorithm/module/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def _check_fields(self):
if not isinstance(self.url, str):
raise ValueError(f"module url({self.url}) must be string type.")

#pylint: disable=too-many-branches
def get_module_instance(self, module_type):
"""
get function of algorithm module by using module type
Expand All @@ -86,7 +87,6 @@ def get_module_instance(self, module_type):
function
"""
print(f'hyperparameters_list: {self.hyperparameters_list}')
class_factory_type = ClassType.GENERAL
if module_type in [ModuleType.HARD_EXAMPLE_MINING.value]:
class_factory_type = ClassType.HEM
Expand All @@ -110,13 +110,11 @@ def get_module_instance(self, module_type):
elif module_type in [ModuleType.AGGREGATION.value]:
class_factory_type = ClassType.FL_AGG
agg = None
print(self.url)
if self.url :
try:
utils.load_module(self.url)
agg = ClassFactory.get_cls(
type_name=class_factory_type, t_cls_name=self.name)(**self.hyperparameters)
print(agg)
except Exception as err:
raise RuntimeError(f"module(type={module_type} loads class(name={self.name}) "
f"failed, error: {err}.") from err
Expand All @@ -125,10 +123,17 @@ def get_module_instance(self, module_type):
if self.url:
try:
utils.load_module(self.url)
# pylint: disable=E1134
func = ClassFactory.get_cls(
type_name=class_factory_type, t_cls_name=self.name)(**self.hyperparameters)

if class_factory_type == ClassType.HEM:
func = {"method": self.name, "param":self.hyperparameters}
else:
func = ClassFactory.get_cls(
type_name=class_factory_type,
t_cls_name=self.name
)(**self.hyperparameters)

return func

except Exception as err:
raise RuntimeError(f"module(type={module_type} loads class(name={self.name}) "
f"failed, error: {err}.") from err
Expand Down
1 change: 1 addition & 0 deletions core/testcasecontroller/algorithm/paradigm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,4 @@
from .multiedge_inference import MultiedgeInference
from .lifelong_learning import LifelongLearning
from .federated_learning import FederatedLearning, FederatedClassIncrementalLearning
from .joint_inference import JointInference
18 changes: 17 additions & 1 deletion core/testcasecontroller/algorithm/paradigm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from sedna.core.incremental_learning import IncrementalLearning
from sedna.core.lifelong_learning import LifelongLearning
from sedna.core.joint_inference import JointInference
from core.common.constant import ModuleType, ParadigmType
from .sedna_federated_learning import FederatedLearning

Expand Down Expand Up @@ -76,6 +77,7 @@ def _get_module_instances(self):
module_instances.update({module_type: func})
return module_instances

# pylint: disable=too-many-return-statements
def build_paradigm_job(self, paradigm_type):
"""
build paradigm job instance according to paradigm type.
Expand Down Expand Up @@ -103,7 +105,10 @@ def build_paradigm_job(self, paradigm_type):

if paradigm_type == ParadigmType.LIFELONG_LEARNING.value:
return LifelongLearning(
estimator=self.module_instances.get(ModuleType.BASEMODEL.value),
seen_estimator=self.module_instances.get(
ModuleType.BASEMODEL.value
),
unseen_estimator=None,
task_definition=self.module_instances.get(
ModuleType.TASK_DEFINITION.value
),
Expand Down Expand Up @@ -144,4 +149,15 @@ def build_paradigm_job(self, paradigm_type):
estimator=self.module_instances.get(ModuleType.BASEMODEL.value)
)

if paradigm_type == ParadigmType.JOINT_INFERENCE.value:
return JointInference(
estimator=self.module_instances.get(
ModuleType.EDGEMODEL.value),
cloud=self.module_instances.get(
ModuleType.CLOUDMODEL.value),
hard_example_mining=self.module_instances.get(
ModuleType.HARD_EXAMPLE_MINING.value),
LCReporter_enable=False
)

return None
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Copyright 2022 The KubeEdge Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# pylint: disable=missing-module-docstring
from .joint_inference import JointInference
Loading

0 comments on commit 62b598c

Please sign in to comment.