Skip to content

Commit

Permalink
fix(federated_class_incremental_learning) support multiple client ini…
Browse files Browse the repository at this point in the history
…t and deepcopy estimator

Signed-off-by: Marchons <[email protected]>
  • Loading branch information
Yoda-wu committed Oct 17, 2024
1 parent 86e1e2e commit 5f08fcd
Show file tree
Hide file tree
Showing 6 changed files with 13 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,8 @@ def get_task_size(self, train_datasets):
Returns:
int: task size for each task
"""
return np.unique(
[train_datasets[i][1] for i in range(len(train_datasets))]
).shape[0]
LOGGER.info(f"train_datasets: {len(train_datasets[0])}")
return np.unique(train_datasets[0][1]).shape[0]

def split_label_unlabel_data(self, train_datasets):
"""split train dataset into label and unlabel data for semi-supervised learning
Expand Down Expand Up @@ -140,9 +139,6 @@ def init_client(self):
)
for _ in range(self.clients_number)
]
LOGGER.info(
f"init client {self.clients[0].estimator == self.clients[1].estimator}"
)

def run(self):
"""run the Federated Class-Incremental Learning paradigm
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ algorithm:
fl_data_setting:
train_ratio: 1.0
splitting_method: "default"
label_data_ratio: 0.3
label_data_ratio: 1.0
data_partition: "iid"
non_iid_ratio: "0.6"
initial_model_url: "/home/wyd/ianvs/project/init_model/cnn.pb"
Expand All @@ -21,7 +21,7 @@ algorithm:
- 0.001
- epochs:
values:
- 16
- 1
- type: "aggregation"
name: "FedAvg"
url: "./examples/cifar100/fci_ssl/fed_ci_match/algorithm/aggregation.py"
Expand Down
4 changes: 2 additions & 2 deletions examples/cifar100/fci_ssl/fed_ci_match/testenv/testenv.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,6 @@ testenv:
- name: "forget_rate"
- name: "task_avg_acc"
# incremental rounds setting of incremental learning; int type; default value is 2;
incremental_rounds: 10
incremental_rounds: 2
round: 1
client_number: 2
client_number: 5
7 changes: 5 additions & 2 deletions examples/cifar100/fci_ssl/glfc/algorithm/GLFC.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,13 +132,16 @@ def _initialize_classifier(self):
def before_train(self, task_id, train_data, class_learned, old_model):
logging.info(f"------before train task_id: {task_id}------")
# print(f'train data len is :{len(train_data[1])}')

self.need_update = task_id != self.old_task_id
if self.need_update:
self.old_task_id = task_id
self.num_classes = self.task_size * (task_id + 1)
if self.current_classes is not None:
self.last_class = self.current_classes
logging.info(f"self.last_class is , {self.last_class}, {self.num_classes}")
logging.info(
f"self.last_class is , {self.last_class}, {self.num_classes} tasksize is {self.task_size}, task_id is {task_id}"
)
self._initialize_classifier()
self.current_classes = np.unique(train_data["label_y"]).tolist()
self.update_new_set(self.need_update)
Expand Down Expand Up @@ -265,7 +268,7 @@ def _compute_loss(self, imgs, labels):
correct = tf.cast(tf.equal(pred, y), dtype=tf.int32)
correct = tf.reduce_sum(correct)
logging.info(
f"current class numbers is {self.num_classes} correct is {correct} and acc is {correct/imgs.shape[0]}"
f"current class numbers is {self.num_classes} correct is {correct} and acc is {correct/imgs.shape[0]} tasksize is {self.task_size} self.old_task_id {self.old_task_id}"
)
# print(f"total_correct: {total_correct}, total_num: {total_num}")
if self.old_model == None:
Expand Down
2 changes: 1 addition & 1 deletion examples/cifar100/fci_ssl/glfc/algorithm/aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def __init__(self):
learning_rate=0.01, num_classes=10, test_data=None
)
self.task_id = -1
self.num_classes = 50
self.num_classes = 10

def aggregate(self, clients):
"""
Expand Down
2 changes: 1 addition & 1 deletion examples/cifar100/fci_ssl/glfc/algorithm/basemodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def __init__(self, **kwargs) -> None:
self.learning_rate = kwargs.get("learning_rate", 0.001)
self.epochs = kwargs.get("epochs", 1)
self.batch_size = kwargs.get("batch_size", 32)
self.task_size = kwargs.get("task_size", 50)
self.task_size = kwargs.get("task_size", 10)
self.memory_size = kwargs.get("memory_size", 2000)
self.encode_model = lenet5(32, 100)
self.encode_model.call(keras.Input(shape=(32, 32, 3)))
Expand Down

0 comments on commit 5f08fcd

Please sign in to comment.