Skip to content

Commit

Permalink
fix: fix the pylint check error
Browse files Browse the repository at this point in the history
Signed-off-by: Marchons <[email protected]>
  • Loading branch information
Yoda-wu committed Sep 18, 2024
1 parent ee5af6b commit 20e07b4
Show file tree
Hide file tree
Showing 6 changed files with 48 additions and 34 deletions.
2 changes: 1 addition & 1 deletion core/testcasecontroller/algorithm/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def __init__(self, name, config):
self._parse_config(config)
self._load_third_party_packages()

# pylint: disable=R0911
def paradigm(self, workspace: str, **kwargs):
"""
get test process of AI algorithm paradigm.
Expand All @@ -98,7 +99,6 @@ def paradigm(self, workspace: str, **kwargs):
"""

config = kwargs

# pylint: disable=C0103
for k, v in self.__dict__.items():
config.update({k: v})
Expand Down
40 changes: 27 additions & 13 deletions core/testcasecontroller/algorithm/paradigm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from core.common.constant import ModuleType, ParadigmType
from .sedna_federated_learning import FederatedLearning


class ParadigmBase:
"""
Paradigm Base
Expand Down Expand Up @@ -96,36 +97,49 @@ def build_paradigm_job(self, paradigm_type):
return IncrementalLearning(
estimator=self.module_instances.get(ModuleType.BASEMODEL.value),
hard_example_mining=self.module_instances.get(
ModuleType.HARD_EXAMPLE_MINING.value))
ModuleType.HARD_EXAMPLE_MINING.value
),
)

if paradigm_type == ParadigmType.LIFELONG_LEARNING.value:
return LifelongLearning(
estimator=self.module_instances.get(
ModuleType.BASEMODEL.value),
estimator=self.module_instances.get(ModuleType.BASEMODEL.value),
task_definition=self.module_instances.get(
ModuleType.TASK_DEFINITION.value),
ModuleType.TASK_DEFINITION.value
),
task_relationship_discovery=self.module_instances.get(
ModuleType.TASK_RELATIONSHIP_DISCOVERY.value),
ModuleType.TASK_RELATIONSHIP_DISCOVERY.value
),
task_allocation=self.module_instances.get(
ModuleType.TASK_ALLOCATION.value),
ModuleType.TASK_ALLOCATION.value
),
task_remodeling=self.module_instances.get(
ModuleType.TASK_REMODELING.value),
ModuleType.TASK_REMODELING.value
),
inference_integrate=self.module_instances.get(
ModuleType.INFERENCE_INTEGRATE.value),
ModuleType.INFERENCE_INTEGRATE.value
),
task_update_decision=self.module_instances.get(
ModuleType.TASK_UPDATE_DECISION.value),
ModuleType.TASK_UPDATE_DECISION.value
),
unseen_task_allocation=self.module_instances.get(
ModuleType.UNSEEN_TASK_ALLOCATION.value),
ModuleType.UNSEEN_TASK_ALLOCATION.value
),
unseen_sample_recognition=self.module_instances.get(
ModuleType.UNSEEN_SAMPLE_RECOGNITION.value),
ModuleType.UNSEEN_SAMPLE_RECOGNITION.value
),
unseen_sample_re_recognition=self.module_instances.get(
ModuleType.UNSEEN_SAMPLE_RE_RECOGNITION.value)
ModuleType.UNSEEN_SAMPLE_RE_RECOGNITION.value
),
)
# pylint: disable=E1101
if paradigm_type == ParadigmType.MULTIEDGE_INFERENCE.value:
return self.modules_funcs.get(ModuleType.BASEMODEL.value)()

if paradigm_type == ParadigmType.FEDERATED_LEARNING.value or paradigm_type == ParadigmType.FEDERATED_CLASS_INCREMENTAL_LEARNING.value:
if paradigm_type in [
ParadigmType.FEDERATED_LEARNING.value,
ParadigmType.FEDERATED_CLASS_INCREMENTAL_LEARNING.value,
]:
return FederatedLearning(
estimator=self.module_instances.get(ModuleType.BASEMODEL.value)
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
# pylint: disable=C0412
# pylint: disable=W1203
import numpy as np
from sedna.algorithms.aggregation import AggClient
from core.common.log import LOGGER
from core.common.constant import ParadigmType, SystemMetricType
from core.testcasecontroller.metrics.metrics import get_metric_func
Expand Down Expand Up @@ -136,6 +135,10 @@ def init_client(self):
for _ in range(self.clients_number)
]

# pylint: disable=C0103
# pylint: disable=C0206
# pylint: disable=C0201
# pylint: disable=W1203
def run(self):
"""run the Federated Class-Incremental Learning paradigm
This function will run the Federated Class-Incremental Learning paradigm.
Expand All @@ -150,6 +153,7 @@ def run(self):
list: prediction result
dict: system metric information
"""

self.init_client()
dataset_files = self._split_dataset(self.incremental_rounds)
test_dataset_files = self._split_test_dataset(self.incremental_rounds)
Expand All @@ -158,10 +162,10 @@ def run(self):
for task_id in range(self.incremental_rounds):
train_datasets, task_size = self.task_definition(dataset_files, task_id)
testdatasets = test_dataset_files[: task_id + 1]
for round in range(self.rounds):
LOGGER.info(f"Round {round} task id: {task_id}")
for r in range(self.rounds):
LOGGER.info(f"Round {r} task id: {task_id}")
self.train(
train_datasets, task_id=task_id, round=round, task_size=task_size
train_datasets, task_id=task_id, round=r, task_size=task_size
)
global_weights = self.aggregator.aggregate(self.aggregate_clients)
if hasattr(self.aggregator, "helper_function"):
Expand Down Expand Up @@ -200,9 +204,6 @@ def _split_test_dataset(self, split_time):
index += 1
return test_datasets_files

def train_data_partition(self, train_dataset_file):
return super().train_data_partition(train_dataset_file)

def client_train(self, client_idx, train_datasets, validation_datasets, **kwargs):
"""client train function that will be called by the thread
Expand All @@ -216,16 +217,6 @@ def client_train(self, client_idx, train_datasets, validation_datasets, **kwargs
)
with self.lock:
self.train_infos.append(train_info)
# train_info = self.clients[client_idx].train(
# train_datasets[client_idx], validation_datasets, **kwargs
# )
# train_info["client_id"] = client_idx
# agg_client = AggClient()
# agg_client.num_samples = train_info["num_samples"]
# agg_client.weights = self.clients[client_idx].get_weights()
# with self.lock:
# self.aggregate_clients.append(agg_client)
# self.train_infos.append(train_info)

