Skip to content

Commit

Permalink
fix: fix the pylint check error
Browse files Browse the repository at this point in the history
Signed-off-by: Marchons <[email protected]>
  • Loading branch information
Yoda-wu committed Sep 18, 2024
1 parent ee5af6b commit db2c4e5
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 20 deletions.
2 changes: 1 addition & 1 deletion core/testcasecontroller/algorithm/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def __init__(self, name, config):
self._parse_config(config)
self._load_third_party_packages()

# pylint: disable=R0911
def paradigm(self, workspace: str, **kwargs):
"""
get test process of AI algorithm paradigm.
Expand All @@ -98,7 +99,6 @@ def paradigm(self, workspace: str, **kwargs):
"""

config = kwargs

# pylint: disable=C0103
for k, v in self.__dict__.items():
config.update({k: v})
Expand Down
40 changes: 27 additions & 13 deletions core/testcasecontroller/algorithm/paradigm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from core.common.constant import ModuleType, ParadigmType
from .sedna_federated_learning import FederatedLearning


class ParadigmBase:
"""
Paradigm Base
Expand Down Expand Up @@ -96,36 +97,49 @@ def build_paradigm_job(self, paradigm_type):
return IncrementalLearning(
estimator=self.module_instances.get(ModuleType.BASEMODEL.value),
hard_example_mining=self.module_instances.get(
ModuleType.HARD_EXAMPLE_MINING.value))
ModuleType.HARD_EXAMPLE_MINING.value
),
)

if paradigm_type == ParadigmType.LIFELONG_LEARNING.value:
return LifelongLearning(
estimator=self.module_instances.get(
ModuleType.BASEMODEL.value),
estimator=self.module_instances.get(ModuleType.BASEMODEL.value),
task_definition=self.module_instances.get(
ModuleType.TASK_DEFINITION.value),
ModuleType.TASK_DEFINITION.value
),
task_relationship_discovery=self.module_instances.get(
ModuleType.TASK_RELATIONSHIP_DISCOVERY.value),
ModuleType.TASK_RELATIONSHIP_DISCOVERY.value
),
task_allocation=self.module_instances.get(
ModuleType.TASK_ALLOCATION.value),
ModuleType.TASK_ALLOCATION.value
),
task_remodeling=self.module_instances.get(
ModuleType.TASK_REMODELING.value),
ModuleType.TASK_REMODELING.value
),
inference_integrate=self.module_instances.get(
ModuleType.INFERENCE_INTEGRATE.value),
ModuleType.INFERENCE_INTEGRATE.value
),
task_update_decision=self.module_instances.get(
ModuleType.TASK_UPDATE_DECISION.value),
ModuleType.TASK_UPDATE_DECISION.value
),
unseen_task_allocation=self.module_instances.get(
ModuleType.UNSEEN_TASK_ALLOCATION.value),
ModuleType.UNSEEN_TASK_ALLOCATION.value
),
unseen_sample_recognition=self.module_instances.get(
ModuleType.UNSEEN_SAMPLE_RECOGNITION.value),
ModuleType.UNSEEN_SAMPLE_RECOGNITION.value
),
unseen_sample_re_recognition=self.module_instances.get(
ModuleType.UNSEEN_SAMPLE_RE_RECOGNITION.value)
ModuleType.UNSEEN_SAMPLE_RE_RECOGNITION.value
),
)
# pylint: disable=E1101
if paradigm_type == ParadigmType.MULTIEDGE_INFERENCE.value:
return self.modules_funcs.get(ModuleType.BASEMODEL.value)()

if paradigm_type == ParadigmType.FEDERATED_LEARNING.value or paradigm_type == ParadigmType.FEDERATED_CLASS_INCREMENTAL_LEARNING.value:
if paradigm_type in [
ParadigmType.FEDERATED_LEARNING.value,
ParadigmType.FEDERATED_CLASS_INCREMENTAL_LEARNING.value,
]:
return FederatedLearning(
estimator=self.module_instances.get(ModuleType.BASEMODEL.value)
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -200,9 +200,6 @@ def _split_test_dataset(self, split_time):
index += 1
return test_datasets_files

def train_data_partition(self, train_dataset_file):
return super().train_data_partition(train_dataset_file)

def client_train(self, client_idx, train_datasets, validation_datasets, **kwargs):
"""client train function that will be called by the thread
Expand Down Expand Up @@ -241,6 +238,7 @@ def helper_function(self, train_infos):
LOGGER.info("finish helper function")

# pylint: disable=too-many-locals
# pylint: disable=consider-using-enumerat
def evaluation(self, testdataset_files, incremental_round):
"""evaluate the model performance on old classes
Expand All @@ -254,7 +252,7 @@ def evaluation(self, testdataset_files, incremental_round):
"""
if self.accuracy_func is None:
raise ValueError("accuracy function is not defined")
LOGGER.info("*" * 20 + "start evaluation" + "*" * 20)
LOGGER.info("********start evaluation********")
if isinstance(testdataset_files, str):
testdataset_files = [testdataset_files]
job = self.get_global_model()
Expand Down Expand Up @@ -285,7 +283,8 @@ def evaluation(self, testdataset_files, incremental_round):
acc_per_round = self.accuracy_per_round[j]
if i < len(acc_per_round):
LOGGER.info(
f"acc_per_round: {acc_per_round[i]} and diff is {acc_per_round[i] - old_class_acc_list[i]}"
f"acc_per_round: {acc_per_round[i]}"
+ f" and diff is {acc_per_round[i] - old_class_acc_list[i]}"
)
max_acc_diff = max(
max_acc_diff, acc_per_round[i] - old_class_acc_list[i]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,10 @@ class FederatedLearning(JobBase):

# pylint: disable=too-many-locals
def __init__(self, estimator):
super(FederatedLearning, self).__init__(estimator)
super().__init__()
self.estimator = estimator

# pylint: disable=W0221
def train(self, train_data, vald_data, **kwargs):
"""Local training function
Expand Down
1 change: 1 addition & 0 deletions core/testenvmanager/dataset/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

""" Dataset utils to read data from file and partition data """
# pylint: disable=W1203
import random
import numpy as np
Expand Down

0 comments on commit db2c4e5

Please sign in to comment.