diff --git a/build/crd-samples/sedna/federatedlearningjob_yolo_v1alpha1.yaml b/build/crd-samples/sedna/federatedlearningjob_yolo_v1alpha1.yaml new file mode 100644 index 000000000..8230be0d6 --- /dev/null +++ b/build/crd-samples/sedna/federatedlearningjob_yolo_v1alpha1.yaml @@ -0,0 +1,87 @@ +apiVersion: sedna.io/v1alpha1 +kind: FederatedLearningJob +metadata: + name: yolo-v5 +spec: + pretrainedModel: # option + name: "yolo-v5-pretrained-model" + transmitter: # option + ws: { } # option, by default + s3: # option, but at least one + aggDataPath: "s3://sedna/fl/aggregation_data" + credentialName: mysecret + aggregationWorker: + model: + name: "yolo-v5-model" + template: + spec: + nodeName: "sedna-control-plane" + containers: + - image: kubeedge/sedna-example-federated-learning-mistnet-yolo-aggregator:v0.4.0 + name: agg-worker + imagePullPolicy: IfNotPresent + env: # user defined environments + - name: "cut_layer" + value: "4" + - name: "epsilon" + value: "100" + - name: "aggregation_algorithm" + value: "mistnet" + - name: "batch_size" + value: "32" + resources: # user defined resources + limits: + memory: 8Gi + trainingWorkers: + - dataset: + name: "coco-dataset-1" + template: + spec: + nodeName: "edge-node" + containers: + - image: kubeedge/sedna-example-federated-learning-mistnet-yolo-client:v0.4.0 + name: train-worker + imagePullPolicy: IfNotPresent + args: [ "-i", "1" ] + env: # user defined environments + - name: "cut_layer" + value: "4" + - name: "epsilon" + value: "100" + - name: "aggregation_algorithm" + value: "mistnet" + - name: "batch_size" + value: "32" + - name: "learning_rate" + value: "0.001" + - name: "epochs" + value: "1" + resources: # user defined resources + limits: + memory: 2Gi + - dataset: + name: "coco-dataset-2" + template: + spec: + nodeName: "edge-node" + containers: + - image: kubeedge/sedna-example-federated-learning-mistnet-yolo-client:v0.4.0 + name: train-worker + imagePullPolicy: IfNotPresent + args: [ "-i", "2" ] + env: # user defined environments + - name: "cut_layer" + value: "4" + - name: "epsilon" + value: "100" + - name: "aggregation_algorithm" + value: "mistnet" + - name: "batch_size" + value: "32" + - name: "learning_rate" + value: "0.001" + - name: "epochs" + value: "1" + resources: # user defined resources + limits: + memory: 2Gi \ No newline at end of file diff --git a/examples/build_image.sh b/examples/build_image.sh index 6a154d845..fb05c2c3f 100644 --- a/examples/build_image.sh +++ b/examples/build_image.sh @@ -17,11 +17,13 @@ cd "$(dirname "${BASH_SOURCE[0]}")" IMAGE_REPO=${IMAGE_REPO:-kubeedge} -IMAGE_TAG=${IMAGE_TAG:-v0.3.0} +IMAGE_TAG=${IMAGE_TAG:-v0.4.0} EXAMPLE_REPO_PREFIX=${IMAGE_REPO}/sedna-example- dockerfiles=( +federated-learning-mistnet-yolo-aggregator.Dockerfile +federated-learning-mistnet-yolo-client.Dockerfile federated-learning-surface-defect-detection-aggregation.Dockerfile federated-learning-surface-defect-detection-train.Dockerfile incremental-learning-helmet-detection.Dockerfile diff --git a/examples/federated-learning-mistnet-yolo-aggregator.Dockerfile b/examples/federated-learning-mistnet-yolo-aggregator.Dockerfile new file mode 100644 index 000000000..e316f6eb3 --- /dev/null +++ b/examples/federated-learning-mistnet-yolo-aggregator.Dockerfile @@ -0,0 +1,23 @@ +FROM tensorflow/tensorflow:1.15.4 + +RUN apt update \ + && apt install -y libgl1-mesa-glx git + +COPY ./lib/requirements.txt /home + +RUN python -m pip install --upgrade pip + +RUN pip install -r /home/requirements.txt + +ENV PYTHONPATH "/home/lib:/home/plato:/home/plato/packages/yolov5" + +COPY ./lib /home/lib +RUN git clone https://github.com/TL-System/plato.git /home/plato + +RUN pip install -r /home/plato/requirements.txt +RUN pip install -r /home/plato/packages/yolov5/requirements.txt + +WORKDIR /home/work +COPY examples/federated_learning/yolov5_coco128_mistnet /home/work/ + +CMD ["/bin/sh", "-c", "ulimit -n 50000; python aggregate.py"] diff --git a/examples/federated-learning-mistnet-yolo-client.Dockerfile b/examples/federated-learning-mistnet-yolo-client.Dockerfile new file mode 100644 index 000000000..b1e7aa356 --- /dev/null +++ b/examples/federated-learning-mistnet-yolo-client.Dockerfile @@ -0,0 +1,23 @@ +FROM tensorflow/tensorflow:1.15.4 + +RUN apt update \ + && apt install -y libgl1-mesa-glx git + +COPY ./lib/requirements.txt /home + +RUN python -m pip install --upgrade pip + +RUN pip install -r /home/requirements.txt + +ENV PYTHONPATH "/home/lib:/home/plato:/home/plato/packages/yolov5" + +COPY ./lib /home/lib +RUN git clone https://github.com/TL-System/plato.git /home/plato + +RUN pip install -r /home/plato/requirements.txt +RUN pip install -r /home/plato/packages/yolov5/requirements.txt + +WORKDIR /home/work +COPY examples/federated_learning/yolov5_coco128_mistnet /home/work/ + +ENTRYPOINT ["python", "train.py"] diff --git a/examples/federated_learning/surface_defect_detection/training_worker/train.py b/examples/federated_learning/surface_defect_detection/training_worker/train.py index 4fd9a1122..37f21d0cc 100644 --- a/examples/federated_learning/surface_defect_detection/training_worker/train.py +++ b/examples/federated_learning/surface_defect_detection/training_worker/train.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import os import numpy as np @@ -74,6 +73,7 @@ def main(): learning_rate=learning_rate, validation_split=validation_split ) + return train_jobs diff --git a/examples/federated_learning/yolov5_coco128_mistnet/README.md b/examples/federated_learning/yolov5_coco128_mistnet/README.md new file mode 100644 index 000000000..0bc26fa0c --- /dev/null +++ b/examples/federated_learning/yolov5_coco128_mistnet/README.md @@ -0,0 +1,249 @@ +# Collaboratively Train Yolo-v5 Using MistNet on COCO128 Dataset + +This case introduces how to train a federated learning job with an aggregation algorithm named MistNet in MNIST +handwritten digit classification scenario. Data is scattered in different places (such as edge nodes, cameras, and +others) and cannot be aggregated at the server due to data privacy and bandwidth. As a result, we cannot use all the +data for training. In some cases, edge nodes have limited computing resources and even have no training capability. The +edge cannot gain the updated weights from the training process. Therefore, traditional algorithms (e.g., federated +average), which usually aggregate the updated weights trained by different edge clients, cannot work in this scenario. +MistNet is proposed to address this issue. + +MistNet partitions a DNN model into two parts, a lightweight feature extractor at the edge side to generate meaningful +features from the raw data, and a classifier including the most model layers at the cloud to be iteratively trained for +specific tasks. MistNet achieves acceptable model utility while greatly reducing privacy leakage from the released +intermediate features. + +## Object Detection Experiment + +> Assume that there are two edge nodes and a cloud node. Data on the edge nodes cannot be migrated to the cloud due to privacy issues. +> Base on this scenario, we will demonstrate the mnist example. + +### Prepare Nodes + +``` +CLOUD_NODE="cloud-node-name" +EDGE1_NODE="edge1-node-name" +EDGE2_NODE="edge2-node-name" +``` + +### Install Sedna + +Follow the [Sedna installation document](/docs/setup/install.md) to install Sedna. + +### Prepare Dataset + +Download [dataset](https://github.com/ultralytics/yolov5/releases/download/v1.0/coco128.zip) + +Create data interface for ```EDGE1_NODE```. + +```shell +mkdir -p /data/1 +cd /data/1 +wget https://github.com/ultralytics/yolov5/releases/download/v1.0/coco128.zip +unzip coco128.zip -d COCO +``` + +Create data interface for ```EDGE2_NODE```. + +```shell +mkdir -p /data/2 +cd /data/2 +wget https://github.com/ultralytics/yolov5/releases/download/v1.0/coco128.zip +unzip coco128.zip -d COCO +``` + +### Prepare Images + +This example uses these images: + +1. aggregation worker: ```kubeedge/sedna-example-federated-learning-mistnet-yolo-aggregato:v0.4.0``` +2. train worker: ```kubeedge/sedna-example-federated-learning-mistnet-yolo-client:v0.4.0``` + +These images are generated by the script [build_images.sh](/examples/build_image.sh). + +### Create Federated Learning Job + +#### Create Dataset + +create dataset for `$EDGE1_NODE` and `$EDGE2_NODE` + +```bash +kubectl create -f - < None: + self.parameters = { + "datasource": "YOLO", + "data_params": "./coco128.yaml", + # Where the dataset is located + "data_path": "./data/COCO", + "train_path": "./data/COCO/coco128/images/train2017/", + "test_path": "./data/COCO/coco128/images/train2017/", + # number of training examples + "num_train_examples": 128, + # number of testing examples + "num_test_examples": 128, + # number of classes + "num_classes": 80, + # image size + "image_size": 640, + "download_urls": ["https://github.com/ultralytics/yolov5/releases/download/v1.0/coco128.zip",], + "classes": + [ + "person", + "bicycle", + "car", + "motorcycle", + "airplane", + "bus", + "train", + "truck", + "boat", + "traffic light", + "fire hydrant", + "stop sign", + "parking meter", + "bench", + "bird", + "cat", + "dog", + "horse", + "sheep", + "cow", + "elephant", + "bear", + "zebra", + "giraffe", + "backpack", + "umbrella", + "handbag", + "tie", + "suitcase", + "frisbee", + "skis", + "snowboard", + "sports ball", + "kite", + "baseball bat", + "baseball glove", + "skateboard", + "surfboard", + "tennis racket", + "bottle", + "wine glass", + "cup", + "fork", + "knife", + "spoon", + "bowl", + "banana", + "apple", + "sandwich", + "orange", + "broccoli", + "carrot", + "hot dog", + "pizza", + "donut", + "cake", + "chair", + "couch", + "potted plant", + "bed", + "dining table", + "toilet", + "tv", + "laptop", + "mouse", + "remote", + "keyboard", + "cell phone", + "microwave", + "oven", + "toaster", + "sink", + "refrigerator", + "book", + "clock", + "vase", + "scissors", + "teddy bear", + "hair drier", + "toothbrush", + ], + "partition_size": 128, + } + + +class Estimator: + def __init__(self) -> None: + self.model = None + self.pretrained = None + self.hyperparameters = { + "type": "yolov5", + "rounds": 1, + "target_accuracy": 0.99, + "epochs": 500, + "batch_size": 16, + "optimizer": "SGD", + "linear_lr": False, + # The machine learning model + "model_name": "yolov5", + "model_config": "./yolov5s.yaml", + "train_params": "./hyp.scratch.yaml" + } diff --git a/examples/federated_learning/yolov5_coco128_mistnet/train.py b/examples/federated_learning/yolov5_coco128_mistnet/train.py new file mode 100644 index 000000000..99406dd21 --- /dev/null +++ b/examples/federated_learning/yolov5_coco128_mistnet/train.py @@ -0,0 +1,35 @@ +# Copyright 2021 The KubeEdge Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from interface import mistnet, s3_transmitter +from interface import Dataset, Estimator +from sedna.common.config import BaseConfig +from sedna.core.federated_learning import FederatedLearningV2 + +def main(): + data = Dataset() + estimator = Estimator() + data.parameters["data_path"] = BaseConfig.train_dataset_url.replace("robot.txt", "") + data.parameters["train_path"] = os.path.join(data.parameters["data_path"], "./coco128/images/train2017/") + data.parameters["test_path"] = data.parameters["train_path"] + fl_model = FederatedLearningV2( + data=data, + estimator=estimator, + aggregation=mistnet, + transmitter=s3_transmitter) + + fl_model.train() + +if __name__ == '__main__': + main() diff --git a/examples/federated_learning/yolov5_coco128_mistnet/yolov5s.yaml b/examples/federated_learning/yolov5_coco128_mistnet/yolov5s.yaml new file mode 100644 index 000000000..e4e9e4dde --- /dev/null +++ b/examples/federated_learning/yolov5_coco128_mistnet/yolov5s.yaml @@ -0,0 +1,48 @@ +# parameters +nc: 80 # number of classes +depth_multiple: 0.33 # model depth multiple +width_multiple: 0.50 # layer channel multiple + +# anchors +anchors: + - [ 10,13, 16,30, 33,23 ] # P3/8 + - [ 30,61, 62,45, 59,119 ] # P4/16 + - [ 116,90, 156,198, 373,326 ] # P5/32 + +# YOLOv5 backbone +backbone: + # [from, number, module, args] + [ [ -1, 1, Focus, [ 64, 3 ] ], # 0-P1/2 + [ -1, 1, Conv, [ 128, 3, 2 ] ], # 1-P2/4 + [ -1, 3, C3, [ 128 ] ], + [ -1, 1, Conv, [ 256, 3, 2 ] ], # 3-P3/8 + [ -1, 9, C3, [ 256 ] ], + [ -1, 1, Conv, [ 512, 3, 2 ] ], # 5-P4/16 + [ -1, 9, C3, [ 512 ] ], + [ -1, 1, Conv, [ 1024, 3, 2 ] ], # 7-P5/32 + [ -1, 1, SPP, [ 1024, [ 5, 9, 13 ] ] ], + [ -1, 3, C3, [ 1024, False ] ], # 9 + ] + +# YOLOv5 head +head: + [ [ -1, 1, Conv, [ 512, 1, 1 ] ], + [ -1, 1, nn.Upsample, [ None, 2, 'nearest' ] ], + [ [ -1, 6 ], 1, Concat, [ 1 ] ], # cat backbone P4 + [ -1, 3, C3, [ 512, False ] ], # 13 + + [ -1, 1, Conv, [ 256, 1, 1 ] ], + [ -1, 1, nn.Upsample, [ None, 2, 'nearest' ] ], + [ [ -1, 4 ], 1, Concat, [ 1 ] ], # cat backbone P3 + [ -1, 3, C3, [ 256, False ] ], # 17 (P3/8-small) + + [ -1, 1, Conv, [ 256, 3, 2 ] ], + [ [ -1, 14 ], 1, Concat, [ 1 ] ], # cat head P4 + [ -1, 3, C3, [ 512, False ] ], # 20 (P4/16-medium) + + [ -1, 1, Conv, [ 512, 3, 2 ] ], + [ [ -1, 10 ], 1, Concat, [ 1 ] ], # cat head P5 + [ -1, 3, C3, [ 1024, False ] ], # 23 (P5/32-large) + + [ [ 17, 20, 23 ], 1, Detect, [ nc, anchors ] ], # Detect(P3, P4, P5) + ] diff --git a/lib/sedna/algorithms/aggregation/__init__.py b/lib/sedna/algorithms/aggregation/__init__.py index eba0a1881..4725746ab 100644 --- a/lib/sedna/algorithms/aggregation/__init__.py +++ b/lib/sedna/algorithms/aggregation/__init__.py @@ -13,3 +13,4 @@ # limitations under the License. from . import aggregation +from .aggregation import FedAvg, MistNet, AggClient diff --git a/lib/sedna/algorithms/aggregation/aggregation.py b/lib/sedna/algorithms/aggregation/aggregation.py index 3b998814f..d2116eadd 100644 --- a/lib/sedna/algorithms/aggregation/aggregation.py +++ b/lib/sedna/algorithms/aggregation/aggregation.py @@ -104,3 +104,24 @@ def aggregate(self, clients: List[AggClient]): updates.append(row.tolist()) self.weights = deepcopy(updates) return updates + + +@ClassFactory.register(ClassType.FL_AGG) +class MistNet(BaseAggregation, abc.ABC): + def __init__(self, cut_layer, epsilon=100): + super().__init__() + self.parameters = { + "type": "mistnet", + "cut_layer": cut_layer, + "epsilon": epsilon + } + if isinstance(self.parameters["cut_layer"], str): + if self.parameters["cut_layer"].isdigit(): + self.parameters["cut_layer"] = int(cut_layer) + + if isinstance(self.parameters["epsilon"], str): + if self.parameters["epsilon"].isdigit(): + self.parameters["epsilon"] = int(cut_layer) + + def aggregate(self, clients: List[AggClient]): + pass diff --git a/lib/sedna/algorithms/client_choose/__init__.py b/lib/sedna/algorithms/client_choose/__init__.py new file mode 100644 index 000000000..d5f58d4a0 --- /dev/null +++ b/lib/sedna/algorithms/client_choose/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2021 The KubeEdge Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .client_choose import SimpleClientChoose diff --git a/lib/sedna/algorithms/client_choose/client_choose.py b/lib/sedna/algorithms/client_choose/client_choose.py new file mode 100644 index 000000000..f7e10ca1b --- /dev/null +++ b/lib/sedna/algorithms/client_choose/client_choose.py @@ -0,0 +1,38 @@ +# Copyright 2021 The KubeEdge Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import abc + + +class AbstractClientChoose(metaclass=abc.ABCMeta): + """ + Abstract class of ClientChoose, which provides base client choose + algorithm interfaces in federated learning. + """ + + def __init__(self): + pass + + +class SimpleClientChoose(AbstractClientChoose): + """ + A Simple Implementation of Client Choose. + """ + + def __init__(self, per_round=1): + super().__init__() + self.parameters = { + "per_round": per_round + } diff --git a/lib/sedna/algorithms/transmitter/__init__.py b/lib/sedna/algorithms/transmitter/__init__.py new file mode 100644 index 000000000..b71ccf1f7 --- /dev/null +++ b/lib/sedna/algorithms/transmitter/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2021 The KubeEdge Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .transmitter import S3Transmitter, WSTransmitter diff --git a/lib/sedna/algorithms/transmitter/transmitter.py b/lib/sedna/algorithms/transmitter/transmitter.py new file mode 100644 index 000000000..0aaac0c74 --- /dev/null +++ b/lib/sedna/algorithms/transmitter/transmitter.py @@ -0,0 +1,69 @@ +# Copyright 2021 The KubeEdge Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod + + +class AbstractTransmitter(ABC): + """ + Abstract class of Transmitter, which provides base transmission + interfaces between edge and cloud. + """ + + @abstractmethod + def recv(self): + pass + + @abstractmethod + def send(self, data): + pass + + +class WSTransmitter(AbstractTransmitter, ABC): + """ + An implementation of Transmitter based on WebSocket. + """ + + def __init__(self): + self.parameters = {} + + def recv(self): + pass + + def send(self, data): + pass + + +class S3Transmitter(AbstractTransmitter, ABC): + """ + An implementation of Transmitter based on S3 protocol. + """ + + def __init__(self, + s3_endpoint_url, + access_key, + secret_key, + transmitter_url): + self.parameters = { + "s3_endpoint_url": s3_endpoint_url, + "s3_bucket": transmitter_url, + "access_key": access_key, + "secret_key": secret_key + } + + def recv(self): + pass + + def send(self, data): + pass diff --git a/lib/sedna/common/config.py b/lib/sedna/common/config.py index 769ef1a45..ad6c62d4a 100644 --- a/lib/sedna/common/config.py +++ b/lib/sedna/common/config.py @@ -269,9 +269,16 @@ class BaseConfig(ConfigSerializable): # the name of FederatedLearningJob and others Job job_name = os.getenv("JOB_NAME", "sedna") + pretrained_model_url = os.getenv("PRETRAINED_MODEL_URL", "./") model_url = os.getenv("MODEL_URL") model_name = os.getenv("MODEL_NAME") + transmitter = os.getenv("TRANSMITTER", "ws") + agg_data_path = os.getenv("AGG_DATA_PATH", "./") + s3_endpoint_url = os.getenv("S3_ENDPOINT_URL", "") + access_key_id = os.getenv("ACCESS_KEY_ID", "") + secret_access_key = os.getenv("SECRET_ACCESS_KEY", "") + # user parameter parameters = os.getenv("PARAMETERS") diff --git a/lib/sedna/core/federated_learning/__init__.py b/lib/sedna/core/federated_learning/__init__.py index c36fe3b80..e5eaad2d7 100644 --- a/lib/sedna/core/federated_learning/__init__.py +++ b/lib/sedna/core/federated_learning/__init__.py @@ -13,3 +13,4 @@ # limitations under the License. from .federated_learning import FederatedLearning +from .federated_learning import FederatedLearningV2 diff --git a/lib/sedna/core/federated_learning/federated_learning.py b/lib/sedna/core/federated_learning/federated_learning.py index eec652dd1..999a153ed 100644 --- a/lib/sedna/core/federated_learning/federated_learning.py +++ b/lib/sedna/core/federated_learning/federated_learning.py @@ -13,14 +13,19 @@ # limitations under the License. +import asyncio +import sys import time -from sedna.core.base import JobBase -from sedna.common.config import Context -from sedna.common.file_ops import FileOps +from sedna.algorithms.transmitter import S3Transmitter, WSTransmitter from sedna.common.class_factory import ClassFactory, ClassType -from sedna.service.client import AggregationClient +from sedna.common.config import BaseConfig, Context from sedna.common.constant import K8sResourceKindStatus +from sedna.common.file_ops import FileOps +from sedna.core.base import JobBase +from sedna.service.client import AggregationClient + +__all__ = ('FederatedLearning', 'FederatedLearningV2') class FederatedLearning(JobBase): @@ -50,6 +55,7 @@ class FederatedLearning(JobBase): aggregation="FedAvg" ) """ + def __init__(self, estimator, aggregation="FedAvg"): protocol = Context.get_parameters("AGG_PROTOCOL", "ws") @@ -178,3 +184,64 @@ def train(self, train_data, task_info, K8sResourceKindStatus.RUNNING.value, task_info_res) + + +class FederatedLearningV2: + def __init__(self, data=None, estimator=None, + aggregation=None, transmitter=None) -> None: + + from plato.config import Config + from plato.clients import registry as client_registry + # set parameters + server = Config.server._asdict() + clients = Config.clients._asdict() + datastore = Config.data._asdict() + train = Config.trainer._asdict() + + if data is not None: + datastore.update(data.parameters) + Config.data = Config.namedtuple_from_dict(datastore) + + self.model = None + if estimator is not None: + self.model = estimator.model + train.update(estimator.hyperparameters) + Config.trainer = Config.namedtuple_from_dict(train) + + if aggregation is not None: + Config.algorithm = Config.namedtuple_from_dict( + aggregation.parameters) + if aggregation.parameters["type"] == "mistnet": + clients["type"] = "mistnet" + server["type"] = "mistnet" + + server["address"] = Context.get_parameters("AGG_IP") + server["port"] = Context.get_parameters("AGG_PORT") + + if transmitter is not None: + server.update(transmitter.parameters) + + Config.server = Config.namedtuple_from_dict(server) + Config.clients = Config.namedtuple_from_dict(clients) + + # Config.store() + # create a client + self.client = client_registry.get(model=self.model) + self.client.configure() + + @classmethod + def get_transmitter_from_config(cls): + if BaseConfig.transmitter == "ws": + return WSTransmitter() + elif BaseConfig.transmitter == "s3": + return S3Transmitter(s3_endpoint_url=BaseConfig.s3_endpoint_url, + access_key=BaseConfig.access_key_id, + secret_key=BaseConfig.secret_access_key, + transmitter_url=BaseConfig.agg_data_path) + + def train(self): + if int(sys.version[2]) <= 6: + loop = asyncio.get_event_loop() + loop.run_until_complete(self.client.start_client()) + else: + asyncio.run(self.client.start_client()) diff --git a/lib/sedna/service/server/aggregation.py b/lib/sedna/service/server/aggregation.py index 5a7bb3091..1586f76ae 100644 --- a/lib/sedna/service/server/aggregation.py +++ b/lib/sedna/service/server/aggregation.py @@ -13,27 +13,27 @@ # limitations under the License. import time -from typing import List, Optional, Dict, Any - import uuid -from pydantic import BaseModel +from typing import Any, Dict, List, Optional + from fastapi import FastAPI, WebSocket from fastapi.routing import APIRoute +from pydantic import BaseModel +from starlette.endpoints import WebSocketEndpoint from starlette.requests import Request from starlette.responses import JSONResponse from starlette.routing import WebSocketRoute -from starlette.endpoints import WebSocketEndpoint from starlette.types import ASGIApp, Receive, Scope, Send +from sedna.algorithms.aggregation import AggClient +from sedna.common.config import BaseConfig, Context +from sedna.common.class_factory import ClassFactory, ClassType from sedna.common.log import LOGGER from sedna.common.config import Context from sedna.common.utils import get_host_ip -from sedna.common.class_factory import ClassFactory, ClassType -from sedna.algorithms.aggregation import AggClient - from .base import BaseServer -__all__ = ('AggregationServer',) +__all__ = ('AggregationServer', 'AggregationServerV2') class WSClientInfo(BaseModel): # pylint: disable=too-few-public-methods @@ -269,3 +269,56 @@ async def client_info(self, request: Request): if client_id: return server.get_client(client_id) return WSClientInfoList(clients=server.client_list) + + +class AggregationServerV2(): + def __init__(self, data=None, estimator=None, + aggregation=None, transmitter=None, + chooser=None) -> None: + from plato.config import Config + from plato.servers import registry as server_registry + # set parameters + server = Config.server._asdict() + clients = Config.clients._asdict() + datastore = Config.data._asdict() + train = Config.trainer._asdict() + + if data is not None: + datastore.update(data.parameters) + Config.data = Config.namedtuple_from_dict(datastore) + + self.model = None + if estimator is not None: + self.model = estimator.model + if estimator.pretrained is not None: + LOGGER.info(estimator.pretrained) + Config.params['model_dir'] = estimator.pretrained + train.update(estimator.hyperparameters) + Config.trainer = Config.namedtuple_from_dict(train) + + server["address"] = Context.get_parameters("AGG_BIND_IP", "0.0.0.0") + server["port"] = Context.get_parameters("AGG_BIND_PORT", 7363) + if transmitter is not None: + server.update(transmitter.parameters) + + if aggregation is not None: + Config.algorithm = Config.namedtuple_from_dict( + aggregation.parameters) + if aggregation.parameters["type"] == "mistnet": + clients["type"] = "mistnet" + server["type"] = "mistnet" + + if chooser is not None: + clients["per_round"] = chooser.parameters["per_round"] + + LOGGER.info("address %s, port %s", server["address"], server["port"]) + + Config.server = Config.namedtuple_from_dict(server) + Config.clients = Config.namedtuple_from_dict(clients) + + # Config.store() + # create a server + self.server = server_registry.get(model=self.model) + + def start(self): + self.server.run()