def helper_function(self, train_infos):
"""helper function for FCI Method
Expand All @@ -241,6 +232,7 @@ def helper_function(self, train_infos):
LOGGER.info("finish helper function")

# pylint: disable=too-many-locals
# pylint: disable=C0200
def evaluation(self, testdataset_files, incremental_round):
"""evaluate the model performance on old classes
Expand All @@ -254,7 +246,7 @@ def evaluation(self, testdataset_files, incremental_round):
"""
if self.accuracy_func is None:
raise ValueError("accuracy function is not defined")
LOGGER.info("*" * 20 + "start evaluation" + "*" * 20)
LOGGER.info("********start evaluation********")
if isinstance(testdataset_files, str):
testdataset_files = [testdataset_files]
job = self.get_global_model()
Expand Down Expand Up @@ -285,7 +277,8 @@ def evaluation(self, testdataset_files, incremental_round):
acc_per_round = self.accuracy_per_round[j]
if i < len(acc_per_round):
LOGGER.info(
f"acc_per_round: {acc_per_round[i]} and diff is {acc_per_round[i] - old_class_acc_list[i]}"
f"acc_per_round: {acc_per_round[i]}"
+ f" and diff is {acc_per_round[i] - old_class_acc_list[i]}"
)
max_acc_diff = max(
max_acc_diff, acc_per_round[i] - old_class_acc_list[i]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,11 @@ def train(self, train_datasets, **kwargs):
LOGGER.info("finish training")

def send_weights_to_clients(self, global_weights):
"""send weights to clients
Args:
global_weights (list): aggregated weights
"""
for client in self.clients:
client.set_weights(global_weights)
LOGGER.info("finish send weights to clients")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,10 @@ class FederatedLearning(JobBase):

# pylint: disable=too-many-locals
def __init__(self, estimator):
super(FederatedLearning, self).__init__(estimator)
super().__init__(estimator)
self.estimator = estimator

# pylint: disable=W0221
def train(self, train_data, vald_data, **kwargs):
"""Local training function
Expand Down
1 change: 1 addition & 0 deletions core/testenvmanager/dataset/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

""" Dataset utils to read data from file and partition data """
# pylint: disable=W1203
import random
import numpy as np
Expand Down

0 comments on commit 20e07b4

Please sign in to comment.