From 52ff18a9c899167f599cd5eae4efcd0c6a3f8e90 Mon Sep 17 00:00:00 2001 From: cwj Date: Wed, 30 Aug 2023 14:33:51 +0800 Subject: [PATCH 01/23] LLM-1.3 Update 1. FedIPR: - backdoor watermark dataset - feature-based water mark modules: conv and layernorm - built-in feature-based watermark model: alexnet, resnet18, distilbert, gpt2 - fedipr trainer 2. Offsite-tuning - offsite-tuning models: gpt2 family and bloom-7b - offsite-tuning trainer Signed-off-by: cwj --- python/fate_llm/dataset/__init__.py | 2 +- python/fate_llm/dataset/glm_tokenizer.py | 6 +- python/fate_llm/dataset/llama_tokenizer.py | 20 +- python/fate_llm/dataset/nlp_tokenizer.py | 18 +- python/fate_llm/dataset/watermark.py | 134 ++++ python/fate_llm/model_zoo/__init__.py | 2 +- python/fate_llm/model_zoo/ipr/alexnet.py | 59 ++ python/fate_llm/model_zoo/ipr/distilbert.py | 49 ++ python/fate_llm/model_zoo/ipr/gpt2.py | 41 ++ python/fate_llm/model_zoo/ipr/resnet.py | 146 +++++ .../model_zoo/offsite_tuning/bloom_ot.py | 166 +++++ .../model_zoo/offsite_tuning/gpt2_ot.py | 291 +++++++++ .../offsite_tuning/offsite_tuning_model.py | 346 +++++++++++ python/fate_llm/model_zoo/pellm/albert.py | 8 +- python/fate_llm/model_zoo/pellm/bart.py | 8 +- python/fate_llm/model_zoo/pellm/bert.py | 8 +- python/fate_llm/model_zoo/pellm/chatglm.py | 8 +- python/fate_llm/model_zoo/pellm/deberta.py | 8 +- python/fate_llm/model_zoo/pellm/distilbert.py | 8 +- python/fate_llm/model_zoo/pellm/gpt2.py | 8 +- .../pellm/parameter_efficient_llm.py | 15 +- python/fate_llm/model_zoo/pellm/roberta.py | 8 +- python/fate_llm/trainer/fedipr_trainer.py | 572 ++++++++++++++++++ .../trainer/offsite_tuning_trainer.py | 300 +++++++++ 24 files changed, 2190 insertions(+), 41 deletions(-) create mode 100644 python/fate_llm/dataset/watermark.py create mode 100644 python/fate_llm/model_zoo/ipr/alexnet.py create mode 100644 python/fate_llm/model_zoo/ipr/distilbert.py create mode 100644 python/fate_llm/model_zoo/ipr/gpt2.py create mode 100644 python/fate_llm/model_zoo/ipr/resnet.py create mode 100644 python/fate_llm/model_zoo/offsite_tuning/bloom_ot.py create mode 100644 python/fate_llm/model_zoo/offsite_tuning/gpt2_ot.py create mode 100644 python/fate_llm/model_zoo/offsite_tuning/offsite_tuning_model.py create mode 100644 python/fate_llm/trainer/fedipr_trainer.py create mode 100644 python/fate_llm/trainer/offsite_tuning_trainer.py diff --git a/python/fate_llm/dataset/__init__.py b/python/fate_llm/dataset/__init__.py index ef471ba..878d3a9 100644 --- a/python/fate_llm/dataset/__init__.py +++ b/python/fate_llm/dataset/__init__.py @@ -12,4 +12,4 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# \ No newline at end of file +# diff --git a/python/fate_llm/dataset/glm_tokenizer.py b/python/fate_llm/dataset/glm_tokenizer.py index 17970f7..99f6d13 100644 --- a/python/fate_llm/dataset/glm_tokenizer.py +++ b/python/fate_llm/dataset/glm_tokenizer.py @@ -38,7 +38,8 @@ def __init__(self, truncation=True, text_max_length=256, self.truncation = truncation self.max_length = text_max_length self.tokenizer_name_or_path = tokenizer_name_or_path - self.tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name_or_path, trust_remote_code=trust_remote_code) + self.tokenizer = AutoTokenizer.from_pretrained( + self.tokenizer_name_or_path, trust_remote_code=trust_remote_code) self.tokenizer.padding_side = padding_side if pad_token is not None: self.tokenizer.add_special_tokens({'pad_token': pad_token}) @@ -65,7 +66,8 @@ def _process_data(self, line): if len(target_ids) > self.max_length - 2: target_ids = target_ids[: self.max_length - 2] - input_ids = self.tokenizer.build_inputs_with_special_tokens(prompt_ids, target_ids) + input_ids = self.tokenizer.build_inputs_with_special_tokens( + prompt_ids, target_ids) seq_length = input_ids.index(self.tokenizer.bos_token_id) labels = [-100] * seq_length + input_ids[seq_length:] diff --git a/python/fate_llm/dataset/llama_tokenizer.py b/python/fate_llm/dataset/llama_tokenizer.py index eef7502..a71b5e9 100644 --- a/python/fate_llm/dataset/llama_tokenizer.py +++ b/python/fate_llm/dataset/llama_tokenizer.py @@ -41,7 +41,8 @@ def __init__(self, text_max_length=256, self.add_special_tokens = add_special_tokens self.max_length = text_max_length self.tokenizer_name_or_path = tokenizer_name_or_path - self.tokenizer = LlamaTokenizer.from_pretrained(self.tokenizer_name_or_path, add_eos_token=add_eos_token) + self.tokenizer = LlamaTokenizer.from_pretrained( + self.tokenizer_name_or_path, add_eos_token=add_eos_token) self.tokenizer.pad_token_id = pad_token_id self.tokenizer.bos_token_id = bos_token_id self.tokenizer.eos_token_id = eos_token_id @@ -61,19 +62,22 @@ def _process_data(self, line): _response = line[self.response_column] prompt = self.prompt_template.format_map(dict(prompt=_prompt)) - prompt_ids = self.tokenizer.encode(prompt, - add_special_tokens=self.add_special_tokens, - padding=self.padding) - target_ids = self.tokenizer.encode(_response, - add_special_tokens=self.add_special_tokens, - padding=self.padding) + prompt_ids = self.tokenizer.encode( + prompt, + add_special_tokens=self.add_special_tokens, + padding=self.padding) + target_ids = self.tokenizer.encode( + _response, + add_special_tokens=self.add_special_tokens, + padding=self.padding) if len(prompt_ids) > self.max_length - 2: prompt_ids = prompt_ids[: self.max_length - 2] if len(target_ids) > self.max_length - 2: target_ids = target_ids[: self.max_length - 2] - input_ids = self.tokenizer.build_inputs_with_special_tokens(prompt_ids, target_ids) + input_ids = self.tokenizer.build_inputs_with_special_tokens( + prompt_ids, target_ids) seq_length = len(prompt_ids) + 2 labels = [-100] * seq_length + input_ids[seq_length:] diff --git a/python/fate_llm/dataset/nlp_tokenizer.py b/python/fate_llm/dataset/nlp_tokenizer.py index 79c81bb..c506088 100644 --- a/python/fate_llm/dataset/nlp_tokenizer.py +++ b/python/fate_llm/dataset/nlp_tokenizer.py @@ -42,11 +42,16 @@ class TokenizerDataset(Dataset): return_input_ids bool, whether to return input_ids or not, if False, return word_idx['input_ids'] """ - def __init__(self, truncation=True, text_max_length=128, - tokenizer_name_or_path="bert-base-uncased", - return_label=True, padding=True, padding_side="right", pad_token=None, - return_input_ids=True - ): + def __init__( + self, + truncation=True, + text_max_length=128, + tokenizer_name_or_path="bert-base-uncased", + return_label=True, + padding=True, + padding_side="right", + pad_token=None, + return_input_ids=True): super(TokenizerDataset, self).__init__() self.text = None @@ -59,7 +64,8 @@ def __init__(self, truncation=True, text_max_length=128, self.max_length = text_max_length self.with_label = return_label self.tokenizer_name_or_path = tokenizer_name_or_path - self.tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name_or_path) + self.tokenizer = AutoTokenizer.from_pretrained( + self.tokenizer_name_or_path) self.tokenizer.padding_side = padding_side self.return_input_ids = return_input_ids if pad_token is not None: diff --git a/python/fate_llm/dataset/watermark.py b/python/fate_llm/dataset/watermark.py new file mode 100644 index 0000000..3e38959 --- /dev/null +++ b/python/fate_llm/dataset/watermark.py @@ -0,0 +1,134 @@ +import os +import numpy as np +import pandas as pd +from federatedml.nn.dataset.base import Dataset +from federatedml.util import LOGGER +from federatedml.nn.dataset.image import ImageDataset + + +class WaterMarkDataset(Dataset): + + def __init__(self): + super().__init__() + self.normal_dataset = None + self.watermark_dataset = None + + def load(self, path): + raise NotImplementedError() + + def get_normal_dataset(self): + return self.normal_dataset + + def get_watermark_dataset(self): + return self.watermark_dataset + + +class WaterMarkImageDataset(WaterMarkDataset): + + """ + A basic WaterMark Dataset built on pytorch ImageFolder + This Dataset is used for Fed-IPR algorithm, see: https://arxiv.org/abs/2109.13236 for details + It will contain two part: A normal dataset and a watermark dataset + When training, the FedIPR Trainer will retrieve the normal dataset and watermark dataset from it + Given a path to image folder, WaterMarkImageDataset will load images from this folder, by default, + folder named 'normal' will be treated as normal dataset, folder named 'watermark' will be treated as watermark dataset + You can adjust this behavior by setting normal_folder_name and watermark_folder_name in the parameters + + Parameters: + ---------- + normal_folder_name: str, default is 'normal', the folder name of normal dataset + watermark_folder_name: str, default is 'watermark', the folder name of watermark dataset + """ + + def __init__( + self, + normal_folder_name='normal', + watermark_folder_name='watermark', + center_crop=False, + center_crop_shape=None, + generate_id_from_file_name=True, + file_suffix='.jpg', + float64=False, + label_dtype='long'): + + super(WaterMarkImageDataset, self).__init__() + self.normal_folder_name = normal_folder_name + self.watermark_folder_name = watermark_folder_name + + self.normal_dataset = None + self.watermark_dataset = None + + self.center_crop = center_crop + self.size = center_crop_shape + self.generate_id_from_file_name = generate_id_from_file_name + self.file_suffix = file_suffix + self.float64 = float64 + self.label_type = label_dtype + + def __getitem__(self, item): + + if item < 0: + item = len(self) + item + if item < 0: + raise IndexError('index out of range') + + if item < len(self.normal_dataset): + return ('normal', self.normal_dataset[item]) + else: + return ('watermark', + self.watermark_dataset[item - len(self.normal_dataset)]) + + def __len__(self): + len_ = 0 + if self.normal_dataset is not None: + len_ += len(self.normal_dataset) + if self.watermark_dataset is not None: + len_ += len(self.watermark_dataset) + return len_ + + def load(self, file_path): + + # normal dataset path + normal_path = os.path.join(file_path, self.normal_folder_name) + # watermark dataset path + watermark_path = os.path.join(file_path, self.watermark_folder_name) + + # load normal dataset + self.normal_dataset = ImageDataset( + center_crop=self.center_crop, + center_crop_shape=self.size, + generate_id_from_file_name=self.generate_id_from_file_name, + file_suffix=self.file_suffix, + float64=self.float64, + label_dtype=self.label_type + ) + if os.path.exists(normal_path): + self.normal_dataset.load(normal_path) + else: + self.normal_dataset = None + LOGGER.info( + f'normal dataset not found in {normal_path}, will not load normal dataset') + # load watermark dataset + self.watermark_dataset = ImageDataset( + center_crop=self.center_crop, + center_crop_shape=self.size, + generate_id_from_file_name=self.generate_id_from_file_name, + file_suffix=self.file_suffix, + float64=self.float64, + label_dtype=self.label_type + ) + if os.path.exists(watermark_path): + self.watermark_dataset.load(watermark_path) + else: + self.watermark_dataset = None + LOGGER.info( + f'watermark dataset not found in {watermark_path}, will not load watermark dataset') + + def get_normal_dataset(self): + return self.normal_dataset + + def get_watermark_dataset(self): + return self.watermark_dataset + + def get_classes(self): + return self.normal_dataset.get_classes() diff --git a/python/fate_llm/model_zoo/__init__.py b/python/fate_llm/model_zoo/__init__.py index ef471ba..878d3a9 100644 --- a/python/fate_llm/model_zoo/__init__.py +++ b/python/fate_llm/model_zoo/__init__.py @@ -12,4 +12,4 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# \ No newline at end of file +# diff --git a/python/fate_llm/model_zoo/ipr/alexnet.py b/python/fate_llm/model_zoo/ipr/alexnet.py new file mode 100644 index 0000000..57eae5f --- /dev/null +++ b/python/fate_llm/model_zoo/ipr/alexnet.py @@ -0,0 +1,59 @@ +import torch.nn as nn +from fate_llm.model_zoo.sign_block import SignatureConv, ConvBlock + + +class SignAlexNet(nn.Module): + + """ + This is a modified Alexnet: its 4,5,6 layers are replaced by Singnature Conv Block + """ + + def __init__(self, num_classes): + super().__init__() + in_channels = 3 + maxpoolidx = [1, 3, 7] + signed_layer = [4, 5, 6] + layers = [] + inp = in_channels + + # channels & kennel size + # the same setting as the FedIPR paper + oups = { + 0: 64, + 2: 192, + 4: 384, + 5: 256, + 6: 256 + } + kp = { + 0: (5, 2), + 2: (5, 2), + 4: (3, 1), + 5: (3, 1), + 6: (3, 1) + } + + for layeridx in range(8): + if layeridx in maxpoolidx: + layers.append(nn.MaxPool2d(2, 2)) + else: + k = kp[layeridx][0] + p = kp[layeridx][1] + if layeridx in signed_layer: + layers.append(SignatureConv(inp, oups[layeridx], k, 1, p)) + else: + layers.append(ConvBlock(inp, oups[layeridx], k, 1, p)) + inp = oups[layeridx] + + self.features = nn.Sequential(*layers) + self.classifier = nn.Linear(4 * 4 * 256, num_classes) + + def forward(self, x): + for m in self.features: + x = m(x) + x = x.view(x.size(0), -1) + x = self.classifier(x) + if self.training: + return x + else: # Sofmax + return nn.functional.softmax(x, dim=1) diff --git a/python/fate_llm/model_zoo/ipr/distilbert.py b/python/fate_llm/model_zoo/ipr/distilbert.py new file mode 100644 index 0000000..5ff429c --- /dev/null +++ b/python/fate_llm/model_zoo/ipr/distilbert.py @@ -0,0 +1,49 @@ +from torch.nn import Module +from transformers import DistilBertForSequenceClassification, DistilBertForTokenClassification +from fate_llm.model_zoo.sign_block import recursive_replace_layernorm + + +class SignDistilBertForTokenClassification(Module): + + def __init__(self, model_path=None, num_labels=4) -> None: + super().__init__() + if model_path is None: + model_path = 'distilbert-base-uncased' + + self.model_path = model_path + self.model = DistilBertForTokenClassification.from_pretrained( + model_path, num_labels=num_labels) + + # replace layernorm by SignatureLayerNorm + sub_distilbert = self.model.distilbert.transformer.layer[3:] + recursive_replace_layernorm( + sub_distilbert, + layer_name_set={'output_layer_norm'}) + + def forward(self, input_dict): + return self.model(**input_dict) + + +class SignDistilBertForSequenceClassification(Module): + + def __init__( + self, + model_path=None, + num_labels=4, + problem_type=None) -> None: + super().__init__() + if model_path is None: + model_path = 'distilbert-base-uncased' + + self.model_path = model_path + self.model = DistilBertForSequenceClassification.from_pretrained( + model_path, num_labels=num_labels, problem_type=problem_type) + + # replace layernorm by SignatureLayerNorm + sub_distilbert = self.model.distilbert.transformer.layer[3:] + recursive_replace_layernorm( + sub_distilbert, + layer_name_set={'output_layer_norm'}) + + def forward(self, input_dict): + return self.model(**input_dict) diff --git a/python/fate_llm/model_zoo/ipr/gpt2.py b/python/fate_llm/model_zoo/ipr/gpt2.py new file mode 100644 index 0000000..ef506bc --- /dev/null +++ b/python/fate_llm/model_zoo/ipr/gpt2.py @@ -0,0 +1,41 @@ +from torch.nn import Module +from transformers import GPT2ForTokenClassification, GPT2ForSequenceClassification +from fate_llm.model_zoo.sign_block import recursive_replace_layernorm + + +class SignGPT2ForTokenClassification(Module): + + def __init__(self, model_path=None, num_labels=4) -> None: + super().__init__() + if model_path is None: + model_path = 'gpt2' + + self.model_path = model_path + self.model = GPT2ForTokenClassification.from_pretrained( + model_path, num_labels=num_labels) + + # replace layernorm by SignatureLayerNorm + sub_gpt2 = self.model.transformer.h[10:] + recursive_replace_layernorm(sub_gpt2) + + def forward(self, input_dict): + return self.model(**input_dict) + + +class SignGPT2ForSequenceClassification(Module): + + def __init__(self, model_path=None, num_labels=2) -> None: + super().__init__() + if model_path is None: + model_path = 'gpt2' + + self.model_path = model_path + self.model = GPT2ForSequenceClassification.from_pretrained( + model_path, num_labels=num_labels) + + # replace layernorm by SignatureLayerNorm + sub_gpt2 = self.model.transformer.h[10:] + recursive_replace_layernorm(sub_gpt2) + + def forward(self, input_dict): + return self.model(**input_dict) diff --git a/python/fate_llm/model_zoo/ipr/resnet.py b/python/fate_llm/model_zoo/ipr/resnet.py new file mode 100644 index 0000000..156206e --- /dev/null +++ b/python/fate_llm/model_zoo/ipr/resnet.py @@ -0,0 +1,146 @@ +import torch.nn as nn +import torch.nn.functional as F +from fate_llm.model_zoo.sign_block import ConvBlock, SignatureConv + + +# The layer define for ResNet18, add signature to last layer +signed_layer_define = { + 'layer1': { + '0': {'convbnrelu_1': {'flag': False}, 'convbn_2': {'flag': False}}, + '1': {'convbnrelu_1': {'flag': False}, 'convbn_2': {'flag': False}} + }, + 'layer2': { + '0': {'convbnrelu_1': {'flag': False}, 'convbn_2': {'flag': False}, 'shortcut': {'flag': False}}, + '1': {'convbnrelu_1': {'flag': False}, 'convbn_2': {'flag': False}} + }, + 'layer3': { + '0': {'convbnrelu_1': {'flag': False}, 'convbn_2': {'flag': False}, 'shortcut': {'flag': False}}, + '1': {'convbnrelu_1': {'flag': False}, 'convbn_2': {'flag': False}} + }, + 'layer4': { + '0': {'convbnrelu_1': {'flag': True}, 'convbn_2': {'flag': True}, 'shortcut': {'flag': False}}, + '1': {'convbnrelu_1': {'flag': True}, 'convbn_2': {'flag': True}} + } +} + + +def get_convblock(passport_kwargs): + def convblock_(*args, **kwargs): + if passport_kwargs['flag']: + return SignatureConv(*args, **kwargs) + else: + return ConvBlock(*args, **kwargs) + + return convblock_ + + +class BasicPrivateBlock(nn.Module): + + expansion = 1 + + def __init__(self, in_planes, planes, stride=1, kwargs={}): # (512, 512, 2) (512, 512, 1) + super(BasicPrivateBlock, self).__init__() + + self.convbnrelu_1 = get_convblock( + kwargs['convbnrelu_1'])( + in_planes, planes, 3, stride, 1) + self.convbn_2 = get_convblock( + kwargs['convbn_2'])( + planes, planes, 3, 1, 1) + self.shortcut = nn.Sequential() + if stride != 1 or in_planes != self.expansion * planes: + self.shortcut = get_convblock( + kwargs['shortcut'])( + in_planes, + self.expansion * planes, + 1, + stride, + 0) # input, output, kernel_size=1 + + def forward(self, x): + + out = self.convbnrelu_1(x) + out = self.convbn_2(out) + + if not isinstance(self.shortcut, nn.Sequential): + out = out + self.shortcut(x) + else: + out = out + x + out = F.relu(out) + return out + + +class SignResnet18(nn.Module): + + # BasicPrivateBlock, [2, 2, 2, 2], **model_kwargs + def __init__(self, num_classes=100): + + super(SignResnet18, self).__init__() + num_blocks = [2, 2, 2, 2] + self.in_planes = 64 + block = BasicPrivateBlock + model_define = signed_layer_define + + self.convbnrelu_1 = ConvBlock(3, 64, 3, 1, 1) + self.layer1 = self._make_layer( + block, + 64, + num_blocks[0], + stride=1, + model_define=model_define['layer1']) + self.layer2 = self._make_layer( + block, + 128, + num_blocks[1], + stride=2, + model_define=model_define['layer2']) + self.layer3 = self._make_layer( + block, + 256, + num_blocks[2], + stride=2, + model_define=model_define['layer3']) + self.layer4 = self._make_layer( + block, + 512, + num_blocks[3], + stride=2, + model_define=model_define['layer4']) + self.linear = nn.Linear(512 * block.expansion, num_classes) + + # BasicPrivateBlock, planes = 512, numblocks = 2, stride =2, **model_kwargs + def _make_layer(self, block, planes, num_blocks, stride, model_define): + strides = [stride] + [1] * (num_blocks - 1) # [2] + [1]*1 = [2, 1] + layers = [] + for i, stride in enumerate(strides): # stride = 2 & 1 + layers.append(block(self.in_planes, planes, stride, + model_define[str(i)])) # (512, 512, 2) + self.in_planes = planes * block.expansion + return nn.Sequential(*layers) + + def forward(self, x): + + out = self.convbnrelu_1(x) + + for block in self.layer1: + out = block(out) + for block in self.layer2: + out = block(out) + for block in self.layer3: + out = block(out) + for block in self.layer4: + out = block(out) + + out = F.avg_pool2d(out, 4) + out = out.view(out.size(0), -1) + out = self.linear(out) + + if self.training: + return out + else: + return F.softmax(out, dim=1) + + +if __name__ == '__main__': + + net = SignResnet18(num_classes=10) diff --git a/python/fate_llm/model_zoo/offsite_tuning/bloom_ot.py b/python/fate_llm/model_zoo/offsite_tuning/bloom_ot.py new file mode 100644 index 0000000..b1586e6 --- /dev/null +++ b/python/fate_llm/model_zoo/offsite_tuning/bloom_ot.py @@ -0,0 +1,166 @@ +from fate_llm.model_zoo.offsite_tuning.offsite_tuning_model import OffsiteTuningSubModel, OffsiteTuningMainModel, get_dropout_emulator_and_adapters, split_numpy_array, recover_numpy_array +from transformers.models.bloom.modeling_bloom import BloomForCausalLM, BloomModel, BloomConfig +from torch import nn +import torch + + +class BloomMainModel(OffsiteTuningMainModel): + + def __init__( + self, + model_name_or_path, + emulator_layer_num: int, + adapter_top_layer_num: int = 2, + adapter_bottom_layer_num: int = 2): + + self.model_name_or_path = model_name_or_path + super().__init__( + emulator_layer_num, + adapter_top_layer_num, + adapter_bottom_layer_num) + + def get_base_model(self): + return BloomForCausalLM.from_pretrained(self.model_name_or_path) + + def get_model_transformer_blocks(self, model: BloomForCausalLM): + return model.transformer.h + + def forward(self, x): + return self.model(**x) + + def get_additional_param_state_dict(self): + # get parameter of additional parameter + model = self.model + param_dict = { + 'wte': model.transformer.word_embeddings, + 'word_ln': model.transformer.word_embeddings_layernorm, + 'last_ln_f': model.transformer.ln_f + } + + addition_weights = self.get_numpy_state_dict(param_dict) + + wte = addition_weights.pop('wte') + wte_dict = split_numpy_array(wte, 25, 'wte') + addition_weights.update(wte_dict) + return addition_weights + + def load_additional_param_state_dict(self, submodel_weights: dict): + # load additional weights: + model = self.model + param_dict = { + 'wte': model.transformer.word_embeddings, + 'word_ln': model.transformer.word_embeddings_layernorm, + 'last_ln_f': model.transformer.ln_f + } + + new_submodel_weight = {} + new_submodel_weight['last_ln_f'] = submodel_weights['last_ln_f'] + new_submodel_weight['word_ln'] = submodel_weights['word_ln'] + wte_dict = {} + for k, v in submodel_weights.items(): + if 'wte' in k: + wte_dict[k] = v + wte = recover_numpy_array(wte_dict, 'wte') + new_submodel_weight['wte'] = wte + self.load_numpy_state_dict(param_dict, new_submodel_weight) + + def forward(self, x): + return self.model(**x) + + +class BloomSubModel(OffsiteTuningSubModel): + + def __init__( + self, + model_name_or_path, + emulator_layer_num: int, + adapter_top_layer_num: int = 2, + adapter_bottom_layer_num: int = 2, + fp16_mix_precision=False, + partial_weight_decay=None): + + self.model_name_or_path = model_name_or_path + self.emulator_layer_num = emulator_layer_num + self.adapter_top_layer_num = adapter_top_layer_num + self.adapter_bottom_layer_num = adapter_bottom_layer_num + super().__init__( + emulator_layer_num, + adapter_top_layer_num, + adapter_bottom_layer_num, + fp16_mix_precision) + self.partial_weight_decay = partial_weight_decay + + # import torch as t + # state_dict = t.load('/data/projects/fate/cwj/shortcut_bloom.pkl') + # self.load_state_dict(state_dict) + + def get_base_model(self): + total_layer_num = self.emulator_layer_num + \ + self.adapter_top_layer_num + self.adapter_bottom_layer_num + config = BloomConfig.from_pretrained(self.model_name_or_path) + config.num_hidden_layers = total_layer_num + # initialize a model without pretrained weights + return BloomForCausalLM(config) + + def get_model_transformer_blocks(self, model: BloomForCausalLM): + return model.transformer.h + + def forward(self, x): + return self.model(**x) + + def get_additional_param_state_dict(self): + # get parameter of additional parameter + model = self.model + param_dict = { + 'wte': model.transformer.word_embeddings, + 'word_ln': model.transformer.word_embeddings_layernorm, + 'last_ln_f': model.transformer.ln_f + } + + addition_weights = self.get_numpy_state_dict(param_dict) + + wte = addition_weights.pop('wte') + wte_dict = split_numpy_array(wte, 25, 'wte') + addition_weights.update(wte_dict) + return addition_weights + + def load_additional_param_state_dict(self, submodel_weights: dict): + # load additional weights: + model = self.model + param_dict = { + 'wte': model.transformer.word_embeddings, + 'word_ln': model.transformer.word_embeddings_layernorm, + 'last_ln_f': model.transformer.ln_f + } + + new_submodel_weight = {} + new_submodel_weight['last_ln_f'] = submodel_weights['last_ln_f'] + new_submodel_weight['word_ln'] = submodel_weights['word_ln'] + wte_dict = {} + for k, v in submodel_weights.items(): + if 'wte' in k: + wte_dict[k] = v + wte = recover_numpy_array(wte_dict, 'wte') + new_submodel_weight['wte'] = wte + self.load_numpy_state_dict(param_dict, new_submodel_weight) + + def forward(self, x): + return self.model(**x) + + def parameters(self, recurse=True): + if self.partial_weight_decay is None: + return super().parameters(recurse) + elif isinstance(self.partial_weight_decay, float): + no_decay = ["bias", "layer_norm.weight"] + return [ + { + "params": [ + p for n, p in self.named_parameters() if not any( + nd in n for nd in no_decay)], "weight_decay": self.partial_weight_decay}, { + "params": [ + p for n, p in self.named_parameters() if any( + nd in n for nd in no_decay)], "weight_decay": 0.0}] + else: + raise ValueError( + f"partial_weight_decay should be None or float, but got {self.partial_weight_decay}") + diff --git a/python/fate_llm/model_zoo/offsite_tuning/gpt2_ot.py b/python/fate_llm/model_zoo/offsite_tuning/gpt2_ot.py new file mode 100644 index 0000000..2052cf1 --- /dev/null +++ b/python/fate_llm/model_zoo/offsite_tuning/gpt2_ot.py @@ -0,0 +1,291 @@ +from fate_llm.model_zoo.offsite_tuning.offsite_tuning_model import OffsiteTuningSubModel, OffsiteTuningMainModel, get_dropout_emulator_and_adapters, split_numpy_array, recover_numpy_array +from transformers import GPT2LMHeadModel, GPT2Config +from torch import nn +import torch +import torch as t + + +class GPT2LMHeadMainModel(OffsiteTuningMainModel): + + def __init__( + self, + model_name_or_path, + emulator_layer_num: int, + adapter_top_layer_num: int = 2, + adapter_bottom_layer_num: int = 2): + + self.model_name_or_path = model_name_or_path + super().__init__( + emulator_layer_num, + adapter_top_layer_num, + adapter_bottom_layer_num) + + def get_base_model(self): + return GPT2LMHeadModel.from_pretrained(self.model_name_or_path) + + def get_model_transformer_blocks(self, model: GPT2LMHeadModel): + return model.transformer.h + + def forward(self, x): + return self.model(**x) + + def get_additional_param_state_dict(self): + # get parameter of additional parameter + model = self.model + param_dict = { + 'wte': model.transformer.wte, + 'wpe': model.transformer.wpe, + 'last_ln_f': model.transformer.ln_f + } + + addition_weights = self.get_numpy_state_dict(param_dict) + + wte = addition_weights.pop('wte') + wte_dict = split_numpy_array(wte, 10, 'wte') + wpe = addition_weights.pop('wpe') + wpe_dict = split_numpy_array(wpe, 10, 'wpe') + addition_weights.update(wte_dict) + addition_weights.update(wpe_dict) + return addition_weights + + def load_additional_param_state_dict(self, submodel_weights: dict): + # load additional weights: + model = self.model + param_dict = { + 'wte': model.transformer.wte, + 'wpe': model.transformer.wpe, + 'last_ln_f': model.transformer.ln_f + } + + new_submodel_weight = {} + new_submodel_weight['last_ln_f'] = submodel_weights['last_ln_f'] + wte_dict, wpe_dict = {}, {} + for k, v in submodel_weights.items(): + if 'wte' in k: + wte_dict[k] = v + if 'wpe' in k: + wpe_dict[k] = v + wte = recover_numpy_array(wte_dict, 'wte') + wpe = recover_numpy_array(wpe_dict, 'wpe') + new_submodel_weight['wte'] = wte + new_submodel_weight['wpe'] = wpe + + self.load_numpy_state_dict(param_dict, new_submodel_weight) + + def forward(self, x): + return self.model(**x) + + +class GPT2LMHeadSubModel(OffsiteTuningSubModel): + + def __init__( + self, + model_name_or_path, + emulator_layer_num: int, + adapter_top_layer_num: int = 2, + adapter_bottom_layer_num: int = 2, + fp16_mix_precision=False, + partial_weight_decay=None): + + self.model_name_or_path = model_name_or_path + self.emulator_layer_num = emulator_layer_num + self.adapter_top_layer_num = adapter_top_layer_num + self.adapter_bottom_layer_num = adapter_bottom_layer_num + super().__init__( + emulator_layer_num, + adapter_top_layer_num, + adapter_bottom_layer_num, + fp16_mix_precision) + self.partial_weight_decay = partial_weight_decay + + def get_base_model(self): + total_layer_num = self.emulator_layer_num + \ + self.adapter_top_layer_num + self.adapter_bottom_layer_num + config = GPT2Config.from_pretrained(self.model_name_or_path) + config.num_hidden_layers = total_layer_num + # initialize a model without pretrained weights + return GPT2LMHeadModel(config) + + def get_model_transformer_blocks(self, model: GPT2LMHeadModel): + return model.transformer.h + + def forward(self, x): + return self.model(**x) + + def get_additional_param_state_dict(self): + # get parameter of additional parameter + model = self.model + param_dict = { + 'wte': model.transformer.wte, + 'wpe': model.transformer.wpe, + 'last_ln_f': model.transformer.ln_f + } + + addition_weights = self.get_numpy_state_dict(param_dict) + + wte = addition_weights.pop('wte') + wte_dict = split_numpy_array(wte, 10, 'wte') + wpe = addition_weights.pop('wpe') + wpe_dict = split_numpy_array(wpe, 10, 'wpe') + addition_weights.update(wte_dict) + addition_weights.update(wpe_dict) + return addition_weights + + def load_additional_param_state_dict(self, submodel_weights: dict): + # load additional weights: + model = self.model + param_dict = { + 'wte': model.transformer.wte, + 'wpe': model.transformer.wpe, + 'last_ln_f': model.transformer.ln_f + } + + new_submodel_weight = {} + new_submodel_weight['last_ln_f'] = submodel_weights['last_ln_f'] + wte_dict, wpe_dict = {}, {} + for k, v in submodel_weights.items(): + if 'wte' in k: + wte_dict[k] = v + if 'wpe' in k: + wpe_dict[k] = v + wte = recover_numpy_array(wte_dict, 'wte') + wpe = recover_numpy_array(wpe_dict, 'wpe') + new_submodel_weight['wte'] = wte + new_submodel_weight['wpe'] = wpe + + self.load_numpy_state_dict(param_dict, new_submodel_weight) + + def forward(self, x): + return self.model(**x) + + def parameters(self, recurse=True): + if self.partial_weight_decay is None: + return super().parameters(recurse) + elif isinstance(self.partial_weight_decay, float): + no_decay = ["bias", "layer_norm.weight"] + return [ + { + "params": [ + p for n, p in self.named_parameters() if not any( + nd in n for nd in no_decay)], "weight_decay": self.partial_weight_decay}, { + "params": [ + p for n, p in self.named_parameters() if any( + nd in n for nd in no_decay)], "weight_decay": 0.0}] + else: + raise ValueError( + f"partial_weight_decay should be None or float, but got {self.partial_weight_decay}") + + +# from fate_llm.model_zoo.offsite_tuning.offsite_tuning_model import OffsiteTuningSubModel, OffsiteTuningMainModel, get_dropout_emulator_and_adapters +# from transformers import GPT2LMHeadModel, GPT2Config +# from torch import nn +# import torch + + +# class GPT2LMHeadMainModel(OffsiteTuningMainModel): + +# def __init__( +# self, +# model_name_or_path, +# emulator_layer_num: int, +# adapter_top_layer_num: int = 2, +# adapter_bottom_layer_num: int = 2): + +# self.model_name_or_path = model_name_or_path +# super().__init__( +# emulator_layer_num, +# adapter_top_layer_num, +# adapter_bottom_layer_num) + +# def get_base_model(self): +# return GPT2LMHeadModel.from_pretrained(self.model_name_or_path) + +# def get_model_transformer_blocks(self, model: GPT2LMHeadModel): +# return model.transformer.h + +# def forward(self, x): +# return self.model(**x) + +# def get_additional_parameter(self, model) -> dict: +# return { +# 'wte': model.transformer.wte, +# 'wpe': model.transformer.wpe, +# 'last_ln_f': model.transformer.ln_f +# } + +# def forward(self, x): +# return self.model(**x) + + +# class GPT2LMHeadSubModel(OffsiteTuningSubModel): + +# def __init__( +# self, +# model_name_or_path, +# emulator_layer_num: int, +# adapter_top_layer_num: int = 2, +# adapter_bottom_layer_num: int = 2, +# fp16_mix_precision=False, +# partial_weight_decay=None): + +# self.model_name_or_path = model_name_or_path +# self.emulator_layer_num = emulator_layer_num +# self.adapter_top_layer_num = adapter_top_layer_num +# self.adapter_bottom_layer_num = adapter_bottom_layer_num +# super().__init__( +# emulator_layer_num, +# adapter_top_layer_num, +# adapter_bottom_layer_num, +# fp16_mix_precision) +# self.partial_weight_decay = partial_weight_decay + +# def get_base_model(self): +# total_layer_num = self.emulator_layer_num + \ +# self.adapter_top_layer_num + self.adapter_bottom_layer_num +# config = GPT2Config.from_pretrained(self.model_name_or_path) +# config.num_hidden_layers = total_layer_num +# # initialize a model without pretrained weights +# return GPT2LMHeadModel(config) + +# def get_model_transformer_blocks(self, model: GPT2LMHeadModel): +# return model.transformer.h + +# def forward(self, x): +# return self.model(**x) + +# def get_additional_parameter(self, model) -> dict: +# return { +# 'wte': model.transformer.wte, +# 'wpe': model.transformer.wpe, +# 'last_ln_f': model.transformer.ln_f +# } + +# def forward(self, x): +# return self.model(**x) + +# def parameters(self, recurse=True): +# if self.partial_weight_decay is None: +# return super().parameters(recurse) +# elif isinstance(self.partial_weight_decay, float): +# no_decay = ["bias", "layer_norm.weight"] +# return [ +# { +# "params": [ +# p for n, p in self.named_parameters() if not any( +# nd in n for nd in no_decay)], "weight_decay": self.partial_weight_decay}, { +# "params": [ +# p for n, p in self.named_parameters() if any( +# nd in n for nd in no_decay)], "weight_decay": 0.0}] +# else: +# raise ValueError( +# f"partial_weight_decay should be None or float, but got {self.partial_weight_decay}") + + +if __name__ == "__main__": + + from transformers import GPT2Model + + model = GPT2LMHeadMainModel('gpt2-xl', 12, 2, 2) + model_sub = GPT2LMHeadSubModel( + 'gpt2-xl', 12, 2, 2, fp16_mix_precision=True) + model_sub.load_submodel_weights(model.get_submodel_weights()) diff --git a/python/fate_llm/model_zoo/offsite_tuning/offsite_tuning_model.py b/python/fate_llm/model_zoo/offsite_tuning/offsite_tuning_model.py new file mode 100644 index 0000000..1c5ef56 --- /dev/null +++ b/python/fate_llm/model_zoo/offsite_tuning/offsite_tuning_model.py @@ -0,0 +1,346 @@ +import torch as t +from torch import nn +from federatedml.util import LOGGER +from transformers import AutoModel +import numpy as np + + + +def get_dropout_emulator_and_adapters( + transformer_layers: nn.ModuleList, + emulator_layer_num: int, + adapter_top_layer_num: int, + adapter_bottom_layer_num: int): + + assert adapter_bottom_layer_num > 0 and adapter_top_layer_num > 0, "adapter layer num must be greater than 0" + assert emulator_layer_num < len( + transformer_layers), "emulator layer num must be less than the number of transformer layers" + assert adapter_bottom_layer_num + adapter_top_layer_num < len( + transformer_layers), "adapter layer num must be less than the number of transformer layers" + assert emulator_layer_num < len( + transformer_layers) and emulator_layer_num > 0, "emulator layer num must be less than the number of transformer layers" + + bottom_idx = adapter_bottom_layer_num + top_idx = len(transformer_layers) - adapter_top_layer_num + bottom_layers = transformer_layers[:bottom_idx] + top_layers = transformer_layers[top_idx:] + kept_layers = transformer_layers[bottom_idx:top_idx] + emulator = nn.ModuleList() + stride = (len(kept_layers) - 1) / (emulator_layer_num - 1) + + layer_idx = [] + for i in range(emulator_layer_num): + idx = int(round(i * stride)) + layer_idx.append(idx) + emulator.append(kept_layers[idx]) + LOGGER.info( + 'take layer {} of the original model as the emulator'.format( + t.Tensor(layer_idx) + + bottom_idx)) + return nn.ModuleList(emulator), nn.ModuleList( + bottom_layers), nn.ModuleList(top_layers) + + + +def split_numpy_array(embedding_matrix, n, suffix): + # Calculate the indices where the splits should occur + embedding_matrix = embedding_matrix['weight'] + indices = np.linspace(0, embedding_matrix.shape[0], n+1, dtype=int) + + # Split the embedding matrix at the calculated indices + slices = [embedding_matrix[indices[i]:indices[i+1]] for i in range(n)] + + # Create a dictionary with the slices + result_dict = {suffix+str(i): slice for i, slice in enumerate(slices)} + return result_dict + + +def recover_numpy_array(slices_dict, suffix=""): + # Get the slices from the dictionary and concatenate them + slices = [slices_dict[suffix + str(i)] for i in range(len(slices_dict))] + complete_array = np.concatenate(slices, axis=0) + return {'weight': complete_array} + + +class OffsiteTuningBaseModel(t.nn.Module): + + def __init__(self, emulator_layer_num: int, adapter_top_layer_num: int = 2, + adapter_bottom_layer_num: int = 2, fp16_mix_precision=False): + super().__init__() + self.fp16_mix_precision = fp16_mix_precision + self.model = self.get_base_model() + self.initialize_model() + self.emulator, self.adapter_bottom, self.adapter_top = get_dropout_emulator_and_adapters( + transformer_layers=self.get_model_transformer_blocks(self.model), + emulator_layer_num=emulator_layer_num, + adapter_top_layer_num=adapter_top_layer_num, + adapter_bottom_layer_num=adapter_bottom_layer_num + ) + self.post_initialization() + + def initialize_model(self): + if self.fp16_mix_precision: + self.model.half() + for param in self.model.parameters(): + param.requires_grad = False + + def post_initialization(self): + pass + + def get_adapter_top(self): + return self.adapter_top + + def get_adapter_bottom(self): + return self.adapter_bottom + + def get_emulator(self): + return self.emulator + + def get_additional_param_state_dict(self): + # get parameter of additional parameter + return {} + + def load_additional_param_state_dict(self, submodel_weights: dict): + # load additional weights: + pass + + def load_numpy_state_dict(self, module_dict, state_dict): + param_dict = module_dict + + for k, v in param_dict.items(): + if k not in state_dict: + continue + addition_weights = { + k: t.tensor(v) for k, + v in state_dict[k].items()} + v.load_state_dict(addition_weights) + + def get_numpy_state_dict(self, module_dict): + + weight_dict = {} + for k, v in module_dict.items(): + weight_dict[k] = { + k: v.detach().cpu().numpy() for k, + v in v.state_dict().items()} + return weight_dict + + def get_submodel_weights(self) -> dict: + submodel_weights = { + "emulator": { + k: v.detach().cpu().numpy() for k, + v in self.get_emulator().state_dict().items()}, + "adapter_top": { + k: v.detach().cpu().numpy() for k, + v in self.get_adapter_top().state_dict().items()}, + "adapter_bottom": { + k: v.detach().cpu().numpy() for k, + v in self.get_adapter_bottom().state_dict().items()}} + addition_weights = self.get_additional_param_state_dict() + submodel_weights.update(addition_weights) + return submodel_weights + + def load_submodel_weights(self, submodel_weights: dict): + + emulator_weights = { + k: t.tensor(v) for k, + v in submodel_weights['emulator'].items()} + adapter_top_weights = { + k: t.tensor(v) for k, + v in submodel_weights['adapter_top'].items()} + adapter_bottom_weights = { + k: t.tensor(v) for k, + v in submodel_weights['adapter_bottom'].items()} + + emulator = self.get_emulator() + adapter_top = self.get_adapter_top() + adapter_bottom = self.get_adapter_bottom() + + emulator.load_state_dict(emulator_weights) + adapter_top.load_state_dict(adapter_top_weights) + adapter_bottom.load_state_dict(adapter_bottom_weights) + self.load_additional_param_state_dict(submodel_weights) + + def forward(self, **kwargs): + raise NotImplementedError() + + def get_base_model(self): + raise NotImplementedError() + + def get_model_transformer_blocks(self, model: t.nn.Module): + raise NotImplementedError() + + +class OffsiteTuningMainModel(OffsiteTuningBaseModel): + + def post_initialization(self): + pass + + +class OffsiteTuningSubModel(OffsiteTuningBaseModel): + + def post_initialization(self): + # mix precision model training + for param in self.adapter_top.parameters(): + param.data = param.data.float() + param.requires_grad = True + for param in self.adapter_bottom.parameters(): + param.data = param.data.float() + param.requires_grad = True + + +# import torch as t +# from torch import nn +# from federatedml.util import LOGGER +# from transformers import AutoModel + + +# def get_dropout_emulator_and_adapters( +# transformer_layers: nn.ModuleList, +# emulator_layer_num: int, +# adapter_top_layer_num: int, +# adapter_bottom_layer_num: int): + +# assert adapter_bottom_layer_num > 0 and adapter_top_layer_num > 0, "adapter layer num must be greater than 0" +# assert emulator_layer_num < len( +# transformer_layers), "emulator layer num must be less than the number of transformer layers" +# assert adapter_bottom_layer_num + adapter_top_layer_num < len( +# transformer_layers), "adapter layer num must be less than the number of transformer layers" +# assert emulator_layer_num < len( +# transformer_layers) and emulator_layer_num > 0, "emulator layer num must be less than the number of transformer layers" + +# bottom_idx = adapter_bottom_layer_num +# top_idx = len(transformer_layers) - adapter_top_layer_num +# bottom_layers = transformer_layers[:bottom_idx] +# top_layers = transformer_layers[top_idx:] +# kept_layers = transformer_layers[bottom_idx:top_idx] +# emulator = nn.ModuleList() +# stride = (len(kept_layers) - 1) / (emulator_layer_num - 1) + +# layer_idx = [] +# for i in range(emulator_layer_num): +# idx = int(round(i * stride)) +# layer_idx.append(idx) +# emulator.append(kept_layers[idx]) +# LOGGER.info( +# 'take layer {} of the original model as the emulator'.format( +# t.Tensor(layer_idx) + +# bottom_idx)) +# return nn.ModuleList(emulator), nn.ModuleList( +# bottom_layers), nn.ModuleList(top_layers) + + +# class OffsiteTuningBaseModel(t.nn.Module): + +# def __init__(self, emulator_layer_num: int, adapter_top_layer_num: int = 2, +# adapter_bottom_layer_num: int = 2, fp16_mix_precision=False): +# super().__init__() +# self.fp16_mix_precision = fp16_mix_precision +# self.model = self.get_base_model() +# self.initialize_model() +# self.emulator, self.adapter_bottom, self.adapter_top = get_dropout_emulator_and_adapters( +# transformer_layers=self.get_model_transformer_blocks(self.model), +# emulator_layer_num=emulator_layer_num, +# adapter_top_layer_num=adapter_top_layer_num, +# adapter_bottom_layer_num=adapter_bottom_layer_num +# ) +# self.addition_param = self.get_additional_parameter(self.model) +# self.post_initialization() + +# def initialize_model(self): +# if self.fp16_mix_precision: +# self.model.half() +# for param in self.model.parameters(): +# param.requires_grad = False + +# def post_initialization(self): +# pass + +# def get_adapter_top(self): +# return self.adapter_top + +# def get_adapter_bottom(self): +# return self.adapter_bottom + +# def get_emulator(self): +# return self.emulator + +# def get_submodel_weights(self) -> dict: +# submodel_weights = { +# "emulator": { +# k: v.detach().cpu().numpy() for k, +# v in self.get_emulator().state_dict().items()}, +# "adapter_top": { +# k: v.detach().cpu().numpy() for k, +# v in self.get_adapter_top().state_dict().items()}, +# "adapter_bottom": { +# k: v.detach().cpu().numpy() for k, +# v in self.get_adapter_bottom().state_dict().items()}} + +# # get parameter of additional parameter +# addition_weights = {} +# for k, v in self.addition_param.items(): +# addition_weights[k] = { +# k: v.detach().cpu().numpy() for k, +# v in v.state_dict().items()} +# submodel_weights.update(addition_weights) + +# return submodel_weights + +# def load_submodel_weights(self, submodel_weights: dict): + +# emulator_weights = { +# k: t.tensor(v) for k, +# v in submodel_weights['emulator'].items()} +# adapter_top_weights = { +# k: t.tensor(v) for k, +# v in submodel_weights['adapter_top'].items()} +# adapter_bottom_weights = { +# k: t.tensor(v) for k, +# v in submodel_weights['adapter_bottom'].items()} + +# emulator = self.get_emulator() +# adapter_top = self.get_adapter_top() +# adapter_bottom = self.get_adapter_bottom() + +# emulator.load_state_dict(emulator_weights) +# adapter_top.load_state_dict(adapter_top_weights) +# adapter_bottom.load_state_dict(adapter_bottom_weights) + +# # load additional weights: +# for k, v in self.addition_param.items(): +# if k not in submodel_weights: +# continue +# addition_weights = { +# k: t.tensor(v) for k, +# v in submodel_weights[k].items()} +# v.load_state_dict(addition_weights) + +# def forward(self, **kwargs): +# raise NotImplementedError() + +# def get_base_model(self): +# raise NotImplementedError() + +# def get_model_transformer_blocks(self, model: t.nn.Module): +# raise NotImplementedError() + +# def get_additional_parameter(self, model) -> dict: +# return {} + + +# class OffsiteTuningMainModel(OffsiteTuningBaseModel): + +# def post_initialization(self): +# pass + + +# class OffsiteTuningSubModel(OffsiteTuningBaseModel): + +# def post_initialization(self): +# # mix precision model training +# for param in self.adapter_top.parameters(): +# param.data = param.data.float() +# param.requires_grad = True +# for param in self.adapter_bottom.parameters(): +# param.data = param.data.float() +# param.requires_grad = True \ No newline at end of file diff --git a/python/fate_llm/model_zoo/pellm/albert.py b/python/fate_llm/model_zoo/pellm/albert.py index 24a9ddd..c033dd0 100644 --- a/python/fate_llm/model_zoo/pellm/albert.py +++ b/python/fate_llm/model_zoo/pellm/albert.py @@ -34,8 +34,12 @@ def __init__(self, config: dict = None, self.check_config(pretain_path=pretrained_path) if config is None and pretrained_path is None: config = AlbertConfig().to_dict() # use default model setting - super().__init__(config=config, pretrained_path=pretrained_path, - peft_type=peft_type, peft_config=peft_config, **kwargs) + super().__init__( + config=config, + pretrained_path=pretrained_path, + peft_type=peft_type, + peft_config=peft_config, + **kwargs) def check_config(self, pretain_path): config = AutoConfig.from_pretrained(pretain_path) diff --git a/python/fate_llm/model_zoo/pellm/bart.py b/python/fate_llm/model_zoo/pellm/bart.py index d401bfb..f6be713 100644 --- a/python/fate_llm/model_zoo/pellm/bart.py +++ b/python/fate_llm/model_zoo/pellm/bart.py @@ -32,8 +32,12 @@ def __init__(self, config: dict = None, self.check_config(pretrain_path=pretrained_path) if config is None and pretrained_path is None: config = BartConfig().to_dict() - super().__init__(config=config, pretrained_path=pretrained_path, - peft_type=peft_type, peft_config=peft_config, **kwargs) + super().__init__( + config=config, + pretrained_path=pretrained_path, + peft_type=peft_type, + peft_config=peft_config, + **kwargs) def check_config(self, pretrain_path): config = AutoConfig.from_pretrained(pretrain_path) diff --git a/python/fate_llm/model_zoo/pellm/bert.py b/python/fate_llm/model_zoo/pellm/bert.py index c95bc92..ec8cd46 100644 --- a/python/fate_llm/model_zoo/pellm/bert.py +++ b/python/fate_llm/model_zoo/pellm/bert.py @@ -32,8 +32,12 @@ def __init__(self, config: dict = None, self.check_config(pretrain_path=pretrained_path) if config is None and pretrained_path is None: config = BertConfig().to_dict() - super().__init__(config=config, pretrained_path=pretrained_path, - peft_type=peft_type, peft_config=peft_config, **kwargs) + super().__init__( + config=config, + pretrained_path=pretrained_path, + peft_type=peft_type, + peft_config=peft_config, + **kwargs) def check_config(self, pretrain_path): config = AutoConfig.from_pretrained(pretrain_path) diff --git a/python/fate_llm/model_zoo/pellm/chatglm.py b/python/fate_llm/model_zoo/pellm/chatglm.py index 8648cf4..98e0c7d 100644 --- a/python/fate_llm/model_zoo/pellm/chatglm.py +++ b/python/fate_llm/model_zoo/pellm/chatglm.py @@ -37,12 +37,16 @@ def __init__(self, peft_config=peft_config) def init_config(self): - self.config = AutoConfig.from_pretrained(self.config_path, trust_remote_code=True) + self.config = AutoConfig.from_pretrained( + self.config_path, trust_remote_code=True) self.config.pre_seq_len = self.pre_seq_len self.config.prefix_projection = self.prefix_projection def init_base_lm(self): - super(ChatGLMForConditionalGeneration, self).init_base_lm(trust_remote_code=True) + super( + ChatGLMForConditionalGeneration, + self).init_base_lm( + trust_remote_code=True) if self.fp16: self._pe_lm.half() diff --git a/python/fate_llm/model_zoo/pellm/deberta.py b/python/fate_llm/model_zoo/pellm/deberta.py index 376dcb2..56f7857 100644 --- a/python/fate_llm/model_zoo/pellm/deberta.py +++ b/python/fate_llm/model_zoo/pellm/deberta.py @@ -33,8 +33,12 @@ def __init__(self, config: dict = None, self.check_config(pretrain_path=pretrained_path) if config is None and pretrained_path is None: config = DebertaConfig().to_dict() - super().__init__(config=config, pretrained_path=pretrained_path, - peft_type=peft_type, peft_config=peft_config, **kwargs) + super().__init__( + config=config, + pretrained_path=pretrained_path, + peft_type=peft_type, + peft_config=peft_config, + **kwargs) def check_config(self, pretrain_path): config = AutoConfig.from_pretrained(pretrain_path) diff --git a/python/fate_llm/model_zoo/pellm/distilbert.py b/python/fate_llm/model_zoo/pellm/distilbert.py index c23e62f..ac44920 100644 --- a/python/fate_llm/model_zoo/pellm/distilbert.py +++ b/python/fate_llm/model_zoo/pellm/distilbert.py @@ -32,8 +32,12 @@ def __init__(self, config: dict = None, self.check_config(pretrain_path=pretrained_path) if config is None and pretrained_path is None: config = DistilBertConfig().to_dict() - super().__init__(config=config, pretrained_path=pretrained_path, - peft_type=peft_type, peft_config=peft_config, **kwargs) + super().__init__( + config=config, + pretrained_path=pretrained_path, + peft_type=peft_type, + peft_config=peft_config, + **kwargs) def check_config(self, pretrain_path): config = AutoConfig.from_pretrained(pretrain_path) diff --git a/python/fate_llm/model_zoo/pellm/gpt2.py b/python/fate_llm/model_zoo/pellm/gpt2.py index dcfa036..aceca10 100644 --- a/python/fate_llm/model_zoo/pellm/gpt2.py +++ b/python/fate_llm/model_zoo/pellm/gpt2.py @@ -32,8 +32,12 @@ def __init__(self, config: dict = None, self.check_config(pretrain_path=pretrained_path) if config is None and pretrained_path is None: config = GPT2Config().to_dict() - super().__init__(config=config, pretrained_path=pretrained_path, - peft_type=peft_type, peft_config=peft_config, **kwargs) + super().__init__( + config=config, + pretrained_path=pretrained_path, + peft_type=peft_type, + peft_config=peft_config, + **kwargs) def check_config(self, pretrain_path): config = AutoConfig.from_pretrained(pretrain_path) diff --git a/python/fate_llm/model_zoo/pellm/parameter_efficient_llm.py b/python/fate_llm/model_zoo/pellm/parameter_efficient_llm.py index a120ce2..540ec0f 100644 --- a/python/fate_llm/model_zoo/pellm/parameter_efficient_llm.py +++ b/python/fate_llm/model_zoo/pellm/parameter_efficient_llm.py @@ -78,9 +78,11 @@ def init_config(self, **kwargs): def init_base_lm(self, **kwargs): model_loader = self.model_loader if self.model_loader is not None else AutoModel if self.config is not None: - self._pe_lm = model_loader.from_pretrained(self.config_path, config=self.config, **kwargs) + self._pe_lm = model_loader.from_pretrained( + self.config_path, config=self.config, **kwargs) elif self.config_path is not None: - self._pe_lm = model_loader.from_pretrained(self.config_path, **kwargs) + self._pe_lm = model_loader.from_pretrained( + self.config_path, **kwargs) else: raise ValueError( 'config_path to pretrained model folder cannot be None') @@ -116,13 +118,14 @@ def forward(self, tokenized_data: dict): def save_pretrained(self, path): if not self.enable_save_pretrained: - raise ValueError("To save trainable parameters only, set enable_save_pretrained=True in your model") + raise ValueError( + "To save trainable parameters only, set enable_save_pretrained=True in your model") from pathlib import Path state_dict = { - k: p.to("cpu") for k, p in self._pe_lm.named_parameters() if p.requires_grad - } + k: p.to("cpu") for k, + p in self._pe_lm.named_parameters() if p.requires_grad} Path.mkdir(Path(path), exist_ok=True) torch.save(state_dict, Path(path).joinpath("adapter_model.bin")) @@ -131,5 +134,3 @@ class AutoPELLM(PELLM): def __init__(self, **kwargs) -> None: super().__init__(**kwargs) - - diff --git a/python/fate_llm/model_zoo/pellm/roberta.py b/python/fate_llm/model_zoo/pellm/roberta.py index 33d1079..abcee82 100644 --- a/python/fate_llm/model_zoo/pellm/roberta.py +++ b/python/fate_llm/model_zoo/pellm/roberta.py @@ -32,8 +32,12 @@ def __init__(self, config: dict = None, self.check_config(pretrain_path=pretrained_path) if config is None and pretrained_path is None: config = RobertaConfig().to_dict() - super().__init__(config=config, pretrained_path=pretrained_path, - peft_type=peft_type, peft_config=peft_config, **kwargs) + super().__init__( + config=config, + pretrained_path=pretrained_path, + peft_type=peft_type, + peft_config=peft_config, + **kwargs) def check_config(self, pretrain_path): config = AutoConfig.from_pretrained(pretrain_path) diff --git a/python/fate_llm/trainer/fedipr_trainer.py b/python/fate_llm/trainer/fedipr_trainer.py new file mode 100644 index 0000000..76f3141 --- /dev/null +++ b/python/fate_llm/trainer/fedipr_trainer.py @@ -0,0 +1,572 @@ +import torch as t +import tqdm +import numpy as np +import torch +from typing import Literal +from federatedml.nn.homo.trainer.fedavg_trainer import FedAVGTrainer +from federatedml.nn.backend.utils import distributed_util +from torch.utils.data import DataLoader, DistributedSampler +import torch.distributed as dist +from federatedml.nn.dataset.watermark import WaterMarkImageDataset, WaterMarkDataset +from federatedml.util import LOGGER +from fate_llm.model_zoo.sign_block import generate_signature, is_sign_block +from fate_llm.model_zoo.sign_block import SignatureBlock +from sklearn.metrics import accuracy_score +from federatedml.nn.dataset.base import Dataset +from federatedml.util import consts + + +def get_sign_blocks(model: torch.nn.Module): + + record_sign_block = {} + for name, m in model.named_modules(): + if is_sign_block(m): + record_sign_block[name] = m + + return record_sign_block + + +def get_keys(sign_block_dict: dict, num_bits: int): + + key_pairs = {} + param_len = [] + sum_allocated_bits = 0 + # Iterate through each layer and compute the flattened parameter lengths + for k, v in sign_block_dict.items(): + param_len.append(len(v.embeded_param.flatten())) + total_param_len = sum(param_len) + + alloc_bits = {} + + for i, (k, v) in enumerate(sign_block_dict.items()): + allocated_bits = int((param_len[i] / total_param_len) * num_bits) + alloc_bits[k] = allocated_bits + sum_allocated_bits += allocated_bits + + rest_bits = num_bits - sum_allocated_bits + if rest_bits > 0: + alloc_bits[k] += rest_bits + + for k, v in sign_block_dict.items(): + key_pairs[k] = generate_signature(v, alloc_bits[k]) + + return key_pairs + + +""" +Verify Tools +""" + + +def to_cuda(var, device=0): + if hasattr(var, 'cuda'): + return var.cuda(device) + elif isinstance(var, tuple) or isinstance(var, list): + ret = tuple(to_cuda(i) for i in var) + return ret + elif isinstance(var, dict): + for k in var: + if hasattr(var[k], 'cuda'): + var[k] = var[k].cuda(device) + return var + else: + return var + + +def _verify_sign_blocks(sign_blocks, keys, cuda=False, device=None): + + signature_correct_count = 0 + total_bit = 0 + for name, block in sign_blocks.items(): + block: SignatureBlock = block + W, signature = keys[name] + if cuda: + W = to_cuda(W, device=device) + signature = to_cuda(signature, device=device) + extract_bits = block.extract_sign(W) + total_bit += len(extract_bits) + signature_correct_count += (extract_bits == + signature).sum().detach().cpu().item() + + sign_acc = signature_correct_count / total_bit + return sign_acc + + +def _suggest_sign_bit(param_num, client_num): + max_signbit = param_num // client_num + max_signbit -= 1 # not to exceed + if max_signbit <= 0: + raise ValueError( + 'not able to add feature based watermark, param_num is {}, client num is {}, computed max bit is {} <=0'.format( + param_num, client_num, max_signbit)) + return max_signbit + + +def compute_sign_bit(model, client_num): + total_param_num = 0 + blocks = get_sign_blocks(model) + for k, v in blocks.items(): + total_param_num += v.embeded_param_num() + if total_param_num == 0: + return 0 + return _suggest_sign_bit(total_param_num, client_num) + + +def verify_feature_based_signature(model, keys): + + model = model.cpu() + sign_blocks = get_sign_blocks(model) + return _verify_sign_blocks(sign_blocks, keys, cuda=False) + + +class FedIPRTrainer(FedAVGTrainer): + + def __init__(self, + epochs=10, + noraml_dataset_batch_size=32, + watermark_dataset_batch_size=2, + early_stop=None, + tol=0.0001, + secure_aggregate=True, + weighted_aggregation=True, + aggregate_every_n_epoch=None, + cuda=None, + pin_memory=True, + shuffle=True, + data_loader_worker=0, + validation_freqs=None, + checkpoint_save_freqs=None, + task_type='auto', + save_to_local_dir=False, + collate_fn=None, + collate_fn_params=None, + alpha=0.01, + verify_freqs=1, + backdoor_verify_method: Literal['accuracy', + 'loss'] = 'accuracy'): + + super().__init__( + epochs, + noraml_dataset_batch_size, + early_stop, + tol, + secure_aggregate, + weighted_aggregation, + aggregate_every_n_epoch, + cuda, + pin_memory, + shuffle, + data_loader_worker, + validation_freqs, + checkpoint_save_freqs, + task_type, + save_to_local_dir, + collate_fn, + collate_fn_params) + + self.normal_train_set = None + self.watermark_set = None + self.data_loader = None + self.normal_dataset_batch_size = noraml_dataset_batch_size + self.watermark_dataset_batch_size = watermark_dataset_batch_size + self.alpha = alpha + self.verify_freqs = verify_freqs + self.backdoor_verify_method = backdoor_verify_method + self._sign_keys = None + self._sign_blocks = None + self._client_num = None + self._sign_bits = None + + assert self.alpha > 0, 'alpha must be greater than 0' + assert self.verify_freqs > 0 and isinstance( + self.verify_freqs, int), 'verify_freqs must be greater than 0' + assert self.backdoor_verify_method in [ + 'accuracy', 'loss'], 'backdoor_verify_method must be accuracy or loss' + + def local_mode(self): + self.fed_mode = False + self._client_num = 1 + + def _handle_dataset(self, train_set, collate_fn): + + if not distributed_util.is_distributed() or distributed_util.get_num_workers() <= 1: + return DataLoader( + train_set, + batch_size=self.batch_size, + pin_memory=self.pin_memory, + shuffle=self.shuffle, + num_workers=self.data_loader_worker, + collate_fn=collate_fn + ) + else: + train_sampler = DistributedSampler( + train_set, + num_replicas=dist.get_world_size(), + rank=dist.get_rank() + ) + return DataLoader( + train_set, + batch_size=self.batch_size, + pin_memory=self.pin_memory, + num_workers=self.data_loader_worker, + collate_fn=collate_fn, + sampler=train_sampler + ) + + def _get_train_data_loader(self, train_set): + + collate_fn = self._get_collate_fn(train_set) + + if isinstance(train_set, WaterMarkDataset): + LOGGER.info( + 'detect watermark dataset, split watermark dataset and normal dataset') + normal_train_set = train_set.get_normal_dataset() + watermark_set = train_set.get_watermark_dataset() + if normal_train_set is None: + raise ValueError( + 'normal dataset must not be None in FedIPR algo') + train_dataloder = self._handle_dataset( + normal_train_set, collate_fn) + + if watermark_set is not None: + watermark_dataloader = self._handle_dataset( + watermark_set, collate_fn) + else: + watermark_dataloader = None + self.normal_train_set = normal_train_set + self.watermark_set = watermark_set + dataloaders = { + 'train': train_dataloder, + 'watermark': watermark_dataloader} + return dataloaders + else: + LOGGER.info('detect non-watermark dataset') + train_dataloder = self._handle_dataset(train_set, collate_fn) + dataloaders = {'train': train_dataloder, 'watermark': None} + return dataloaders + + def _get_device(self): + if self.cuda is not None or self._enable_deepspeed: + device = self.cuda_main_device if self.cuda_main_device is not None else self.model.device + return device + else: + return None + + def verify(self, sign_blocks: dict, keys: dict): + + return _verify_sign_blocks( + sign_blocks, + keys, + self.cuda is not None, + self._get_device()) + + def get_loss_from_pred(self, loss, pred, batch_label): + + if not loss and hasattr(pred, "loss"): + batch_loss = pred.loss + + elif loss is not None: + if batch_label is None: + raise ValueError( + "When loss is set, please provide label to calculate loss" + ) + if not isinstance(pred, torch.Tensor) and hasattr(pred, "logits"): + pred = pred.logits + batch_loss = loss(pred, batch_label) + else: + raise ValueError( + 'Trainer requires a loss function, but got None, please specify loss function in the' + ' job configuration') + + return batch_loss + + def _get_keys(self, sign_blocks): + + if self._sign_keys is None: + self._sign_keys = get_keys(sign_blocks, self._sign_bits) + return self._sign_keys + + def _get_sign_blocks(self): + if self._sign_blocks is None: + sign_blocks = get_sign_blocks(self.model) + self._sign_blocks = sign_blocks + + return self._sign_blocks + + def train( + self, + train_set: Dataset, + validate_set: Dataset = None, + optimizer=None, + loss=None, + extra_dict={}): + + if 'keys' in extra_dict: + self._sign_keys = extra_dict['keys'] + self._sign_bits = extra_dict['num_bits'] + else: + LOGGER.info('computing feature based sign bits') + if self._client_num is None and self.party_id_list is not None: + self._client_num = len(self.party_id_list) + self._sign_bits = compute_sign_bit(self.model, self._client_num) + + LOGGER.info( + 'client num {}, party id list {}'.format( + self._client_num, + self.party_id_list)) + LOGGER.info( + 'will assign {} bits for feature based watermark'.format( + self._sign_bits)) + return super().train(train_set, validate_set, optimizer, loss, extra_dict) + + def train_an_epoch( + self, + epoch_idx, + model, + train_set, + optimizer, + loss_func): + + epoch_loss = 0.0 + batch_idx = 0 + acc_num = 0 + + sign_blocks = self._get_sign_blocks() + keys = self._get_keys(sign_blocks) + + dl, watermark_dl = self.data_loader['train'], self.data_loader['watermark'] + if isinstance(dl, DistributedSampler): + dl.sampler.set_epoch(epoch_idx) + if isinstance(watermark_dl, DistributedSampler): + watermark_dl.sampler.set_epoch(epoch_idx) + + if not self.fed_mode: + trainset_iterator = tqdm.tqdm(dl) + else: + trainset_iterator = dl + batch_label = None + + # collect watermark data and mix them into the training data + watermark_collect = [] + if watermark_dl is not None: + for watermark_batch in watermark_dl: + watermark_collect.append(watermark_batch) + + for _batch_iter in trainset_iterator: + + _batch_iter = self._decode(_batch_iter) + + if isinstance(_batch_iter, list) or isinstance(_batch_iter, tuple): + batch_data, batch_label = _batch_iter + else: + batch_data = _batch_iter + + if watermark_dl is not None: + # Mix the backdoor sample into the training data + wm_batch_idx = int(batch_idx % len(watermark_collect)) + wm_batch = watermark_collect[wm_batch_idx] + if isinstance(wm_batch, list): + wm_batch_data, wm_batch_label = wm_batch + batch_data = torch.cat([batch_data, wm_batch_data], dim=0) + batch_label = torch.cat( + [batch_label, wm_batch_label], dim=0) + else: + wm_batch_data = wm_batch + batch_data = torch.cat([batch_data, wm_batch_data], dim=0) + + if self.cuda is not None or self._enable_deepspeed: + device = self.cuda_main_device if self.cuda_main_device is not None else self.model.device + batch_data = self.to_cuda(batch_data, device) + if batch_label is not None: + batch_label = self.to_cuda(batch_label, device) + + if not self._enable_deepspeed: + optimizer.zero_grad() + else: + model.zero_grad() + + pred = model(batch_data) + + sign_loss = 0 + # Get the sign loss of model + for name, block in sign_blocks.items(): + + block: SignatureBlock = block + W, signature = keys[name] + if self.cuda is not None: + device = self._get_device() + W = self.to_cuda(W, device) + signature = self.to_cuda(signature, device) + sign_loss += self.alpha * block.sign_loss(W, signature) + + batch_loss = self.get_loss_from_pred(loss_func, pred, batch_label) + batch_loss += sign_loss + + if not self._enable_deepspeed: + + batch_loss.backward() + optimizer.step() + batch_loss_np = np.array( + batch_loss.detach().tolist()) if self.cuda is None else np.array( + batch_loss.cpu().detach().tolist()) + + if acc_num + self.batch_size > len(train_set): + batch_len = len(train_set) - acc_num + else: + batch_len = self.batch_size + + epoch_loss += batch_loss_np * batch_len + else: + batch_loss = model.backward(batch_loss) + batch_loss_np = np.array(batch_loss.cpu().detach().tolist()) + model.step() + batch_loss_np = self._sync_loss( + batch_loss_np * self._get_batch_size(batch_data)) + if distributed_util.is_rank_0(): + epoch_loss += batch_loss_np + + batch_idx += 1 + + if self.fed_mode: + LOGGER.debug( + 'epoch {} batch {} finished'.format(epoch_idx, batch_idx)) + + epoch_loss = epoch_loss / len(train_set) + + # verify the sign of model during training + if epoch_idx % self.verify_freqs == 0: + # verify feature-based signature + sign_acc = self.verify(sign_blocks, keys) + LOGGER.info(f"epoch {epoch_idx} sign accuracy: {sign_acc}") + # verify backdoor signature + if self.watermark_set is not None: + _, pred, label = self._predict(self.watermark_set) + pred = pred.detach().cpu() + label = label.detach().cpu() + if self.backdoor_verify_method == 'accuracy': + if not isinstance( + pred, torch.Tensor) and hasattr( + pred, "logits"): + pred = pred.logits + pred = pred.numpy().reshape((len(label), -1)) + label = label.numpy() + pred_label = np.argmax(pred, axis=1) + metric = accuracy_score( + pred_label.flatten(), label.flatten()) + else: + metric = self.get_loss_from_pred(loss_func, pred, label) + + LOGGER.info( + f"epoch {epoch_idx} backdoor {self.backdoor_verify_method}: {metric}") + + return epoch_loss + + def _predict(self, dataset: Dataset): + pred_result = [] + + # switch eval mode + dataset.eval() + model = self._select_model() + model.eval() + + if not dataset.has_sample_ids(): + dataset.init_sid_and_getfunc(prefix=dataset.get_type()) + + labels = [] + with torch.no_grad(): + for _batch_iter in DataLoader( + dataset, self.batch_size + ): + if isinstance(_batch_iter, list): + batch_data, batch_label = _batch_iter + else: + batch_label = _batch_iter.pop("labels") + batch_data = _batch_iter + + if self.cuda is not None or self._enable_deepspeed: + device = self.cuda_main_device if self.cuda_main_device is not None else self.model.device + batch_data = self.to_cuda(batch_data, device) + + pred = model(batch_data) + + if not isinstance( + pred, torch.Tensor) and hasattr( + pred, "logits"): + pred = pred.logits + + pred_result.append(pred) + labels.append(batch_label) + + ret_rs = torch.concat(pred_result, axis=0) + ret_label = torch.concat(labels, axis=0) + + # switch back to train mode + dataset.train() + model.train() + + return dataset.get_sample_ids(), ret_rs, ret_label + + def predict(self, dataset: Dataset): + + if self.task_type in [consts.CAUSAL_LM, consts.SEQ_2_SEQ_LM]: + LOGGER.warning( + f"Not support prediction of task_types={[consts.CAUSAL_LM, consts.SEQ_2_SEQ_LM]}") + return + + if distributed_util.is_distributed() and not distributed_util.is_rank_0(): + return + + if isinstance(dataset, WaterMarkDataset): + normal_train_set = dataset.get_normal_dataset() + if normal_train_set is None: + raise ValueError( + 'normal train set is None in FedIPR algo predict function') + else: + normal_train_set = normal_train_set + + ids, ret_rs, ret_label = self._predict(normal_train_set) + + if self.fed_mode: + return self.format_predict_result( + ids, ret_rs, ret_label, task_type=self.task_type) + else: + return ret_rs, ret_label + + def save( + self, + model=None, + epoch_idx=-1, + optimizer=None, + converge_status=False, + loss_history=None, + best_epoch=-1, + extra_data={}): + + extra_data = {'keys': self._sign_keys, 'num_bits': self._sign_bits} + super().save( + model, + epoch_idx, + optimizer, + converge_status, + loss_history, + best_epoch, + extra_data) + + def local_save(self, + model=None, + epoch_idx=-1, + optimizer=None, + converge_status=False, + loss_history=None, + best_epoch=-1, + extra_data={}): + + extra_data = {'keys': self._sign_keys, 'num_bits': self._sign_bits} + super().local_save( + model, + epoch_idx, + optimizer, + converge_status, + loss_history, + best_epoch, + extra_data) diff --git a/python/fate_llm/trainer/offsite_tuning_trainer.py b/python/fate_llm/trainer/offsite_tuning_trainer.py new file mode 100644 index 0000000..ba9c1bd --- /dev/null +++ b/python/fate_llm/trainer/offsite_tuning_trainer.py @@ -0,0 +1,300 @@ +import torch as t +from federatedml.nn.homo.trainer.fedavg_trainer import FedAVGTrainer +from federatedml.framework.homo.aggregator.secure_aggregator import SecureAggregatorClient as SecureAggClient +from federatedml.framework.homo.aggregator.secure_aggregator import SecureAggregatorServer as SecureAggServer +from federatedml.util import LOGGER +from fate_llm.model_zoo.offsite_tuning.offsite_tuning_model import OffsiteTuningSubModel, OffsiteTuningMainModel +from federatedml.util import consts +from federatedml.nn.backend.utils import deepspeed_util +from federatedml.nn.backend.utils import distributed_util +import torch.distributed as dist + + + +def count_parameters(model: t.nn.Module): + return sum(p.numel() for p in model.parameters()) + + +def count_trainable_parameters(model): + return sum(p.numel() for p in model.parameters() if p.requires_grad) + + +class OffsiteTuningTrainer(FedAVGTrainer): + + def __init__(self, epochs=10, batch_size=512, # training parameter + early_stop=None, tol=0.0001, # early stop parameters + secure_aggregate=True, weighted_aggregation=True, aggregate_every_n_epoch=None, # federation + cuda=None, + pin_memory=True, shuffle=True, data_loader_worker=0, # GPU & dataloader + validation_freqs=None, # validation configuration + checkpoint_save_freqs=None, # checkpoint configuration + task_type='auto', # task type + save_to_local_dir=False, # save model to local path + collate_fn=None, + collate_fn_params=None, + need_aggregate=False + ): + + super().__init__( + epochs=epochs, + batch_size=batch_size, + early_stop=early_stop, + tol=tol, + secure_aggregate=secure_aggregate, + weighted_aggregation=weighted_aggregation, + aggregate_every_n_epoch=aggregate_every_n_epoch, + cuda=cuda, + pin_memory=pin_memory, + shuffle=shuffle, + data_loader_worker=data_loader_worker, + validation_freqs=validation_freqs, + checkpoint_save_freqs=checkpoint_save_freqs, + task_type=task_type, + save_to_local_dir=save_to_local_dir, + collate_fn=collate_fn, + collate_fn_params=collate_fn_params) + + self.need_aggregate = need_aggregate + self.model_transvar = None + + + def _send_submodel_weights(self, state_dict, send_func, suffix='start'): + from fate_arch.session import computing_session as session + emulator = state_dict['emulator'] + adapter_top = state_dict['adapter_top'] + adapter_bottom = state_dict['adapter_bottom'] + tb1 = session.parallelize([(key, value) for key, value in emulator.items()], include_key=True, partition=1) + tb2 = session.parallelize([(key, value) for key, value in adapter_top.items()], include_key=True, partition=1) + tb3 = session.parallelize([(key, value) for key, value in adapter_bottom.items()], include_key=True, partition=1) + state_dict.pop('emulator', None) + state_dict.pop('adapter_top', None) + state_dict.pop('adapter_bottom', None) + tb4 = session.parallelize([(key, value) for key, value in state_dict.items()], include_key=True, partition=4) + send_func( + tb1, + suffix='emulator_'+suffix) + send_func( + tb2, + suffix='adapter_top_'+suffix) + send_func( + tb3, + suffix='adapter_bottom_'+suffix) + send_func( + tb4, + suffix='other_param_'+suffix) + + + def _get_submodel_weights(self, get_func, suffix='start'): + + client_agg: SecureAggregatorClient = self.client_agg + tb1 = get_func(suffix='emulator_'+suffix)[0] + tb2 = get_func(suffix='adapter_top_'+suffix)[0] + tb3 = get_func(suffix='adapter_bottom_'+suffix)[0] + tb4 = get_func(suffix='other_param_'+suffix)[0] + + got_state_dict = {} + got_state_dict['emulator'] = dict(tb1.collect()) + got_state_dict['adapter_top'] = dict(tb2.collect()) + got_state_dict['adapter_bottom'] = dict(tb3.collect()) + other_param = dict(tb4.collect()) + got_state_dict.update(other_param) + + return got_state_dict + + + def on_loop_begin_client(self): + + unwarp_model = self.unwrap_model(self.model) + if not isinstance(unwarp_model, OffsiteTuningSubModel): + raise ValueError( + 'Client must provide a model subclassing "OffsiteTuningSubModel" in the offsite-tuning trainer, but got {}'.format( + type( + unwarp_model))) + + model: OffsiteTuningSubModel = unwarp_model + + if self.fed_mode: + + if (distributed_util.is_distributed() and distributed_util.is_rank_0()) or (not distributed_util.is_distributed()): + # receive parameters from model provider and load emulator, adapter + client_agg: SecureAggregatorClient = self.client_agg + ret = self._get_submodel_weights(self.model_transvar.server_to_client.get, suffix='start') + LOGGER.info('loaded weights keys are {}'.format(ret.keys())) + # client_agg: SecureAggregatorClient = self.client_agg + # param = client_agg.get('sub_model_parameter') + model.load_submodel_weights(ret) + + if distributed_util.is_distributed(): + self._share_model(sync_trainable_only=False) + # reinitalize deepspeed + deepspeed_util.init_deepspeed_env(self._ds_config) + model = self.unwrap_model(self.model) + self._model, self._optimizer = deepspeed_util.deepspeed_init(model, self._ds_config) + if deepspeed_util.is_zero3(self._ds_config): + self._model.train() + + LOGGER.info( + 'adapter parameters num: {}'.format( + count_parameters( + model.get_adapter_top()) + + count_parameters( + model.get_adapter_bottom()))) + LOGGER.info( + 'trainable parameters num {}'.format( + count_trainable_parameters(model))) + + def on_loop_begin_server(self): + + if self.model is None: + raise ValueError( + 'Server must provide a main model in the offsite-tuning trainer, got None model, \ + please set server_init to True and provide the model config') + + unwrap_model = self.unwrap_model(self.model) + if not isinstance(unwrap_model, OffsiteTuningMainModel): + raise ValueError( + 'Server must provide a model subclassing "OffsiteTuningMainModel" in the offsite-tuning trainer, but got {}'.format( + type( + unwrap_model))) + + model: OffiteTuningMainModel = unwrap_model + sub_model_state_dict = model.get_submodel_weights() + self._send_submodel_weights(sub_model_state_dict, self.model_transvar.server_to_client.remote, suffix='start') + # server_agg: SecureAggregatorServer = self.server_agg + # server_agg.broadcast( + # sub_model_state_dict, + # suffix='sub_model_parameter') + + LOGGER.info( + 'adapter parameters num: {}'.format( + count_parameters( + model.get_adapter_top()) + + count_parameters( + model.get_adapter_bottom()))) + LOGGER.info( + 'emulator parameters num: {}'.format( + count_parameters( + model.get_emulator()))) + + def on_loop_end_client(self): + + if self.fed_mode: + if (distributed_util.is_distributed() and distributed_util.is_rank_0()) or (not distributed_util.is_distributed()): + model: OffsiteTuningSubModel = self.unwrap_model(self.model) + sub_model_state_dict = model.get_submodel_weights() + # client_agg = self.client_agg + # client_agg.send( + # sub_model_state_dict, + # suffix='final_sub_model_parameter') + self._send_submodel_weights(sub_model_state_dict, self.model_transvar.client_to_server.remote, suffix='end') + + def on_loop_end_server(self): + + model: OffsiteTuningMainModel = self.model + ret_state_dict = self._get_submodel_weights(self.model_transvar.client_to_server.get, suffix='end') + model.load_submodel_weights(ret_state_dict) + # server_agg = self.server_agg + # sub_model_state_dict = server_agg.collect( + # suffix='final_sub_model_parameter')[0] + # model.load_submodel_weights(sub_model_state_dict) + + + def _client_sends_data(self, epoch_idx, epoch_loss, cur_agg_round): + if self.need_aggregate: + return super()._client_sends_data(epoch_idx, epoch_loss, cur_agg_round) + else: + return False + + def _server_aggregates_data( + self, + epoch_idx, + check_converge, + converge_func): + if self.need_aggregate: + return super()._server_aggregates_data(epoch_idx, check_converge, converge_func) + else: + return False + + def _init_aggregator(self, train_set): + # compute round to aggregate + cur_agg_round = 0 + if self.aggregate_every_n_epoch is not None: + aggregate_round = self.epochs // self.aggregate_every_n_epoch + else: + aggregate_round = self.epochs + + # initialize fed avg client + if self.fed_mode: + if self.weighted_aggregation: + sample_num = len(train_set) + else: + sample_num = 1.0 + + if not distributed_util.is_distributed() or distributed_util.is_rank_0(): + if len(self.party_id_list) == 1: # guest only: + clients = (consts.GUEST, ) + else: + clients = (consts.GUEST, consts.HOST) + client_agg = SecureAggClient( + self.secure_aggregate, + aggregate_weight=sample_num, + communicate_match_suffix=self.comm_suffix, + clients=clients) + # init model transvar + from federatedml.framework.homo.blocks import CommunicatorTransVar + self.model_transvar = CommunicatorTransVar(clients=clients, prefix='model', disable_gc=True) + else: + client_agg = None + else: + client_agg = None + + return client_agg, aggregate_round + + def server_aggregate_procedure(self, extra_data={}): + + # converge status + check_converge = False + converge_func = None + if self.early_stop: + check_converge = True + converge_func = converge_func_factory( + self.early_stop, self.tol).is_converge + LOGGER.info( + 'check early stop, converge func is {}'.format(converge_func)) + + LOGGER.info('server running aggregate procedure') + if len(self.party_id_list) == 1: # guest only: + clients = (consts.GUEST, ) + else: + clients = (consts.GUEST, consts.HOST) + + self.server_agg = SecureAggServer( + self.secure_aggregate, + communicate_match_suffix=self.comm_suffix, + clients=clients) + from federatedml.framework.homo.blocks import CommunicatorTransVar + self.model_transvar = CommunicatorTransVar(clients=clients, prefix='model', disable_gc=True) + + self.on_loop_begin_server() + # aggregate and broadcast models + for i in range(self.epochs): + + need_stop = self._server_aggregates_data( + i, check_converge, converge_func) + if need_stop: + break + + self.on_loop_end_server() + LOGGER.info('server aggregation process done') + if self._model is not None: + if self.save_to_local_dir: + self.local_save( + model=self.model, + epoch_idx=i, + converge_status=need_stop) + else: + self.save( + model=self.model, + epoch_idx=i, + converge_status=need_stop) + LOGGER.info('sever side model saved') From 680ea02bdecc75d29baa563a947d0fcca5ef298d Mon Sep 17 00:00:00 2001 From: cwj Date: Wed, 30 Aug 2023 15:17:40 +0800 Subject: [PATCH 02/23] Support LM Table aggregate Signed-off-by: cwj --- .../fate_llm/trainer/offsite_tuning_trainer.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/python/fate_llm/trainer/offsite_tuning_trainer.py b/python/fate_llm/trainer/offsite_tuning_trainer.py index ba9c1bd..1dd8336 100644 --- a/python/fate_llm/trainer/offsite_tuning_trainer.py +++ b/python/fate_llm/trainer/offsite_tuning_trainer.py @@ -23,7 +23,7 @@ class OffsiteTuningTrainer(FedAVGTrainer): def __init__(self, epochs=10, batch_size=512, # training parameter early_stop=None, tol=0.0001, # early stop parameters - secure_aggregate=True, weighted_aggregation=True, aggregate_every_n_epoch=None, # federation + secure_aggregate=False, weighted_aggregation=True, aggregate_every_n_epoch=None, # federation, offsite tuning need to aggregate large model, default is False cuda=None, pin_memory=True, shuffle=True, data_loader_worker=0, # GPU & dataloader validation_freqs=None, # validation configuration @@ -63,9 +63,9 @@ def _send_submodel_weights(self, state_dict, send_func, suffix='start'): emulator = state_dict['emulator'] adapter_top = state_dict['adapter_top'] adapter_bottom = state_dict['adapter_bottom'] - tb1 = session.parallelize([(key, value) for key, value in emulator.items()], include_key=True, partition=1) - tb2 = session.parallelize([(key, value) for key, value in adapter_top.items()], include_key=True, partition=1) - tb3 = session.parallelize([(key, value) for key, value in adapter_bottom.items()], include_key=True, partition=1) + tb1 = session.parallelize([(key, value) for key, value in emulator.items()], include_key=True, partition=4) + tb2 = session.parallelize([(key, value) for key, value in adapter_top.items()], include_key=True, partition=4) + tb3 = session.parallelize([(key, value) for key, value in adapter_bottom.items()], include_key=True, partition=4) state_dict.pop('emulator', None) state_dict.pop('adapter_top', None) state_dict.pop('adapter_bottom', None) @@ -86,7 +86,6 @@ def _send_submodel_weights(self, state_dict, send_func, suffix='start'): def _get_submodel_weights(self, get_func, suffix='start'): - client_agg: SecureAggregatorClient = self.client_agg tb1 = get_func(suffix='emulator_'+suffix)[0] tb2 = get_func(suffix='adapter_top_'+suffix)[0] tb3 = get_func(suffix='adapter_bottom_'+suffix)[0] @@ -117,7 +116,6 @@ def on_loop_begin_client(self): if (distributed_util.is_distributed() and distributed_util.is_rank_0()) or (not distributed_util.is_distributed()): # receive parameters from model provider and load emulator, adapter - client_agg: SecureAggregatorClient = self.client_agg ret = self._get_submodel_weights(self.model_transvar.server_to_client.get, suffix='start') LOGGER.info('loaded weights keys are {}'.format(ret.keys())) # client_agg: SecureAggregatorClient = self.client_agg @@ -239,7 +237,9 @@ def _init_aggregator(self, train_set): self.secure_aggregate, aggregate_weight=sample_num, communicate_match_suffix=self.comm_suffix, - clients=clients) + clients=clients, + lm_aggregate=True + ) # init model transvar from federatedml.framework.homo.blocks import CommunicatorTransVar self.model_transvar = CommunicatorTransVar(clients=clients, prefix='model', disable_gc=True) @@ -271,7 +271,9 @@ def server_aggregate_procedure(self, extra_data={}): self.server_agg = SecureAggServer( self.secure_aggregate, communicate_match_suffix=self.comm_suffix, - clients=clients) + clients=clients, + lm_aggregate=True + ) from federatedml.framework.homo.blocks import CommunicatorTransVar self.model_transvar = CommunicatorTransVar(clients=clients, prefix='model', disable_gc=True) From b76873d9d4a200a451d2e75998918fd40f1d993a Mon Sep 17 00:00:00 2001 From: cwj Date: Wed, 30 Aug 2023 16:16:39 +0800 Subject: [PATCH 03/23] Fix typo Signed-off-by: cwj --- python/fate_llm/trainer/offsite_tuning_trainer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/fate_llm/trainer/offsite_tuning_trainer.py b/python/fate_llm/trainer/offsite_tuning_trainer.py index 1dd8336..718cf20 100644 --- a/python/fate_llm/trainer/offsite_tuning_trainer.py +++ b/python/fate_llm/trainer/offsite_tuning_trainer.py @@ -8,6 +8,7 @@ from federatedml.nn.backend.utils import deepspeed_util from federatedml.nn.backend.utils import distributed_util import torch.distributed as dist +from federatedml.optim.convergence import converge_func_factory @@ -155,7 +156,7 @@ def on_loop_begin_server(self): type( unwrap_model))) - model: OffiteTuningMainModel = unwrap_model + model: OffsiteTuningMainModel = unwrap_model sub_model_state_dict = model.get_submodel_weights() self._send_submodel_weights(sub_model_state_dict, self.model_transvar.server_to_client.remote, suffix='start') # server_agg: SecureAggregatorServer = self.server_agg From 4c2c11af6519c8cfeb56669871f30f82e3a8ed4a Mon Sep 17 00:00:00 2001 From: cwj Date: Thu, 31 Aug 2023 17:56:21 +0800 Subject: [PATCH 04/23] Add llama support remove debug codes Signed-off-by: cwj --- .../model_zoo/offsite_tuning/gpt2_ot.py | 115 ------------- .../model_zoo/offsite_tuning/llama_ot.py | 153 +++++++++++++++++ .../offsite_tuning/offsite_tuning_model.py | 158 ------------------ 3 files changed, 153 insertions(+), 273 deletions(-) create mode 100644 python/fate_llm/model_zoo/offsite_tuning/llama_ot.py diff --git a/python/fate_llm/model_zoo/offsite_tuning/gpt2_ot.py b/python/fate_llm/model_zoo/offsite_tuning/gpt2_ot.py index 2052cf1..ebc7ded 100644 --- a/python/fate_llm/model_zoo/offsite_tuning/gpt2_ot.py +++ b/python/fate_llm/model_zoo/offsite_tuning/gpt2_ot.py @@ -174,118 +174,3 @@ def parameters(self, recurse=True): else: raise ValueError( f"partial_weight_decay should be None or float, but got {self.partial_weight_decay}") - - -# from fate_llm.model_zoo.offsite_tuning.offsite_tuning_model import OffsiteTuningSubModel, OffsiteTuningMainModel, get_dropout_emulator_and_adapters -# from transformers import GPT2LMHeadModel, GPT2Config -# from torch import nn -# import torch - - -# class GPT2LMHeadMainModel(OffsiteTuningMainModel): - -# def __init__( -# self, -# model_name_or_path, -# emulator_layer_num: int, -# adapter_top_layer_num: int = 2, -# adapter_bottom_layer_num: int = 2): - -# self.model_name_or_path = model_name_or_path -# super().__init__( -# emulator_layer_num, -# adapter_top_layer_num, -# adapter_bottom_layer_num) - -# def get_base_model(self): -# return GPT2LMHeadModel.from_pretrained(self.model_name_or_path) - -# def get_model_transformer_blocks(self, model: GPT2LMHeadModel): -# return model.transformer.h - -# def forward(self, x): -# return self.model(**x) - -# def get_additional_parameter(self, model) -> dict: -# return { -# 'wte': model.transformer.wte, -# 'wpe': model.transformer.wpe, -# 'last_ln_f': model.transformer.ln_f -# } - -# def forward(self, x): -# return self.model(**x) - - -# class GPT2LMHeadSubModel(OffsiteTuningSubModel): - -# def __init__( -# self, -# model_name_or_path, -# emulator_layer_num: int, -# adapter_top_layer_num: int = 2, -# adapter_bottom_layer_num: int = 2, -# fp16_mix_precision=False, -# partial_weight_decay=None): - -# self.model_name_or_path = model_name_or_path -# self.emulator_layer_num = emulator_layer_num -# self.adapter_top_layer_num = adapter_top_layer_num -# self.adapter_bottom_layer_num = adapter_bottom_layer_num -# super().__init__( -# emulator_layer_num, -# adapter_top_layer_num, -# adapter_bottom_layer_num, -# fp16_mix_precision) -# self.partial_weight_decay = partial_weight_decay - -# def get_base_model(self): -# total_layer_num = self.emulator_layer_num + \ -# self.adapter_top_layer_num + self.adapter_bottom_layer_num -# config = GPT2Config.from_pretrained(self.model_name_or_path) -# config.num_hidden_layers = total_layer_num -# # initialize a model without pretrained weights -# return GPT2LMHeadModel(config) - -# def get_model_transformer_blocks(self, model: GPT2LMHeadModel): -# return model.transformer.h - -# def forward(self, x): -# return self.model(**x) - -# def get_additional_parameter(self, model) -> dict: -# return { -# 'wte': model.transformer.wte, -# 'wpe': model.transformer.wpe, -# 'last_ln_f': model.transformer.ln_f -# } - -# def forward(self, x): -# return self.model(**x) - -# def parameters(self, recurse=True): -# if self.partial_weight_decay is None: -# return super().parameters(recurse) -# elif isinstance(self.partial_weight_decay, float): -# no_decay = ["bias", "layer_norm.weight"] -# return [ -# { -# "params": [ -# p for n, p in self.named_parameters() if not any( -# nd in n for nd in no_decay)], "weight_decay": self.partial_weight_decay}, { -# "params": [ -# p for n, p in self.named_parameters() if any( -# nd in n for nd in no_decay)], "weight_decay": 0.0}] -# else: -# raise ValueError( -# f"partial_weight_decay should be None or float, but got {self.partial_weight_decay}") - - -if __name__ == "__main__": - - from transformers import GPT2Model - - model = GPT2LMHeadMainModel('gpt2-xl', 12, 2, 2) - model_sub = GPT2LMHeadSubModel( - 'gpt2-xl', 12, 2, 2, fp16_mix_precision=True) - model_sub.load_submodel_weights(model.get_submodel_weights()) diff --git a/python/fate_llm/model_zoo/offsite_tuning/llama_ot.py b/python/fate_llm/model_zoo/offsite_tuning/llama_ot.py new file mode 100644 index 0000000..2473c32 --- /dev/null +++ b/python/fate_llm/model_zoo/offsite_tuning/llama_ot.py @@ -0,0 +1,153 @@ +from fate_llm.model_zoo.offsite_tuning.offsite_tuning_model import OffsiteTuningSubModel, OffsiteTuningMainModel, get_dropout_emulator_and_adapters, split_numpy_array, recover_numpy_array +from transformers import LlamaConfig, LlamaForCausalLM + + +class LlamaMainModel(OffsiteTuningMainModel): + + def __init__( + self, + model_name_or_path, + emulator_layer_num: int, + adapter_top_layer_num: int = 2, + adapter_bottom_layer_num: int = 2): + + self.model_name_or_path = model_name_or_path + super().__init__( + emulator_layer_num, + adapter_top_layer_num, + adapter_bottom_layer_num) + + def get_base_model(self): + return LlamaForCausalLM.from_pretrained(self.model_name_or_path) + + def get_model_transformer_blocks(self, model: LlamaForCausalLM): + return model.model.layers + + def forward(self, x): + return self.model(**x) + + def get_additional_param_state_dict(self): + # get parameter of additional parameter + model = self.model + param_dict = { + 'wte': model.model.embed_tokens, + 'last_ln_f': model.model.norm + } + + addition_weights = self.get_numpy_state_dict(param_dict) + + wte = addition_weights.pop('wte') + wte_dict = split_numpy_array(wte, 25, 'wte') + addition_weights.update(wte_dict) + return addition_weights + + def load_additional_param_state_dict(self, submodel_weights: dict): + # load additional weights: + model = self.model + param_dict = { + 'wte': model.model.embed_tokens, + 'last_ln_f': model.model.norm + } + + new_submodel_weight = {} + new_submodel_weight['last_ln_f'] = submodel_weights['last_ln_f'] + wte_dict = {} + for k, v in submodel_weights.items(): + if 'wte' in k: + wte_dict[k] = v + wte = recover_numpy_array(wte_dict, 'wte') + new_submodel_weight['wte'] = wte + self.load_numpy_state_dict(param_dict, new_submodel_weight) + + def forward(self, x): + return self.model(**x) + + +class LlamaSubModel(OffsiteTuningSubModel): + + def __init__( + self, + model_name_or_path, + emulator_layer_num: int, + adapter_top_layer_num: int = 2, + adapter_bottom_layer_num: int = 2, + fp16_mix_precision=False, + partial_weight_decay=None): + + self.model_name_or_path = model_name_or_path + self.emulator_layer_num = emulator_layer_num + self.adapter_top_layer_num = adapter_top_layer_num + self.adapter_bottom_layer_num = adapter_bottom_layer_num + super().__init__( + emulator_layer_num, + adapter_top_layer_num, + adapter_bottom_layer_num, + fp16_mix_precision) + self.partial_weight_decay = partial_weight_decay + + def get_base_model(self): + total_layer_num = self.emulator_layer_num + \ + self.adapter_top_layer_num + self.adapter_bottom_layer_num + config = LlamaConfig.from_pretrained(self.model_name_or_path) + config.num_layers = total_layer_num + # initialize a model without pretrained weights + return LlamaForCausalLM(config) + + def get_model_transformer_blocks(self, model: LlamaForCausalLM): + return model.model.layers + + def forward(self, x): + return self.model(**x) + + def get_additional_param_state_dict(self): + # get parameter of additional parameter + model = self.model + param_dict = { + 'wte': model.model.embed_tokens, + 'last_ln_f': model.model.norm + } + + addition_weights = self.get_numpy_state_dict(param_dict) + + wte = addition_weights.pop('wte') + wte_dict = split_numpy_array(wte, 25, 'wte') + addition_weights.update(wte_dict) + return addition_weights + + def load_additional_param_state_dict(self, submodel_weights: dict): + # load additional weights: + model = self.model + param_dict = { + 'wte': model.model.embed_tokens, + 'last_ln_f': model.model.norm + } + + new_submodel_weight = {} + new_submodel_weight['last_ln_f'] = submodel_weights['last_ln_f'] + wte_dict = {} + for k, v in submodel_weights.items(): + if 'wte' in k: + wte_dict[k] = v + wte = recover_numpy_array(wte_dict, 'wte') + new_submodel_weight['wte'] = wte + self.load_numpy_state_dict(param_dict, new_submodel_weight) + + def forward(self, x): + return self.model(**x) + + def parameters(self, recurse=True): + if self.partial_weight_decay is None: + return super().parameters(recurse) + elif isinstance(self.partial_weight_decay, float): + no_decay = ["bias", "layer_norm.weight"] + return [ + { + "params": [ + p for n, p in self.named_parameters() if not any( + nd in n for nd in no_decay)], "weight_decay": self.partial_weight_decay}, { + "params": [ + p for n, p in self.named_parameters() if any( + nd in n for nd in no_decay)], "weight_decay": 0.0}] + else: + raise ValueError( + f"partial_weight_decay should be None or float, but got {self.partial_weight_decay}") \ No newline at end of file diff --git a/python/fate_llm/model_zoo/offsite_tuning/offsite_tuning_model.py b/python/fate_llm/model_zoo/offsite_tuning/offsite_tuning_model.py index 1c5ef56..a81f03e 100644 --- a/python/fate_llm/model_zoo/offsite_tuning/offsite_tuning_model.py +++ b/python/fate_llm/model_zoo/offsite_tuning/offsite_tuning_model.py @@ -186,161 +186,3 @@ def post_initialization(self): for param in self.adapter_bottom.parameters(): param.data = param.data.float() param.requires_grad = True - - -# import torch as t -# from torch import nn -# from federatedml.util import LOGGER -# from transformers import AutoModel - - -# def get_dropout_emulator_and_adapters( -# transformer_layers: nn.ModuleList, -# emulator_layer_num: int, -# adapter_top_layer_num: int, -# adapter_bottom_layer_num: int): - -# assert adapter_bottom_layer_num > 0 and adapter_top_layer_num > 0, "adapter layer num must be greater than 0" -# assert emulator_layer_num < len( -# transformer_layers), "emulator layer num must be less than the number of transformer layers" -# assert adapter_bottom_layer_num + adapter_top_layer_num < len( -# transformer_layers), "adapter layer num must be less than the number of transformer layers" -# assert emulator_layer_num < len( -# transformer_layers) and emulator_layer_num > 0, "emulator layer num must be less than the number of transformer layers" - -# bottom_idx = adapter_bottom_layer_num -# top_idx = len(transformer_layers) - adapter_top_layer_num -# bottom_layers = transformer_layers[:bottom_idx] -# top_layers = transformer_layers[top_idx:] -# kept_layers = transformer_layers[bottom_idx:top_idx] -# emulator = nn.ModuleList() -# stride = (len(kept_layers) - 1) / (emulator_layer_num - 1) - -# layer_idx = [] -# for i in range(emulator_layer_num): -# idx = int(round(i * stride)) -# layer_idx.append(idx) -# emulator.append(kept_layers[idx]) -# LOGGER.info( -# 'take layer {} of the original model as the emulator'.format( -# t.Tensor(layer_idx) + -# bottom_idx)) -# return nn.ModuleList(emulator), nn.ModuleList( -# bottom_layers), nn.ModuleList(top_layers) - - -# class OffsiteTuningBaseModel(t.nn.Module): - -# def __init__(self, emulator_layer_num: int, adapter_top_layer_num: int = 2, -# adapter_bottom_layer_num: int = 2, fp16_mix_precision=False): -# super().__init__() -# self.fp16_mix_precision = fp16_mix_precision -# self.model = self.get_base_model() -# self.initialize_model() -# self.emulator, self.adapter_bottom, self.adapter_top = get_dropout_emulator_and_adapters( -# transformer_layers=self.get_model_transformer_blocks(self.model), -# emulator_layer_num=emulator_layer_num, -# adapter_top_layer_num=adapter_top_layer_num, -# adapter_bottom_layer_num=adapter_bottom_layer_num -# ) -# self.addition_param = self.get_additional_parameter(self.model) -# self.post_initialization() - -# def initialize_model(self): -# if self.fp16_mix_precision: -# self.model.half() -# for param in self.model.parameters(): -# param.requires_grad = False - -# def post_initialization(self): -# pass - -# def get_adapter_top(self): -# return self.adapter_top - -# def get_adapter_bottom(self): -# return self.adapter_bottom - -# def get_emulator(self): -# return self.emulator - -# def get_submodel_weights(self) -> dict: -# submodel_weights = { -# "emulator": { -# k: v.detach().cpu().numpy() for k, -# v in self.get_emulator().state_dict().items()}, -# "adapter_top": { -# k: v.detach().cpu().numpy() for k, -# v in self.get_adapter_top().state_dict().items()}, -# "adapter_bottom": { -# k: v.detach().cpu().numpy() for k, -# v in self.get_adapter_bottom().state_dict().items()}} - -# # get parameter of additional parameter -# addition_weights = {} -# for k, v in self.addition_param.items(): -# addition_weights[k] = { -# k: v.detach().cpu().numpy() for k, -# v in v.state_dict().items()} -# submodel_weights.update(addition_weights) - -# return submodel_weights - -# def load_submodel_weights(self, submodel_weights: dict): - -# emulator_weights = { -# k: t.tensor(v) for k, -# v in submodel_weights['emulator'].items()} -# adapter_top_weights = { -# k: t.tensor(v) for k, -# v in submodel_weights['adapter_top'].items()} -# adapter_bottom_weights = { -# k: t.tensor(v) for k, -# v in submodel_weights['adapter_bottom'].items()} - -# emulator = self.get_emulator() -# adapter_top = self.get_adapter_top() -# adapter_bottom = self.get_adapter_bottom() - -# emulator.load_state_dict(emulator_weights) -# adapter_top.load_state_dict(adapter_top_weights) -# adapter_bottom.load_state_dict(adapter_bottom_weights) - -# # load additional weights: -# for k, v in self.addition_param.items(): -# if k not in submodel_weights: -# continue -# addition_weights = { -# k: t.tensor(v) for k, -# v in submodel_weights[k].items()} -# v.load_state_dict(addition_weights) - -# def forward(self, **kwargs): -# raise NotImplementedError() - -# def get_base_model(self): -# raise NotImplementedError() - -# def get_model_transformer_blocks(self, model: t.nn.Module): -# raise NotImplementedError() - -# def get_additional_parameter(self, model) -> dict: -# return {} - - -# class OffsiteTuningMainModel(OffsiteTuningBaseModel): - -# def post_initialization(self): -# pass - - -# class OffsiteTuningSubModel(OffsiteTuningBaseModel): - -# def post_initialization(self): -# # mix precision model training -# for param in self.adapter_top.parameters(): -# param.data = param.data.float() -# param.requires_grad = True -# for param in self.adapter_bottom.parameters(): -# param.data = param.data.float() -# param.requires_grad = True \ No newline at end of file From 7145938366d7e2de64ef7f262f52ff8e881f02b6 Mon Sep 17 00:00:00 2001 From: cwj Date: Sun, 3 Sep 2023 21:18:51 +0800 Subject: [PATCH 05/23] Fix llama num layer Signed-off-by: cwj --- python/fate_llm/model_zoo/offsite_tuning/llama_ot.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/fate_llm/model_zoo/offsite_tuning/llama_ot.py b/python/fate_llm/model_zoo/offsite_tuning/llama_ot.py index 2473c32..cb8de1c 100644 --- a/python/fate_llm/model_zoo/offsite_tuning/llama_ot.py +++ b/python/fate_llm/model_zoo/offsite_tuning/llama_ot.py @@ -89,7 +89,7 @@ def get_base_model(self): total_layer_num = self.emulator_layer_num + \ self.adapter_top_layer_num + self.adapter_bottom_layer_num config = LlamaConfig.from_pretrained(self.model_name_or_path) - config.num_layers = total_layer_num + config.num_hidden_layers = total_layer_num # initialize a model without pretrained weights return LlamaForCausalLM(config) From 96930aaa1aa4bd8d9c76e361d89cdf4cde96575f Mon Sep 17 00:00:00 2001 From: cwj Date: Mon, 4 Sep 2023 12:41:16 +0800 Subject: [PATCH 06/23] Fix Ipr import & update log Signed-off-by: cwj --- python/fate_llm/model_zoo/ipr/alexnet.py | 2 +- python/fate_llm/model_zoo/ipr/distilbert.py | 2 +- python/fate_llm/model_zoo/ipr/gpt2.py | 2 +- python/fate_llm/model_zoo/ipr/resnet.py | 2 +- python/fate_llm/model_zoo/ipr/sign_block.py | 182 ++++++++++++++++++++ python/fate_llm/trainer/fedipr_trainer.py | 17 +- 6 files changed, 198 insertions(+), 9 deletions(-) create mode 100644 python/fate_llm/model_zoo/ipr/sign_block.py diff --git a/python/fate_llm/model_zoo/ipr/alexnet.py b/python/fate_llm/model_zoo/ipr/alexnet.py index 57eae5f..bcaee84 100644 --- a/python/fate_llm/model_zoo/ipr/alexnet.py +++ b/python/fate_llm/model_zoo/ipr/alexnet.py @@ -1,5 +1,5 @@ import torch.nn as nn -from fate_llm.model_zoo.sign_block import SignatureConv, ConvBlock +from fate_llm.model_zoo.ipr.sign_block import SignatureConv, ConvBlock class SignAlexNet(nn.Module): diff --git a/python/fate_llm/model_zoo/ipr/distilbert.py b/python/fate_llm/model_zoo/ipr/distilbert.py index 5ff429c..4455417 100644 --- a/python/fate_llm/model_zoo/ipr/distilbert.py +++ b/python/fate_llm/model_zoo/ipr/distilbert.py @@ -1,6 +1,6 @@ from torch.nn import Module from transformers import DistilBertForSequenceClassification, DistilBertForTokenClassification -from fate_llm.model_zoo.sign_block import recursive_replace_layernorm +from fate_llm.model_zoo.ipr.sign_block import recursive_replace_layernorm class SignDistilBertForTokenClassification(Module): diff --git a/python/fate_llm/model_zoo/ipr/gpt2.py b/python/fate_llm/model_zoo/ipr/gpt2.py index ef506bc..39d72e2 100644 --- a/python/fate_llm/model_zoo/ipr/gpt2.py +++ b/python/fate_llm/model_zoo/ipr/gpt2.py @@ -1,6 +1,6 @@ from torch.nn import Module from transformers import GPT2ForTokenClassification, GPT2ForSequenceClassification -from fate_llm.model_zoo.sign_block import recursive_replace_layernorm +from fate_llm.model_zoo.ipr.sign_block import recursive_replace_layernorm class SignGPT2ForTokenClassification(Module): diff --git a/python/fate_llm/model_zoo/ipr/resnet.py b/python/fate_llm/model_zoo/ipr/resnet.py index 156206e..4e3b0e4 100644 --- a/python/fate_llm/model_zoo/ipr/resnet.py +++ b/python/fate_llm/model_zoo/ipr/resnet.py @@ -1,6 +1,6 @@ import torch.nn as nn import torch.nn.functional as F -from fate_llm.model_zoo.sign_block import ConvBlock, SignatureConv +from fate_llm.model_zoo.ipr.sign_block import ConvBlock, SignatureConv # The layer define for ResNet18, add signature to last layer diff --git a/python/fate_llm/model_zoo/ipr/sign_block.py b/python/fate_llm/model_zoo/ipr/sign_block.py new file mode 100644 index 0000000..45e701d --- /dev/null +++ b/python/fate_llm/model_zoo/ipr/sign_block.py @@ -0,0 +1,182 @@ +import torch +import torch.nn as nn +import torch.nn.init as init +from torch.nn import functional as F +from federatedml.util import LOGGER + +""" +Base +""" + + +class SignatureBlock(nn.Module): + + def __init__(self) -> None: + super().__init__() + + @property + def embeded_param(self): + return None + + def embeded_param_num(self): + return None + + def extract_sign(self, W): + pass + + def sign_loss(self, W, sign): + pass + + +def is_sign_block(block): + return issubclass(type(block), SignatureBlock) + + +class ConvBlock(nn.Module): + def __init__(self, i, o, ks=3, s=1, pd=1, relu=True): + super().__init__() + + self.conv = nn.Conv2d(i, o, ks, s, pd, bias= False) + + if relu: + self.relu = nn.ReLU(inplace=True) + else: + self.relu = None + + self.reset_parameters() + + def reset_parameters(self): + init.kaiming_normal_(self.conv.weight, mode='fan_out', nonlinearity='relu') + + def forward(self, x): + x = self.conv(x) + if self.relu is not None: + x = self.relu(x) + return x + + +def generate_signature(conv_block: SignatureBlock, num_bits): + + sign = torch.sign(torch.rand(num_bits) - 0.5) + W = torch.randn(len(conv_block.embeded_param.flatten()), num_bits) + + return (W, sign) + + +""" +Function & Class for Conv Layer +""" + + +class SignatureConv(SignatureBlock): + + def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False): + super(SignatureConv, self).__init__() + + self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=bias) + self.weight = self.conv.weight + + self.init_scale() + self.init_bias() + self.bn = nn.BatchNorm2d(out_channels, affine=False) + self.relu = nn.ReLU(inplace=True) + self.reset_parameters() + + def init_bias(self): + self.bias = nn.Parameter(torch.Tensor(self.conv.out_channels).to(self.weight.device)) + init.zeros_(self.bias) + + def init_scale(self): + self.scale = nn.Parameter(torch.Tensor(self.conv.out_channels).to(self.weight.device)) + init.ones_(self.scale) + + def reset_parameters(self): + init.kaiming_normal_(self.weight, mode='fan_out', nonlinearity='relu') + + @property + def embeded_param(self): + # embedded in the BatchNorm param, as the same in the paper + return self.scale + + def embeded_param_num(self): + return len(self.scale) + + def extract_sign(self, W): + # W is the linear weight for extracting signature + with torch.no_grad(): + return self.scale.view([1, -1]).mm(W).sign().flatten() + + def sign_loss(self, W, sign): + loss = F.relu(-self.scale.view([1, -1]).mm(W).mul(sign.view(-1))).sum() + return loss + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + x = x * self.scale[None, :, None, None] + self.bias[None, :, None, None] + x = self.relu(x) + return x + + +""" +Function & Class for LM +""" + + +def recursive_replace_layernorm(module, layer_name_set=None): + + """ + Recursively replaces the LayerNorm layers of a given module with SignatureLayerNorm layers. + + Parameters: + module (torch.nn.Module): The module in which LayerNorm layers should be replaced. + layer_name_set (set[str], optional): A set of layer names to be replaced. If None, + all LayerNorm layers in the module will be replaced. + """ + + for name, sub_module in module.named_children(): + if isinstance(sub_module, nn.LayerNorm): + if layer_name_set is not None and name not in layer_name_set: + continue + setattr(module, name, SignatureLayerNorm.from_layer_norm_layer(sub_module)) + LOGGER.debug(f"Replace {name} with SignatureLayerNorm") + recursive_replace_layernorm(sub_module, layer_name_set) + + +class SignatureLayerNorm(SignatureBlock): + + def __init__(self, normalized_shape=None, eps=1e-5, elementwise_affine=True, layer_norm_inst=None): + super(SignatureLayerNorm, self).__init__() + if layer_norm_inst is not None and isinstance(layer_norm_inst, nn.LayerNorm): + self.ln = layer_norm_inst + else: + self.ln = nn.LayerNorm(normalized_shape, eps, elementwise_affine) + + @property + def embeded_param(self): + return self.ln.weight + + def embeded_param_num(self): + return self.ln.weight.numel() + + @staticmethod + def from_layer_norm_layer(layer_norm_layer: nn.LayerNorm): + return SignatureLayerNorm(layer_norm_inst=layer_norm_layer) + + def extract_sign(self, W): + # W is the linear weight for extracting signature + with torch.no_grad(): + return self.ln.weight.view([1, -1]).mm(W).sign().flatten() + + def sign_loss(self, W, sign): + loss = F.relu(-self.ln.weight.view([1, -1]).mm(W).mul(sign.view(-1))).sum() + return loss + + def forward(self, x): + return self.ln(x) + + +if __name__ == "__main__": + conv = SignatureConv(3, 384, 3, 1, 1) + layer_norm = SignatureLayerNorm((768, )) + layer_norm_2 = SignatureLayerNorm.from_layer_norm_layer(layer_norm.ln) \ No newline at end of file diff --git a/python/fate_llm/trainer/fedipr_trainer.py b/python/fate_llm/trainer/fedipr_trainer.py index 76f3141..bc2b9e5 100644 --- a/python/fate_llm/trainer/fedipr_trainer.py +++ b/python/fate_llm/trainer/fedipr_trainer.py @@ -7,10 +7,10 @@ from federatedml.nn.backend.utils import distributed_util from torch.utils.data import DataLoader, DistributedSampler import torch.distributed as dist -from federatedml.nn.dataset.watermark import WaterMarkImageDataset, WaterMarkDataset +from fate_llm.dataset.watermark import WaterMarkImageDataset, WaterMarkDataset from federatedml.util import LOGGER -from fate_llm.model_zoo.sign_block import generate_signature, is_sign_block -from fate_llm.model_zoo.sign_block import SignatureBlock +from fate_llm.model_zoo.ipr.sign_block import generate_signature, is_sign_block +from fate_llm.model_zoo.ipr.sign_block import SignatureBlock from sklearn.metrics import accuracy_score from federatedml.nn.dataset.base import Dataset from federatedml.util import consts @@ -352,6 +352,9 @@ def train_an_epoch( for watermark_batch in watermark_dl: watermark_collect.append(watermark_batch) + total_batch_len = len(dl) + LOGGER.info('total batch len is {}'.format(total_batch_len)) + for _batch_iter in trainset_iterator: _batch_iter = self._decode(_batch_iter) @@ -428,8 +431,12 @@ def train_an_epoch( batch_idx += 1 if self.fed_mode: - LOGGER.debug( - 'epoch {} batch {} finished'.format(epoch_idx, batch_idx)) + if total_batch_len > 100: + if batch_idx % (total_batch_len // 100) == 0: + percentage = (batch_idx / total_batch_len) * 100 + LOGGER.debug(f"Training progress of epoch {epoch_idx}: {percentage:.1f}%") + else: + LOGGER.debug("Training epoch {}:batch {}".format(epoch_idx, batch_idx)) epoch_loss = epoch_loss / len(train_set) From f779ced4ec252714e0b814a7f74533c4ed2c17fa Mon Sep 17 00:00:00 2001 From: cwj Date: Mon, 4 Sep 2023 14:28:53 +0800 Subject: [PATCH 07/23] Update Doc Signed-off-by: cwj --- doc/tutorial/fed_ipr/FedIPR-tutorial.ipynb | 828 ++++++++++++++++++ .../ChatGLM-6B_ds.ipynb | 0 .../GPT2-example.ipynb | 0 3 files changed, 828 insertions(+) create mode 100644 doc/tutorial/fed_ipr/FedIPR-tutorial.ipynb rename doc/tutorial/{ => parameter_efficient_llm}/ChatGLM-6B_ds.ipynb (100%) rename doc/tutorial/{ => parameter_efficient_llm}/GPT2-example.ipynb (100%) diff --git a/doc/tutorial/fed_ipr/FedIPR-tutorial.ipynb b/doc/tutorial/fed_ipr/FedIPR-tutorial.ipynb new file mode 100644 index 0000000..05a4f5f --- /dev/null +++ b/doc/tutorial/fed_ipr/FedIPR-tutorial.ipynb @@ -0,0 +1,828 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# FedIPR Tutorial: Guide to Adding Watermarks to Image and Language Models" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In this tutorial, you'll learn how to add both backdoor-based and feature-based watermarks to your models in the federated training. \n", + "We'll dive into using backdoor-w datasets for backdoor-based watermarking and exploring signblock—a tool that learns feature-based watermarks during traning. We will show you how to apply these techniques to both computer vision and language models. We'll also offer a hands-on example with a CV task, share how to verify the watermarks you've embedded, and introduce some ready-to-use models provided by the FATE framework. " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## FedIPR Introduction\n", + "FedIPR stands for Federated Intellectual Property Rights, a technology designed to protect the ownership of models developed under federated conditions. At its core, the FedIPR approach is described in the original paper [FedIPR](https://arxiv.org/pdf/2109.13236.pdf), introducing two primary watermarking techniques to safeguard your model: Backdoor-based and Feature-based watermarks.\n", + "\n", + "Backdoor-based methods: These methods use specific input triggers to produce intentional, incorrect labels. The goal here is to create a unique \"signature\" for the model, allowing for ownership verification through remote APIs, without requiring access to the model's internal parameters.\n", + "\n", + "Feature-based methods: These techniques encode designated binary strings as watermarks directly into the model's layer parameters. Various schemes have been proposed, such as embedding these watermarks into convolution layer weights using a binary cross-entropy loss function, or into normalization layer scale parameters using a hinge-like regularization term. In our implementations, we embed signatures into normalization layers as the same as \n", + "\n", + "Through these watermarking techniques, FedIPR ensures a robust way to assert ownership of your federated models without compromising their performance." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Preliminary\n", + "\n", + "We strongly recommend you finish reading our NN tutorial to get familiar with Model and Dataset customizations: [NN Tutorials](https://github.com/FederatedAI/FATE/blob/master/doc/tutorial/pipeline/nn_tutorial/README.md)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Backdoor Dataset for Backdoor Watermark" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "You can develop your own backdoor dataset and use it in FedIPRTrainer. If watermark dataset is detected, it will be used to train models along with your task dataset. If not provided, it will perform normal training.\n", + "\n", + "You can add python path so that you can run codes in the notebook." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "import sys\n", + "your_path_to_fate_python = 'xxx/fate/python'\n", + "sys.path.append(your_path_to_fate_python)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Interfaces" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The WaterMarkDataset class serves as a base class for handling watermark datasets in federated learning environments. It’s crucial for you to implement the load method. The primary task when subclassing WaterMarkDataset is to fill in the load method. This method should take a path argument and use it to load both your normal and watermark datasets.\n", + "\n", + "Besides you need to implement other interfaces like get_item, len like using a pytorch dataset to make it work correctly in FATE.\n", + "You can refer to this tutorial: [Dataset Customization](https://github.com/FederatedAI/FATE/blob/master/doc/tutorial/pipeline/nn_tutorial/Homo-NN-Customize-your-Dataset.ipynb)\n", + "\n", + "Here show you the source code of the watermark dataset class." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "from federatedml.nn.dataset.base import Dataset\n", + "from federatedml.util import LOGGER\n", + "from federatedml.nn.dataset.image import ImageDataset\n", + "\n", + "\n", + "class WaterMarkDataset(Dataset):\n", + "\n", + " def __init__(self):\n", + " super().__init__()\n", + " self.normal_dataset = None\n", + " self.watermark_dataset = None\n", + "\n", + " def load(self, path):\n", + " raise NotImplementedError()\n", + "\n", + " def get_normal_dataset(self):\n", + " return self.normal_dataset\n", + "\n", + " def get_watermark_dataset(self):\n", + " return self.watermark_dataset" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "To make you better understand how our watermark dataset work, here we show the implementation of load function of our built-in WaterMarkImageDataset.\n", + "The WaterMarkImageDataset class is designed to automatically identify and load two distinct folders from the specified file path: one containing 'normal' training samples and another containing 'watermark' trigger samples." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "def load(self, file_path):\n", + "\n", + " # normal dataset path\n", + " normal_path = os.path.join(file_path, self.normal_folder_name)\n", + " # watermark dataset path\n", + " watermark_path = os.path.join(file_path, self.watermark_folder_name)\n", + "\n", + " # load normal dataset\n", + " self.normal_dataset = ImageDataset(\n", + " center_crop=self.center_crop,\n", + " center_crop_shape=self.size,\n", + " generate_id_from_file_name=self.generate_id_from_file_name,\n", + " file_suffix=self.file_suffix,\n", + " float64=self.float64,\n", + " label_dtype=self.label_type\n", + " )\n", + " if os.path.exists(normal_path):\n", + " self.normal_dataset.load(normal_path)\n", + " else:\n", + " self.normal_dataset = None\n", + " LOGGER.info(\n", + " f'normal dataset not found in {normal_path}, will not load normal dataset')\n", + " # load watermark dataset\n", + " self.watermark_dataset = ImageDataset(\n", + " center_crop=self.center_crop,\n", + " center_crop_shape=self.size,\n", + " generate_id_from_file_name=self.generate_id_from_file_name,\n", + " file_suffix=self.file_suffix,\n", + " float64=self.float64,\n", + " label_dtype=self.label_type\n", + " )\n", + " if os.path.exists(watermark_path):\n", + " self.watermark_dataset.load(watermark_path)\n", + " else:\n", + " self.watermark_dataset = None\n", + " LOGGER.info(\n", + " f'watermark dataset not found in {watermark_path}, will not load watermark dataset')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "You can try our WaterMarkImageDataset: use it load our provided cifar-10 watermarked dataset which contains 100 trigger samples.Each image in these folders has been augmented with a pattern of structured noise in one corner. Download the dataset and place it in example/data folder in your fate project: [Dowload Path]()" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "from fate_llm.dataset.watermark import WaterMarkImageDataset\n", + "\n", + "ds = WaterMarkImageDataset()\n", + "ds.load('../../examples/data/cifar_10_ipr/fedipr_cifar10_guest/')" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Dataset ImageFolder\n", + " Number of datapoints: 25000\n", + " Root location: ../../examples/data/cifar_10_ipr/fedipr_cifar10_guest/normal\n", + " StandardTransform\n", + "Transform: Compose(\n", + " ToTensor()\n", + " )" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ds.get_normal_dataset()" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Dataset ImageFolder\n", + " Number of datapoints: 100\n", + " Root location: ../../examples/data/cifar_10_ipr/fedipr_cifar10_guest/watermark\n", + " StandardTransform\n", + "Transform: Compose(\n", + " ToTensor()\n", + " )" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ds.get_watermark_dataset() # water mark dataset" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "25100" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "len(ds)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "At this point, you can now customize a watermark dataset for your own tasks to add watermarks to your models. In the upcoming CIFAR-10 task, we will be using FATE's built-in image watermark dataset class." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Built-in BacthNorm and LayerNorm Blocks for Feature-based Watermark" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In this section, we will delve into the workings of feature-based watermarking. Feature-based watermarking involves embedding binary watermarks vectors into specific model parameters. In FATE implementations, we use the same design as the FATE-IPR paper: In the case of CNN, binary water mark are embeded into BatchNorm Layer. In transformers, watermarks are embeded into LayerNorm layers.\n", + "\n", + "You can use SignatureConv, SignatureLayerNorm to build your model. Once these blocks are detected in the FedIPR trainer, trainer will automatically assign binary watermark vector whose bit length is computed by Equation (15) in the origin paper.\n", + "\n", + "You can import them from:model's proprietary elements." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "from fate_llm.model_zoo.ipr.sign_block import SignatureConv, SignatureLayerNorm" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Here we show you the source code of our built in alexnet and distilbert to show you how to quickly build a model with featurebased watermark:" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "import torch.nn as nn\n", + "from fate_llm.model_zoo.ipr.sign_block import SignatureConv, ConvBlock\n", + "\n", + "\n", + "class SignAlexNet(nn.Module):\n", + "\n", + " \"\"\"\n", + " This is a modified Alexnet: its 4,5,6 layers are replaced by Singnature Conv Block\n", + " \"\"\"\n", + "\n", + " def __init__(self, num_classes):\n", + " super().__init__()\n", + " in_channels = 3\n", + " maxpoolidx = [1, 3, 7]\n", + " signed_layer = [4, 5, 6]\n", + " layers = []\n", + " inp = in_channels\n", + "\n", + " # channels & kennel size\n", + " # the same setting as the FedIPR paper\n", + " oups = {\n", + " 0: 64,\n", + " 2: 192,\n", + " 4: 384,\n", + " 5: 256,\n", + " 6: 256\n", + " }\n", + " kp = {\n", + " 0: (5, 2),\n", + " 2: (5, 2),\n", + " 4: (3, 1),\n", + " 5: (3, 1),\n", + " 6: (3, 1)\n", + " }\n", + "\n", + " for layeridx in range(8):\n", + " if layeridx in maxpoolidx:\n", + " layers.append(nn.MaxPool2d(2, 2))\n", + " else:\n", + " k = kp[layeridx][0]\n", + " p = kp[layeridx][1]\n", + " if layeridx in signed_layer:\n", + " layers.append(SignatureConv(inp, oups[layeridx], k, 1, p))\n", + " else:\n", + " layers.append(ConvBlock(inp, oups[layeridx], k, 1, p))\n", + " inp = oups[layeridx]\n", + "\n", + " self.features = nn.Sequential(*layers)\n", + " self.classifier = nn.Linear(4 * 4 * 256, num_classes)\n", + "\n", + " def forward(self, x):\n", + " for m in self.features:\n", + " x = m(x)\n", + " x = x.view(x.size(0), -1)\n", + " x = self.classifier(x)\n", + " if self.training:\n", + " return x\n", + " else: # Sofmax\n", + " return nn.functional.softmax(x, dim=1)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "By inserting signconv block you can easily build a cv model with feature-based signature, in the case of NLP models, by useing 'recursive_replace_layernorm' you can quickly replace the original LayerNorm with our sign layernorm. Codes below show that you can quickly add feature-based watermarks to a huggingface pretraind model:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from torch.nn import Module\n", + "from transformers import DistilBertForSequenceClassification, DistilBertForTokenClassification\n", + "from fate_llm.model_zoo.ipr.sign_block import recursive_replace_layernorm\n", + "\n", + "\n", + "class SignDistilBertForTokenClassification(Module):\n", + "\n", + " def __init__(self, model_path=None, num_labels=4) -> None:\n", + " super().__init__()\n", + " if model_path is None:\n", + " model_path = 'distilbert-base-uncased'\n", + "\n", + " self.model_path = model_path\n", + " self.model = DistilBertForTokenClassification.from_pretrained(\n", + " model_path, num_labels=num_labels)\n", + "\n", + " # replace layernorm by SignatureLayerNorm\n", + " sub_distilbert = self.model.distilbert.transformer.layer[3:] # replace layernorm by SingLayerNorm in the last 3 layer\n", + " recursive_replace_layernorm(\n", + " sub_distilbert,\n", + " layer_name_set={'output_layer_norm'})\n", + "\n", + " def forward(self, input_dict):\n", + " return self.model(**input_dict)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Verify Feature-based watermark with our tools\n", + "\n", + "After training is done, feature-based watermarks' signatures will be saved together with model. You can use our tool to verify the model ownership." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "from fate_llm.trainer.fedipr_trainer import verify_feature_based_signature" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "See the example below for usage." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## FedIPR on FATE" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In FATE-LLM-1.3’s model_zoo we have these built-in models which are automatically integrated with feature-based watermarking capabilities:\n", + "\n", + "#### Model List\n", + "\n", + "- `alexnet.py` - Alexnet\n", + "- `resnet.py` - Resnet18\n", + "- `distilbert.py` - Distilbert (Configurations match those in the FedIPR paper)\n", + "- `gpt2.py` - Standard GPT-2 (Watermarks are added to the last 2 transformer layers)\n", + "t.py`" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We have verified the effectiveness of our watermarking features through a series of tests:\n", + "- For computer vision tasks, we evaluated both backdoor watermarking and feature-based watermarking techniques on the CIFAR-10 and CIFAR-100 datasets. Our testing involved the use of ResNet and AlexNet models.\n", + "- For natural language processing tasks, we assessed the performance of DistilBERT and GPT2 models on the IMDB & CoNLL-2003 datasets, which are sequence classification tasn and token classification respectively. \n", + "During the testing phase, the sign bit was automatically allocated, and the data was evenly divided between the guest and host parties. For backdoor watermarking, each party supplied 100 trigger samples, all of which were augmented with noises.\n", + "\n", + "Here we display the results of the experiments:\n", + "\n", + "AlexNet & Resnet:\n", + "\n", + "| Test Configuration | AlexNet Feature-Based Watermark Accuracy | AlexNet Backdoor Watermark Accuracy | ResNet18 Feature-Based Watermark Accuracy | ResNet18 Backdoor Watermark Accuracy |\n", + "|--------------------|-----------------------------------------|------------------------------------|------------------------------------------|-------------------------------------|\n", + "| Two-party federation on CIFAR-10 with 100 trigger samples, SignBit auto-assigned | 1.0 (All Parties) | 1.0 (All Parties) | 1.0 (All Parties) | 1.0 (All Parties) |\n", + "| Two-party federation on CIFAR-100 with 100 trigger samples, SignBit auto-assigned | 1.0 (All Parties) | 1.0 (Guest), 0.991 (Host) | 1.0 (All Parties) | 1.0 (All Parties) |\n", + "\n", + "DistilBert & GPT2:\n", + "\n", + "| Test Configuration | DistillBERT Feature-Based Watermark Accuracy | GPT-2 Feature-Based Watermark Accuracy |\n", + "|--------------------|----------------------------------------------|---------------------------------------|\n", + "| Two-party federation on CoNLL-2003 Token Classification with SignBit auto-assigned | 1.0 (All Parties) | 1.0 (All Parties) |\n", + "| Two-party federation on IMDB Classification with SignBit auto-assigned | 1.0 (All Parties) | 1.0 (All Parties) |\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## A Cifar-10 Example & Verifying Watermark\n", + "\n", + "At last, we will show you a CV example: we will train a AlexNet with backdoor watermark & feature-based watermark at the same time. And after training is done, we use built in tools to verify feature-based watermark. You can verify the backdoor watermark yourself by simply predicting trigger samples with your models." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### FedIPR Parameters\n", + "\n", + "The FedIPRTrainer's parameters are basically the same as the FedAVGTrainer except for 3 parameters: alpha, verify_freq and backdoor_verify_method\n", + "alpha is the weight for sign loss; verify_freq is the frequency of verifying your watermark during training(you can check result in logs) and backdoor_verify_method allows you to choose the method for verifying your datasets." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "class FedIPRTrainer(FedAVGTrainer):\n", + "\n", + " def __init__(self,\n", + " epochs=10,\n", + " noraml_dataset_batch_size=32,\n", + " watermark_dataset_batch_size=2,\n", + " early_stop=None,\n", + " tol=0.0001,\n", + " secure_aggregate=True,\n", + " weighted_aggregation=True,\n", + " aggregate_every_n_epoch=None,\n", + " cuda=None,\n", + " pin_memory=True,\n", + " shuffle=True,\n", + " data_loader_worker=0,\n", + " validation_freqs=None,\n", + " checkpoint_save_freqs=None,\n", + " task_type='auto',\n", + " save_to_local_dir=False,\n", + " collate_fn=None,\n", + " collate_fn_params=None,\n", + " alpha=0.01,\n", + " verify_freqs=1,\n", + " backdoor_verify_method: Literal['accuracy',\n", + " 'loss'] = 'accuracy'):\n", + " ..." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Submit a pipeline to run FedIPR CV task\n", + "\n", + "This a standalone version example, if you are running on the cluster version, you have to bind name&namespace on guest&host machines correspondingly" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import torch as t\n", + "from torch import nn\n", + "from pipeline import fate_torch_hook\n", + "from pipeline.component import HomoNN\n", + "from pipeline.backend.pipeline import PipeLine\n", + "from pipeline.component import Reader, Evaluation, DataTransform\n", + "from pipeline.interface import Data, Model\n", + "\n", + "t = fate_torch_hook(t)\n", + "\n", + "import os\n", + "# bind data path to name & namespace\n", + "fate_project_path = os.path.abspath('../../')\n", + "host = 9997\n", + "guest = 9997\n", + "arbiter = 9997\n", + "pipeline = PipeLine().set_initiator(role='guest', party_id=guest).set_roles(guest=guest, host=host,\n", + " arbiter=arbiter)\n", + "\n", + "data_0 = {\"name\": \"watermark_cifar10_guest\", \"namespace\": \"experiment\"}\n", + "data_1 = {\"name\": \"watermark_cifar10_host\", \"namespace\": \"experiment\"}\n", + "\n", + "data_path_0 = fate_project_path + '/examples/data/cifar_10_ipr/fedipr_cifar10_guest'\n", + "data_path_1 = fate_project_path + '/examples/data/cifar_10_ipr/fedipr_cifar10_host'\n", + "pipeline.bind_table(name=data_0['name'], namespace=data_0['namespace'], path=data_path_0)\n", + "pipeline.bind_table(name=data_1['name'], namespace=data_1['namespace'], path=data_path_1)\n", + "\n", + "reader_0 = Reader(name=\"reader_0\")\n", + "reader_0.get_party_instance(role='guest', party_id=guest).component_param(table=data_0)\n", + "reader_0.get_party_instance(role='host', party_id=host).component_param(table=data_1)\n", + "\n", + "from pipeline.component.nn import DatasetParam\n", + "\n", + "dataset_param = DatasetParam(dataset_name='watermark')\n", + "\n", + "from pipeline.component.homo_nn import TrainerParam # Interface\n", + "\n", + "# our simple classification model:\n", + "model = t.nn.CustModel(module_name='ipr.alexnet', class_name='SignAlexNet', num_classes=10)\n", + "\n", + "nn_component = HomoNN(name='nn_0',\n", + " model=model, # model\n", + " dataset=dataset_param, # dataset\n", + " # Notice that for the convenience of getting result model we set save_to_local_dir=True\n", + " trainer=TrainerParam(trainer_name='fedipr_trainer', epochs=5, save_to_local_dir=True, cuda=0),\n", + " optimizer=t.optim.Adam(lr=0.001),\n", + " loss=t.nn.CrossEntropyLoss(),\n", + " torch_seed=100 # random seed\n", + " )\n", + "\n", + "\n", + "pipeline.add_component(reader_0)\n", + "pipeline.add_component(nn_component, data=Data(train_data=reader_0.output.data))\n", + "pipeline.compile()\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "pipeline.fit() # submit!" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Load Model and Verify\n", + "\n", + "Since we enable 'save_to_local_dir', we can directly load trained model from fateflow job folder, and verify its watermarks" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [], + "source": [ + "from fate_llm.trainer.fedipr_trainer import verify_feature_based_signature" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [], + "source": [ + "job_id = '202309041103336933850' # your job id\n", + "import os\n", + "fate_project_path = os.path.abspath('../../')\n", + "local_dir = fate_project_path + '/fateflow/jobs/{}/guest/9997/nn_0/'.format(job_id)\n", + "state_dict = t.load(local_dir + 'model.pkl')" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [], + "source": [ + "from fate_llm.model_zoo.ipr.alexnet import SignAlexNet\n", + "\n", + "model = SignAlexNet(num_classes=10)" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 24, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model.load_state_dict(state_dict['model'])" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": {}, + "outputs": [], + "source": [ + "keys = state_dict['extra_data']['keys'] # W and watermark vectors" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'features.4': (tensor([[-7.3380e-02, 1.6275e+00, -1.5404e+00, ..., 3.4250e-01,\n", + " -1.0067e+00, -5.4504e-01],\n", + " [ 2.9928e-01, -4.0935e-01, -6.1239e-01, ..., 7.2356e-01,\n", + " 2.7019e-01, -9.1200e-01],\n", + " [-4.3889e-02, 2.1774e+00, -1.3706e+00, ..., -8.5879e-01,\n", + " 2.3445e-01, 2.0458e+00],\n", + " ...,\n", + " [-5.1755e-01, 5.9240e-01, 2.6353e-01, ..., -1.0465e+00,\n", + " -5.3456e-01, -6.0439e-01],\n", + " [-2.4679e-01, -1.4290e+00, -5.9567e-01, ..., 7.7682e-01,\n", + " -6.2445e-01, 1.3682e+00],\n", + " [ 1.1148e+00, -8.7518e-01, 7.6818e-01, ..., 6.5654e-01,\n", + " -1.8362e+00, -5.5355e-04]]),\n", + " tensor([-1., -1., 1., 1., 1., -1., 1., 1., -1., -1., -1., 1., -1., 1.,\n", + " -1., -1., 1., -1., -1., -1., -1., 1., -1., 1., -1., -1., 1., -1.,\n", + " 1., -1., -1., -1., 1., 1., 1., 1., 1., 1., 1., -1., -1., 1.,\n", + " 1., -1., 1., -1., -1., 1., -1., 1., -1., -1., -1., 1., -1., -1.,\n", + " -1., -1., -1., -1., 1., -1., -1., 1., 1., -1., -1., 1., -1., -1.,\n", + " 1., 1., 1., -1., -1., -1., -1., 1., -1., 1., -1., 1., -1., 1.,\n", + " -1., -1., 1., -1., -1., -1., -1., 1., -1., 1., 1., -1., -1., 1.,\n", + " -1., -1., 1., -1., -1., -1., -1., 1., 1., 1., -1., -1., 1., -1.,\n", + " -1., -1., 1., -1., 1., -1., -1., -1., -1., 1., -1., 1., 1., -1.,\n", + " 1., 1., 1., -1., 1., 1., -1., -1., 1., 1., -1., -1., -1., 1.,\n", + " -1., -1., -1., -1., 1., -1., -1., 1., 1., 1., -1., 1., -1., 1.,\n", + " 1., 1., -1., 1., 1., 1., 1., 1., 1., 1., -1., -1., -1., 1.,\n", + " -1., -1., -1., -1., -1., -1., -1., -1., -1., 1., -1., -1., 1., 1.,\n", + " 1., -1., 1., -1., -1., -1., 1., 1., 1.])),\n", + " 'features.5': (tensor([[-1.2336, 0.1894, -0.3584, ..., -0.5398, 0.5318, -1.6536],\n", + " [ 0.1128, 0.3999, 1.2841, ..., 1.6082, -0.1920, -0.0636],\n", + " [-0.9447, -0.2025, 0.4786, ..., 1.5100, -0.7834, 0.8102],\n", + " ...,\n", + " [-0.7941, 2.0311, -0.9690, ..., -1.1630, 0.6953, 1.6115],\n", + " [ 0.0314, 0.3718, 0.5974, ..., -1.6695, 1.8833, -0.1461],\n", + " [ 0.4956, 0.7747, -0.0847, ..., -0.3533, 0.0763, 0.0952]]),\n", + " tensor([-1., 1., -1., 1., 1., 1., 1., 1., 1., -1., -1., -1., 1., -1.,\n", + " -1., -1., -1., 1., -1., 1., -1., -1., 1., -1., -1., -1., 1., -1.,\n", + " 1., 1., 1., -1., -1., -1., 1., 1., 1., 1., 1., 1., -1., 1.,\n", + " -1., -1., -1., -1., 1., -1., 1., -1., 1., -1., 1., -1., -1., 1.,\n", + " 1., 1., 1., 1., -1., 1., 1., 1., -1., -1., -1., -1., -1., -1.,\n", + " -1., 1., 1., -1., 1., -1., -1., 1., -1., -1., 1., 1., -1., -1.,\n", + " -1., -1., -1., 1., 1., 1., 1., 1., 1., -1., 1., -1., -1., 1.,\n", + " 1., 1., 1., -1., 1., -1., 1., 1., -1., -1., 1., 1., -1., -1.,\n", + " 1., -1., -1., 1., 1., -1., 1., 1., 1., 1., -1., -1., -1., -1.,\n", + " 1.])),\n", + " 'features.6': (tensor([[ 2.6993e+00, 1.0507e+00, -6.6219e-01, ..., 6.3679e-01,\n", + " 7.7061e-01, 1.4231e+00],\n", + " [-1.0477e+00, 2.0904e-01, -3.4522e-01, ..., -4.9581e-01,\n", + " 1.4211e+00, -2.1041e+00],\n", + " [ 1.0036e+00, 1.0025e+00, -2.5215e-03, ..., 1.1413e+00,\n", + " -1.8600e+00, 2.0058e-02],\n", + " ...,\n", + " [ 1.2943e+00, 5.6073e-01, -1.9590e+00, ..., -1.4320e+00,\n", + " -1.6486e+00, -3.0871e-01],\n", + " [ 4.2747e-01, 1.8310e+00, -2.7685e-01, ..., -1.0765e+00,\n", + " -4.6004e-01, 3.6701e-02],\n", + " [-4.9978e-01, 4.4728e-01, -7.3183e-01, ..., 7.5242e-01,\n", + " 8.4118e-01, 8.3414e-02]]),\n", + " tensor([ 1., -1., 1., 1., -1., -1., 1., -1., -1., 1., 1., 1., 1., -1.,\n", + " -1., -1., -1., 1., 1., 1., 1., -1., -1., -1., 1., 1., 1., 1.,\n", + " -1., 1., 1., -1., 1., -1., -1., 1., -1., -1., 1., 1., -1., 1.,\n", + " -1., -1., -1., -1., -1., 1., 1., -1., 1., -1., -1., -1., -1., -1.,\n", + " -1., -1., 1., -1., 1., 1., -1., 1., -1., -1., 1., -1., -1., 1.,\n", + " 1., 1., -1., -1., -1., 1., -1., -1., 1., -1., 1., -1., 1., -1.,\n", + " -1., 1., 1., 1., -1., 1., 1., 1., 1., 1., 1., -1., 1., -1.,\n", + " -1., 1., 1., 1., -1., 1., 1., -1., -1., 1., 1., 1., 1., 1.,\n", + " -1., -1., -1., -1., 1., 1., 1., -1., 1., -1., 1., -1., 1., 1.,\n", + " -1., -1., -1.]))}" + ] + }, + "execution_count": 28, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "keys" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "1.0\n" + ] + } + ], + "source": [ + "acc = verify_feature_based_signature(model, keys)\n", + "print(acc)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The accuracy is 100%! Congratulations. Now you can use FATE to build your own IPR protected models." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.13" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/doc/tutorial/ChatGLM-6B_ds.ipynb b/doc/tutorial/parameter_efficient_llm/ChatGLM-6B_ds.ipynb similarity index 100% rename from doc/tutorial/ChatGLM-6B_ds.ipynb rename to doc/tutorial/parameter_efficient_llm/ChatGLM-6B_ds.ipynb diff --git a/doc/tutorial/GPT2-example.ipynb b/doc/tutorial/parameter_efficient_llm/GPT2-example.ipynb similarity index 100% rename from doc/tutorial/GPT2-example.ipynb rename to doc/tutorial/parameter_efficient_llm/GPT2-example.ipynb From 08f54d6dce8e954d05b4ae5d00389f9d0c2bc98c Mon Sep 17 00:00:00 2001 From: cwj Date: Mon, 4 Sep 2023 21:11:27 +0800 Subject: [PATCH 08/23] Add dataset tools for 3 QA task Signed-off-by: cwj --- python/fate_llm/dataset/qa_dataset.py | 173 ++++++++++++++++++++++++++ 1 file changed, 173 insertions(+) create mode 100644 python/fate_llm/dataset/qa_dataset.py diff --git a/python/fate_llm/dataset/qa_dataset.py b/python/fate_llm/dataset/qa_dataset.py new file mode 100644 index 0000000..0dfd2c3 --- /dev/null +++ b/python/fate_llm/dataset/qa_dataset.py @@ -0,0 +1,173 @@ +from datasets import load_from_disk, load_dataset +from transformers import AutoTokenizer +import torch as t +import os +from federatedml.nn.dataset.base import Dataset + + +class PIQA: + def __init__(self): + self._template = "Question: {}\nAnswer:" + + def get_context(self, examples): + ctx = examples['goal'] + return [self._template.format(c) for c in ctx] + + def get_target(self, examples): + if -1 in examples["label"]: # test set + return [""] * len(examples["label"]) + else: + gt_tuples = [("sol{}".format(label + 1), idx) + for idx, label in enumerate(examples['label'])] + return [examples[k][i] for k, i in gt_tuples] + +class SciQ: + def __init__(self): + self._template = "{}\nQuestion: {}\nAnswer:" + + def get_context(self, examples): + sources = examples['support'] + queries = examples['question'] + return [self._template.format(s, q) for s, q in zip(sources, queries)] + + def get_target(self, examples): + return examples['correct_answer'] + + +class OpenBookQA: + def get_context(self, examples): + return examples['question_stem'] + + def get_target(self, examples): + choices = examples['choices'] + answers = examples['answerKey'] + targets = [] + for choice, answer in zip(choices, answers): + answer = ord(answer.strip()) - ord('A') + targets.append(choice['text'][answer]) + return targets + + +task_dict = { + "piqa": PIQA(), + "sciq": SciQ(), + "openbookqa": OpenBookQA() +} + +def tokenize_qa_dataset(dataset_name, tokenizer, save_path, seq_max_len=1000): + + max_len = seq_max_len + assert dataset_name in ['piqa', 'sciq', 'openbookqa'], "dataset name must be one of ['piqa', 'sciq', 'openbookqa']" + raw_datasets = load_from_disk(dataset_name) + task = task_dict[dataset_name] + + column_names = raw_datasets["train"].column_names + + def tokenize_function(examples): + context = task.get_context(examples) + target = task.get_target(examples) + + context = tokenizer(context) + target = tokenizer(target) + + # if context is ending with special token, remove it + if len(context['input_ids'][0]) > 0 and context['input_ids'][0][-1] in tokenizer.all_special_ids: + context['input_ids'] = [i[:-1] for i in context['input_ids']] + context['attention_mask'] = [a[:-1] + for a in context['attention_mask']] + + # if target is starting with special token, remove it + if len(target['input_ids'][0]) > 0 and target['input_ids'][0][0] in tokenizer.all_special_ids: + target['input_ids'] = [i[1:] for i in target['input_ids']] + target['attention_mask'] = [a[1:] + for a in target['attention_mask']] + + out = {} + out['input_ids'] = [i1 + i2 for i1, + i2 in zip(context['input_ids'], target['input_ids'])] + out['attention_mask'] = [a1 + a2 for a1, + a2 in zip(context['attention_mask'], target['attention_mask'])] + + # set -100 for context tokens + out["labels"] = [ + [-100] * len(i1) + i2 for i1, i2 in zip(context['input_ids'], target['input_ids'])] + + return out + + tokenized_datasets = raw_datasets.map( + tokenize_function, + batched=True, + num_proc=32, + remove_columns=column_names, + load_from_cache_file=True, + desc="Running tokenizer on dataset", + ) + + # pad all instances in lm_datasets to the max length of the dataset + max_length = -1 + for v in tokenized_datasets.values(): + for x in v: + max_length = max(max_length, len(x['input_ids'])) + + # pad to the multiple of 8 + max_length = (max_length // 8 + 1) * 8 + + block_size = max_len + max_length = min(max_length, block_size) + + def pad_function(examples): + examples["input_ids"] = [i + [tokenizer.pad_token_id] * + (max_length - len(i)) for i in examples["input_ids"]] + examples["attention_mask"] = [[1] * len(i) + [0] * + (max_length - len(i)) for i in examples["attention_mask"]] + examples["labels"] = [i + [-100] * + (max_length - len(i)) for i in examples["labels"]] + # truncate to max_length + examples["input_ids"] = [i[:max_length] for i in examples["input_ids"]] + examples["attention_mask"] = [a[:max_length] + for a in examples["attention_mask"]] + examples["labels"] = [l[:max_length] for l in examples["labels"]] + return examples + + + tokenized_datasets = tokenized_datasets.map( + pad_function, + batched=True, + num_proc=32, + load_from_cache_file=True, + desc=f"Padding dataset to max length {max_length}", + ) + assert os.path.exists(save_path), "save_path must be a valid path" + tokenized_datasets.save_to_disk(save_path) + return tokenized_datasets + + +class QaDataset(Dataset): + + def __init__(self, tokenizer_name_or_path, select_num=None, start_idx=None): + self.select_num = select_num + self.start_idx = start_idx + self.ds = None + if 'llama' in tokenizer_name_or_path: + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path, unk_token="", bos_token="", eos_token="", add_eos_token=True) + self.tokenizer.pad_token = self.tokenizer.eos_token + else: + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path) + if 'gpt2' in tokenizer_name_or_path: + self.tokenizer.pad_token = self.tokenizer.eos_token + + + def load(self, path): + loaded = load_from_disk(path) + self.ds = loaded['train'] + if self.select_num is not None: + if self.start_idx is not None: + self.ds = self.ds.select(range(self.start_idx, min(len(self.ds), self.start_idx + self.select_num))) + else: + self.ds = self.ds.select(range(self.select_num)) + + def __len__(self): + return len(self.ds) + + def __getitem__(self, idx): + return self.ds[idx] From de5a2cdf806fb0b895f9e0a35047a263118a0354 Mon Sep 17 00:00:00 2001 From: cwj Date: Tue, 5 Sep 2023 10:33:43 +0800 Subject: [PATCH 09/23] Fix doc Signed-off-by: cwj --- doc/tutorial/offsite_tuning/Offsite_tuning_tutorial_0.ipynb | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 doc/tutorial/offsite_tuning/Offsite_tuning_tutorial_0.ipynb diff --git a/doc/tutorial/offsite_tuning/Offsite_tuning_tutorial_0.ipynb b/doc/tutorial/offsite_tuning/Offsite_tuning_tutorial_0.ipynb new file mode 100644 index 0000000..e69de29 From 36099d1975633e79c7cea35469d9088430ad437e Mon Sep 17 00:00:00 2001 From: cwj Date: Tue, 5 Sep 2023 11:26:42 +0800 Subject: [PATCH 10/23] fix doc typo Signed-off-by: cwj --- doc/tutorial/fed_ipr/FedIPR-tutorial.ipynb | 8 ++++---- python/fate_llm/model_zoo/offsite_tuning/gpt2_ot.py | 3 --- 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/doc/tutorial/fed_ipr/FedIPR-tutorial.ipynb b/doc/tutorial/fed_ipr/FedIPR-tutorial.ipynb index 05a4f5f..8b32781 100644 --- a/doc/tutorial/fed_ipr/FedIPR-tutorial.ipynb +++ b/doc/tutorial/fed_ipr/FedIPR-tutorial.ipynb @@ -12,7 +12,7 @@ "metadata": {}, "source": [ "In this tutorial, you'll learn how to add both backdoor-based and feature-based watermarks to your models in the federated training. \n", - "We'll dive into using backdoor-w datasets for backdoor-based watermarking and exploring signblock—a tool that learns feature-based watermarks during traning. We will show you how to apply these techniques to both computer vision and language models. We'll also offer a hands-on example with a CV task, share how to verify the watermarks you've embedded, and introduce some ready-to-use models provided by the FATE framework. " + "We'll dive into using backdoor-watermark datasets for backdoor-based watermarking and exploring signblock—a tool that learns feature-based watermarks during traning. We will show you how to apply these techniques to both computer vision and language models. We'll also offer a hands-on example with a CV task, share how to verify the watermarks you've embedded, and introduce some ready-to-use models provided by the FATE framework. " ] }, { @@ -181,7 +181,7 @@ "from fate_llm.dataset.watermark import WaterMarkImageDataset\n", "\n", "ds = WaterMarkImageDataset()\n", - "ds.load('../../examples/data/cifar_10_ipr/fedipr_cifar10_guest/')" + "ds.load('../../../examples/data/cifar_10_ipr/fedipr_cifar10_guest/')" ] }, { @@ -567,7 +567,7 @@ "\n", "import os\n", "# bind data path to name & namespace\n", - "fate_project_path = os.path.abspath('../../')\n", + "fate_project_path = os.path.abspath('../../../')\n", "host = 9997\n", "guest = 9997\n", "arbiter = 9997\n", @@ -647,7 +647,7 @@ "source": [ "job_id = '202309041103336933850' # your job id\n", "import os\n", - "fate_project_path = os.path.abspath('../../')\n", + "fate_project_path = os.path.abspath('../../../')\n", "local_dir = fate_project_path + '/fateflow/jobs/{}/guest/9997/nn_0/'.format(job_id)\n", "state_dict = t.load(local_dir + 'model.pkl')" ] diff --git a/python/fate_llm/model_zoo/offsite_tuning/gpt2_ot.py b/python/fate_llm/model_zoo/offsite_tuning/gpt2_ot.py index ebc7ded..dfb50cb 100644 --- a/python/fate_llm/model_zoo/offsite_tuning/gpt2_ot.py +++ b/python/fate_llm/model_zoo/offsite_tuning/gpt2_ot.py @@ -109,9 +109,6 @@ def get_base_model(self): def get_model_transformer_blocks(self, model: GPT2LMHeadModel): return model.transformer.h - def forward(self, x): - return self.model(**x) - def get_additional_param_state_dict(self): # get parameter of additional parameter model = self.model From a81cbb036c162952f4f02042e80168a878bbe774 Mon Sep 17 00:00:00 2001 From: cwj Date: Tue, 5 Sep 2023 14:40:58 +0800 Subject: [PATCH 11/23] Update & Remove redundent codes Signed-off-by: cwj --- python/fate_llm/dataset/qa_dataset.py | 4 ++++ python/fate_llm/model_zoo/offsite_tuning/gpt2_ot.py | 4 ---- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/python/fate_llm/dataset/qa_dataset.py b/python/fate_llm/dataset/qa_dataset.py index 0dfd2c3..79e0e94 100644 --- a/python/fate_llm/dataset/qa_dataset.py +++ b/python/fate_llm/dataset/qa_dataset.py @@ -5,6 +5,10 @@ from federatedml.nn.dataset.base import Dataset +""" +These Data pre-processing codes are from https://github.com/mit-han-lab/offsite-tuning +""" + class PIQA: def __init__(self): self._template = "Question: {}\nAnswer:" diff --git a/python/fate_llm/model_zoo/offsite_tuning/gpt2_ot.py b/python/fate_llm/model_zoo/offsite_tuning/gpt2_ot.py index dfb50cb..1c8c1a7 100644 --- a/python/fate_llm/model_zoo/offsite_tuning/gpt2_ot.py +++ b/python/fate_llm/model_zoo/offsite_tuning/gpt2_ot.py @@ -72,10 +72,6 @@ def load_additional_param_state_dict(self, submodel_weights: dict): self.load_numpy_state_dict(param_dict, new_submodel_weight) - def forward(self, x): - return self.model(**x) - - class GPT2LMHeadSubModel(OffsiteTuningSubModel): def __init__( From 9898a37f693ed0e737fd666e57e7e4bcadbcf7dc Mon Sep 17 00:00:00 2001 From: cwj Date: Tue, 5 Sep 2023 14:44:37 +0800 Subject: [PATCH 12/23] Fix codes Signed-off-by: cwj --- python/fate_llm/dataset/qa_dataset.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/fate_llm/dataset/qa_dataset.py b/python/fate_llm/dataset/qa_dataset.py index 79e0e94..32377de 100644 --- a/python/fate_llm/dataset/qa_dataset.py +++ b/python/fate_llm/dataset/qa_dataset.py @@ -58,11 +58,12 @@ def get_target(self, examples): "openbookqa": OpenBookQA() } + def tokenize_qa_dataset(dataset_name, tokenizer, save_path, seq_max_len=1000): max_len = seq_max_len assert dataset_name in ['piqa', 'sciq', 'openbookqa'], "dataset name must be one of ['piqa', 'sciq', 'openbookqa']" - raw_datasets = load_from_disk(dataset_name) + raw_datasets = load_dataset(dataset_name) task = task_dict[dataset_name] column_names = raw_datasets["train"].column_names From faa68e2bbc7a08e93577895ec60382f290cd0bf9 Mon Sep 17 00:00:00 2001 From: cwj Date: Tue, 5 Sep 2023 14:58:55 +0800 Subject: [PATCH 13/23] remove constraint Signed-off-by: cwj --- python/fate_llm/dataset/qa_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/fate_llm/dataset/qa_dataset.py b/python/fate_llm/dataset/qa_dataset.py index 32377de..0027a72 100644 --- a/python/fate_llm/dataset/qa_dataset.py +++ b/python/fate_llm/dataset/qa_dataset.py @@ -142,7 +142,7 @@ def pad_function(examples): load_from_cache_file=True, desc=f"Padding dataset to max length {max_length}", ) - assert os.path.exists(save_path), "save_path must be a valid path" + tokenized_datasets.save_to_disk(save_path) return tokenized_datasets From e63cf2a927ca13826d4c9f687c377f1c8b9db594 Mon Sep 17 00:00:00 2001 From: cwj Date: Tue, 5 Sep 2023 16:36:42 +0800 Subject: [PATCH 14/23] Add file head Signed-off-by: cwj --- python/fate_llm/dataset/qa_dataset.py | 17 ++++++++++++++++- python/fate_llm/model_zoo/ipr/alexnet.py | 15 +++++++++++++++ python/fate_llm/model_zoo/ipr/distilbert.py | 15 +++++++++++++++ python/fate_llm/model_zoo/ipr/gpt2.py | 15 +++++++++++++++ python/fate_llm/model_zoo/ipr/resnet.py | 15 +++++++++++++++ python/fate_llm/model_zoo/ipr/sign_block.py | 15 +++++++++++++++ .../model_zoo/offsite_tuning/bloom_ot.py | 15 +++++++++++++++ .../model_zoo/offsite_tuning/gpt2_ot.py | 15 +++++++++++++++ .../model_zoo/offsite_tuning/llama_ot.py | 15 +++++++++++++++ .../offsite_tuning/offsite_tuning_model.py | 15 +++++++++++++++ python/fate_llm/trainer/fedipr_trainer.py | 15 +++++++++++++++ .../fate_llm/trainer/offsite_tuning_trainer.py | 15 +++++++++++++++ 12 files changed, 181 insertions(+), 1 deletion(-) diff --git a/python/fate_llm/dataset/qa_dataset.py b/python/fate_llm/dataset/qa_dataset.py index 0027a72..017c241 100644 --- a/python/fate_llm/dataset/qa_dataset.py +++ b/python/fate_llm/dataset/qa_dataset.py @@ -1,3 +1,18 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# from datasets import load_from_disk, load_dataset from transformers import AutoTokenizer import torch as t @@ -6,7 +21,7 @@ """ -These Data pre-processing codes are from https://github.com/mit-han-lab/offsite-tuning +These Data pre-processing templates are from https://github.com/mit-han-lab/offsite-tuning """ class PIQA: diff --git a/python/fate_llm/model_zoo/ipr/alexnet.py b/python/fate_llm/model_zoo/ipr/alexnet.py index bcaee84..28c6dc0 100644 --- a/python/fate_llm/model_zoo/ipr/alexnet.py +++ b/python/fate_llm/model_zoo/ipr/alexnet.py @@ -1,3 +1,18 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# import torch.nn as nn from fate_llm.model_zoo.ipr.sign_block import SignatureConv, ConvBlock diff --git a/python/fate_llm/model_zoo/ipr/distilbert.py b/python/fate_llm/model_zoo/ipr/distilbert.py index 4455417..063fd13 100644 --- a/python/fate_llm/model_zoo/ipr/distilbert.py +++ b/python/fate_llm/model_zoo/ipr/distilbert.py @@ -1,3 +1,18 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# from torch.nn import Module from transformers import DistilBertForSequenceClassification, DistilBertForTokenClassification from fate_llm.model_zoo.ipr.sign_block import recursive_replace_layernorm diff --git a/python/fate_llm/model_zoo/ipr/gpt2.py b/python/fate_llm/model_zoo/ipr/gpt2.py index 39d72e2..26c9b4b 100644 --- a/python/fate_llm/model_zoo/ipr/gpt2.py +++ b/python/fate_llm/model_zoo/ipr/gpt2.py @@ -1,3 +1,18 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# from torch.nn import Module from transformers import GPT2ForTokenClassification, GPT2ForSequenceClassification from fate_llm.model_zoo.ipr.sign_block import recursive_replace_layernorm diff --git a/python/fate_llm/model_zoo/ipr/resnet.py b/python/fate_llm/model_zoo/ipr/resnet.py index 4e3b0e4..d03e870 100644 --- a/python/fate_llm/model_zoo/ipr/resnet.py +++ b/python/fate_llm/model_zoo/ipr/resnet.py @@ -1,3 +1,18 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# import torch.nn as nn import torch.nn.functional as F from fate_llm.model_zoo.ipr.sign_block import ConvBlock, SignatureConv diff --git a/python/fate_llm/model_zoo/ipr/sign_block.py b/python/fate_llm/model_zoo/ipr/sign_block.py index 45e701d..5cef62e 100644 --- a/python/fate_llm/model_zoo/ipr/sign_block.py +++ b/python/fate_llm/model_zoo/ipr/sign_block.py @@ -1,3 +1,18 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# import torch import torch.nn as nn import torch.nn.init as init diff --git a/python/fate_llm/model_zoo/offsite_tuning/bloom_ot.py b/python/fate_llm/model_zoo/offsite_tuning/bloom_ot.py index b1586e6..2fd97cf 100644 --- a/python/fate_llm/model_zoo/offsite_tuning/bloom_ot.py +++ b/python/fate_llm/model_zoo/offsite_tuning/bloom_ot.py @@ -1,3 +1,18 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# from fate_llm.model_zoo.offsite_tuning.offsite_tuning_model import OffsiteTuningSubModel, OffsiteTuningMainModel, get_dropout_emulator_and_adapters, split_numpy_array, recover_numpy_array from transformers.models.bloom.modeling_bloom import BloomForCausalLM, BloomModel, BloomConfig from torch import nn diff --git a/python/fate_llm/model_zoo/offsite_tuning/gpt2_ot.py b/python/fate_llm/model_zoo/offsite_tuning/gpt2_ot.py index 1c8c1a7..c122d43 100644 --- a/python/fate_llm/model_zoo/offsite_tuning/gpt2_ot.py +++ b/python/fate_llm/model_zoo/offsite_tuning/gpt2_ot.py @@ -1,3 +1,18 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# from fate_llm.model_zoo.offsite_tuning.offsite_tuning_model import OffsiteTuningSubModel, OffsiteTuningMainModel, get_dropout_emulator_and_adapters, split_numpy_array, recover_numpy_array from transformers import GPT2LMHeadModel, GPT2Config from torch import nn diff --git a/python/fate_llm/model_zoo/offsite_tuning/llama_ot.py b/python/fate_llm/model_zoo/offsite_tuning/llama_ot.py index cb8de1c..7acb02c 100644 --- a/python/fate_llm/model_zoo/offsite_tuning/llama_ot.py +++ b/python/fate_llm/model_zoo/offsite_tuning/llama_ot.py @@ -1,3 +1,18 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# from fate_llm.model_zoo.offsite_tuning.offsite_tuning_model import OffsiteTuningSubModel, OffsiteTuningMainModel, get_dropout_emulator_and_adapters, split_numpy_array, recover_numpy_array from transformers import LlamaConfig, LlamaForCausalLM diff --git a/python/fate_llm/model_zoo/offsite_tuning/offsite_tuning_model.py b/python/fate_llm/model_zoo/offsite_tuning/offsite_tuning_model.py index a81f03e..1334148 100644 --- a/python/fate_llm/model_zoo/offsite_tuning/offsite_tuning_model.py +++ b/python/fate_llm/model_zoo/offsite_tuning/offsite_tuning_model.py @@ -1,3 +1,18 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# import torch as t from torch import nn from federatedml.util import LOGGER diff --git a/python/fate_llm/trainer/fedipr_trainer.py b/python/fate_llm/trainer/fedipr_trainer.py index bc2b9e5..bc0f919 100644 --- a/python/fate_llm/trainer/fedipr_trainer.py +++ b/python/fate_llm/trainer/fedipr_trainer.py @@ -1,3 +1,18 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# import torch as t import tqdm import numpy as np diff --git a/python/fate_llm/trainer/offsite_tuning_trainer.py b/python/fate_llm/trainer/offsite_tuning_trainer.py index 718cf20..19b2cd0 100644 --- a/python/fate_llm/trainer/offsite_tuning_trainer.py +++ b/python/fate_llm/trainer/offsite_tuning_trainer.py @@ -1,3 +1,18 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# import torch as t from federatedml.nn.homo.trainer.fedavg_trainer import FedAVGTrainer from federatedml.framework.homo.aggregator.secure_aggregator import SecureAggregatorClient as SecureAggClient From b965ced870419ac5b357467509f2c6f0fcafff22 Mon Sep 17 00:00:00 2001 From: cwj Date: Tue, 5 Sep 2023 17:37:14 +0800 Subject: [PATCH 15/23] Add offsite-tuning tutorial Signed-off-by: cwj --- .../Offsite_tuning_tutorial.ipynb | 802 ++++++++++++++++++ .../Offsite_tuning_tutorial_0.ipynb | 0 2 files changed, 802 insertions(+) create mode 100644 doc/tutorial/offsite_tuning/Offsite_tuning_tutorial.ipynb delete mode 100644 doc/tutorial/offsite_tuning/Offsite_tuning_tutorial_0.ipynb diff --git a/doc/tutorial/offsite_tuning/Offsite_tuning_tutorial.ipynb b/doc/tutorial/offsite_tuning/Offsite_tuning_tutorial.ipynb new file mode 100644 index 0000000..23bbb04 --- /dev/null +++ b/doc/tutorial/offsite_tuning/Offsite_tuning_tutorial.ipynb @@ -0,0 +1,802 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "c2345e19-83eb-4196-9606-74658c8fbdc5", + "metadata": {}, + "source": [ + "# Offsite-tuning Tutorial" + ] + }, + { + "cell_type": "markdown", + "id": "9f1d728c-09e1-418e-8d80-53dd0ec467b1", + "metadata": {}, + "source": [ + "In this tutorial, we'll focus on how to leverage Offsite-Tuning framework in FATE to fine-tune your LLM. You'll learn how to:\n", + "\n", + "1. Define models, including main models(which are at server side and will offer adapters and emulators) and submodel(which are at client side and will load adapters and emulators for local fine-tuning) compatible with Offsite-Tuning framework.\n", + "2. Get hands-on experience with the Offsite-Tuning trainer.\n", + "3. Define configurations for advanced setup(Using Deepspeed, offsite-tuning + federation) through FATE-pipeline." + ] + }, + { + "cell_type": "markdown", + "id": "31432345-5cce-4efa-9a9b-844f997f14ad", + "metadata": {}, + "source": [ + "## Introduction of Offsite-tuning\n", + "\n", + "Offsite-Tuning is a novel approach designed for the efficient and privacy-preserving adaptation of large foundational models for specific downstream tasks. The framework allows data owners to fine-tune models locally without uploading sensitive data to the LLM owner's servers. Specifically, the LLM owner sends a lightweight \"Adapter\" and a lossy compressed \"Emulator\" to the data owner. Using these smaller components, the data owner can then fine-tune the model solely on their private data. The Adapter, once fine-tuned, is returned to the model owner and integrated back into the large model to enhance its performance on the specific dataset.\n", + "\r\n", + "Offsite-Tuning addresses the challenge of unequal distribution of computational power and data. It allows thLLMel owner to enhance the model's capabilities without direct access to private data, while also enabling data owners who may not have the resources to train a full-scale model to fine-tune a portion of it using less computational power. This mutually beneficial arrangement accommodates both parties involve.\r\n", + "\r\n", + "Beyond the standard two-party setup involving the model owner and the data ownin FATE-LLM, er, Offsite-Tunframework ing is also extendable to scenarios with multiple data owners. FATE supports multi-party Offsite-Tuning, allowing multiple data owners to fine-tune and aggregate their Adapters locally, further enhancing the flexibility and applicability of this framewrFor more details of Offsite-tuning, please refer to the [original paper](https://arxiv.org/pdf/2302.04870.pdf).\r\n", + "\r\n", + "\r\n", + "\r\n", + "\r\n", + "\r\n" + ] + }, + { + "cell_type": "markdown", + "id": "2e7ac467-e5df-4bf3-8571-0a477ab4612d", + "metadata": {}, + "source": [ + "## Preliminary\n", + "\n", + "We strongly recommend you finish reading our NN tutorial to get familiar with Model and Dataset customizations: [NN Tutorials](https://github.com/FederatedAI/FATE/blob/master/doc/tutorial/pipeline/nn_tutorial/README.md)\n", + "You can add python path so that you can run codes in the notebook." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "f33516e8-0d28-4c97-bc38-ba28d60acf37", + "metadata": {}, + "outputs": [], + "source": [ + "import sys\n", + "your_path_to_fate_python = '/data/projects/fate/fate/python'\n", + "sys.path.append(your_path_to_fate_python)" + ] + }, + { + "cell_type": "markdown", + "id": "7309281b-5956-4158-9256-d6db230e086d", + "metadata": {}, + "source": [ + "## Define Main Model and Sub Model\n", + "\n", + "Main models are at server side and will provides weights of adapters and emulators to client sides, while Sub Models are at client side and will load adapters and emulators for local fine-tuning. In this chapter we will take a standard GPT2 as the example and show you how to quickly develop main model class and sub model class for offsite-tuning.\n", + "\n", + "### Base Classes and Interfaces\n", + "\n", + "The base classes for the Main and Sub Models are OffsiteTuningMainModel and OffsiteTuningSubModel, respectively. To build your own models upon these base classes, you need to:\n", + "\n", + "1. Implement three key interfaces: get_base_model, get_model_transformer_blocks, and forward. The get_base_model interface should return the full Main or Sub Model. Meanwhile, the get_model_transformer_blocks function should return a ModuleList of all transformer blocks present in your language model, enabling the extraction of emulators and adapters from these blocks. Finally, you're required to implement the forward process for model inference.\n", + "\n", + "2. Supply the parameters emulator_layer_num, adapter_top_layer_num, and adapter_bottom_layer_num to the parent class. This allows the framework to automatically generate the top and bottom adapters as well as the dropout emulator for you. Specifically, the top adapters are taken from the top of the transformer blocks, while the bottom adapters are taken from the bottom. The emulator uses a dropout emulator consistent with the paper's specifications. Once the adapter layers are removed, the emulator is formed by selecting transformer blocks at fixed intervals and finally stack them to make a dropout emulator.\n", + "\n", + "Our framework will automatically detect the emulator and adapters of a main model, and send them to clients. Clients' models them load the weights of emulators and adapters to get trainable models.\n", + "\n", + "### Example\n", + "\n", + "Let us take a look of our built-in GPT-2 model. It will be easy for you to build main models and sub models based on the framework. Please notice that the GPT2LMHeadSubModel's base model is intialized from a GPTConfig, that is to say, it's weights are random and need to load pretrained weights from server." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8611c115-0321-458f-b190-49dcb127a653", + "metadata": {}, + "outputs": [], + "source": [ + "from fate_llm.model_zoo.offsite_tuning.offsite_tuning_model import OffsiteTuningSubModel, OffsiteTuningMainModel\n", + "from transformers import GPT2LMHeadModel, GPT2Config\n", + "from torch import nn\n", + "import torch as t\n", + "\n", + "\n", + "class GPT2LMHeadMainModel(OffsiteTuningMainModel):\n", + "\n", + " def __init__(\n", + " self,\n", + " model_name_or_path,\n", + " emulator_layer_num: int,\n", + " adapter_top_layer_num: int = 2,\n", + " adapter_bottom_layer_num: int = 2):\n", + "\n", + " self.model_name_or_path = model_name_or_path\n", + " super().__init__(\n", + " emulator_layer_num,\n", + " adapter_top_layer_num,\n", + " adapter_bottom_layer_num)\n", + "\n", + " def get_base_model(self):\n", + " return GPT2LMHeadModel.from_pretrained(self.model_name_or_path)\n", + "\n", + " def get_model_transformer_blocks(self, model: GPT2LMHeadModel):\n", + " return model.transformer.h\n", + "\n", + " def forward(self, x):\n", + " return self.model(**x)\n", + "\n", + "class GPT2LMHeadSubModel(OffsiteTuningSubModel):\n", + "\n", + " def __init__(\n", + " self,\n", + " model_name_or_path,\n", + " emulator_layer_num: int,\n", + " adapter_top_layer_num: int = 2,\n", + " adapter_bottom_layer_num: int = 2,\n", + " fp16_mix_precision=False,\n", + " partial_weight_decay=None):\n", + "\n", + " self.model_name_or_path = model_name_or_path\n", + " self.emulator_layer_num = emulator_layer_num\n", + " self.adapter_top_layer_num = adapter_top_layer_num\n", + " self.adapter_bottom_layer_num = adapter_bottom_layer_num\n", + " super().__init__(\n", + " emulator_layer_num,\n", + " adapter_top_layer_num,\n", + " adapter_bottom_layer_num,\n", + " fp16_mix_precision)\n", + " self.partial_weight_decay = partial_weight_decay\n", + "\n", + " def get_base_model(self):\n", + " total_layer_num = self.emulator_layer_num + \\\n", + " self.adapter_top_layer_num + self.adapter_bottom_layer_num\n", + " config = GPT2Config.from_pretrained(self.model_name_or_path)\n", + " config.num_hidden_layers = total_layer_num\n", + " # initialize a model without pretrained weights\n", + " return GPT2LMHeadModel(config)\n", + "\n", + " def get_model_transformer_blocks(self, model: GPT2LMHeadModel):\n", + " return model.transformer.h\n", + " \n", + " def forward(self, x):\n", + " return self.model(**x)\n" + ] + }, + { + "cell_type": "markdown", + "id": "abd1f63f-afa7-4f09-a67e-63812ddcd801", + "metadata": {}, + "source": [ + "We can define a server side model and a client side model that can work together in the offsite-tuning:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "04870e76-11cc-4d79-a09e-b6fd16ed2f23", + "metadata": {}, + "outputs": [], + "source": [ + "model_main = GPT2LMHeadMainModel('gpt2', 4, 2, 2)\n", + "model_sub = GPT2LMHeadSubModel('gpt2', 4, 2, 2)" + ] + }, + { + "cell_type": "markdown", + "id": "19d34937-b4ae-436e-b4ea-1620fb80bed4", + "metadata": {}, + "source": [ + "### Share additional parameters with clients\n", + "\n", + "Additionally, beyond the weights of emulators and adapters, you may also want to share other model parameters, such as embedding weights, with your client partners. To achieve this, you'll need to implement two more interfaces: get_additional_param_state_dict and load_additional_param_state_dict for both the Main and Sub Models.\n", + "\n", + "### Special Attention for Large Objects\n", + "\n", + "Please note that special attention is required when you need to share large objects, any object potentially exceeding 2GB, such as embedding weights. You should slice these large objects to manage them more efficiently. Below is a code snippet demonstrating this practice, taken directly from FATE's native GPT-2 implementation:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "189fce0e-8e4d-4368-8e14-907b30ce0a49", + "metadata": {}, + "outputs": [], + "source": [ + "def get_additional_param_state_dict(self):\n", + " # get parameter of additional parameter\n", + " model = self.model\n", + " param_dict = {\n", + " 'wte': model.transformer.wte,\n", + " 'wpe': model.transformer.wpe,\n", + " 'last_ln_f': model.transformer.ln_f\n", + " }\n", + "\n", + " addition_weights = self.get_numpy_state_dict(param_dict)\n", + "\n", + " wte = addition_weights.pop('wte')\n", + " wte_dict = split_numpy_array(wte, 10, 'wte')\n", + " wpe = addition_weights.pop('wpe')\n", + " wpe_dict = split_numpy_array(wpe, 10, 'wpe')\n", + " addition_weights.update(wte_dict)\n", + " addition_weights.update(wpe_dict)\n", + " return addition_weights\n", + "\n", + "def load_additional_param_state_dict(self, submodel_weights: dict):\n", + " # load additional weights:\n", + " model = self.model\n", + " param_dict = {\n", + " 'wte': model.transformer.wte,\n", + " 'wpe': model.transformer.wpe,\n", + " 'last_ln_f': model.transformer.ln_f\n", + " }\n", + "\n", + " new_submodel_weight = {}\n", + " new_submodel_weight['last_ln_f'] = submodel_weights['last_ln_f']\n", + " wte_dict, wpe_dict = {}, {}\n", + " for k, v in submodel_weights.items():\n", + " if 'wte' in k:\n", + " wte_dict[k] = v\n", + " if 'wpe' in k:\n", + " wpe_dict[k] = v\n", + " wte = recover_numpy_array(wte_dict, 'wte')\n", + " wpe = recover_numpy_array(wpe_dict, 'wpe')\n", + " new_submodel_weight['wte'] = wte\n", + " new_submodel_weight['wpe'] = wpe\n", + "\n", + " self.load_numpy_state_dict(param_dict, new_submodel_weight)" + ] + }, + { + "cell_type": "markdown", + "id": "59d9aa6a-80e9-4130-8af1-c7d2bd0fbba3", + "metadata": {}, + "source": [ + "From these codes we can see that we use 'split_numpy_array, recover_numpy_array' to cut embedding weights into pieces and recover them." + ] + }, + { + "cell_type": "markdown", + "id": "dda6f5e3-d05a-4cdf-afd4-affbc162fce4", + "metadata": {}, + "source": [ + "## Submit a Offsite-tuning Task - A QA Task Sample with GPT2\n", + "\n", + "Now we are going to show you how to run a 2 party(server & client) offsite-tuning task using the GPT-2 model defined above. Before we submit the task we need to prepare the QA dataset.\n", + "\n", + "### Prepare QA Dataset - Sciq\n", + "\n", + "In this example, we use sciq dataset. You can use tools provided in our qa_dataset.py to tokenize the sciq dataset and save the tokenized result. " + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "84f6947e-f0a3-4a42-9549-a9776a15b66d", + "metadata": {}, + "outputs": [ + { + "ename": "ModuleNotFoundError", + "evalue": "No module named 'fate_llm'", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[10], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mfate_llm\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mdataset\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mqa_dataset\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m tokenize_qa_dataset\n\u001b[1;32m 2\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mtransformers\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m AutoTokenizer\n\u001b[1;32m 3\u001b[0m tokenizer_name_or_path \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124m/data/projects/fate/cwj/gpt2\u001b[39m\u001b[38;5;124m'\u001b[39m\n", + "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'fate_llm'" + ] + } + ], + "source": [ + "from fate_llm.dataset.qa_dataset import tokenize_qa_dataset\n", + "from transformers import AutoTokenizer\n", + "tokenizer_name_or_path = '/data/projects/fate/cwj/gpt2'\n", + "tokenizer = AutoTokenizer.from_pretrained(gpt2_path)\n", + "\n", + "if 'llama' in tokenizer_name_or_path:\n", + " tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path, unk_token=\"\", bos_token=\"\", eos_token=\"\", add_eos_token=True) \n", + " tokenizer.pad_token = tokenizer.eos_token\n", + "else:\n", + " tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path)\n", + "if 'gpt2' in tokenizer_name_or_path:\n", + " tokenizer.pad_token = tokenizer.eos_token\n", + "\n", + "import os\n", + "# bind data path to name & namespace\n", + "fate_project_path = os.path.abspath('../../../')\n", + "rs = tokenize_qa_dataset('sciq', tokenizer, fate_project_path + '/sciq/', seq_max_len=600) # we save the cache dataset to the fate root folder" + ] + }, + { + "cell_type": "markdown", + "id": "adabe89a-37be-4c64-bd83-4f8c8b80096f", + "metadata": {}, + "source": [ + "We can use our built-in QA dataset to load tokenized dataset, to see if everything is working correctly." + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "6500c2ba-bc39-4db4-b2ea-947fb09c334e", + "metadata": {}, + "outputs": [], + "source": [ + "from fate_llm.dataset.qa_dataset import QaDataset\n", + "\n", + "ds = QaDataset(tokenizer_name_or_path=tokenizer_name_or_path)\n", + "ds.load(fate_project_path + '/sciq/')" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "d6f62b60-eed0-4bd0-874e-ae3feeebb120", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "11679\n", + "600\n" + ] + } + ], + "source": [ + "print(len(ds)) # train set length\n", + "print(ds[0]['input_ids'].__len__()) # first sample length" + ] + }, + { + "cell_type": "markdown", + "id": "0609c63d-35a4-43bc-bd4b-f1c61adea587", + "metadata": {}, + "source": [ + "## Submit a Task\n", + "\n", + "Now the model and the dataset is prepared! We can submit a training task. \n", + "After we submit the task below, the following process will occur: The server and client each initialize their respective models. The server extracts shared parameters and sends them to the client. The client then loads these parameters and conducts training on a miniaturized GPT-2 model composed of an emulator and adaptesr onSciqP We speicify the OffsiteTuningTrainer via TrainerParam. If you are not familiar with trainer configuration, please refer to [FATE-NN Tutorial](https://github.com/FederatedAI/FATE/blob/master/doc/tutorial/pipeline/nn_tutorial/README.md).\n", + " Upon completion of the training, the client sends the adapter parameters back to the server. Since we are directly using Hugging Face's LMHeadGPT2, there's no need to supply a loss function. Simply inputting the preprocessed data and labels into the model will calculate the correct loss and proceed with gradient descent\n", + "\n", + "One thing to pay special attention to is that Offsite-Tuning differs from FedAvg within FATE. In Offsite-Tuning, the server (the arbiter role) needs to initialize the model. Therefore, please refer to the example below and set the 'nn_component' parameters separately for the client and the server. Also, don't forget to add the 'server_init=True' parameter to the server; otherwise, the arbiter side will not initialize the model.\n", + "\n", + "To make this a quick demo, we only select 100 samples from the origin qa datset, see 'select_num=100' in the DatasetParam." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "c9113d10-c3e7-4875-9502-ce46aa0b86b1", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 1, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import torch as t\n", + "from torch import nn\n", + "from pipeline import fate_torch_hook\n", + "from pipeline.component import HomoNN\n", + "from pipeline.backend.pipeline import PipeLine\n", + "from pipeline.component import Reader, Evaluation, DataTransform\n", + "from pipeline.interface import Data, Model\n", + "\n", + "t = fate_torch_hook(t)\n", + "\n", + "import os\n", + "# bind data path to name & namespace\n", + "fate_project_path = os.path.abspath('../../../')\n", + "guest = 9997\n", + "arbiter = 9997\n", + "pipeline = PipeLine().set_initiator(role='guest', party_id=guest).set_roles(guest=guest, arbiter=arbiter)\n", + "\n", + "# bind data path with name & namespace\n", + "data_0 = {\"name\": \"sciq\", \"namespace\": \"experiment\"}\n", + "data_path_0 = fate_project_path + '/sciq/'\n", + "pipeline.bind_table(name=data_0['name'], namespace=data_0['namespace'], path=data_path_0)\n", + "\n", + "reader_0 = Reader(name=\"reader_0\")\n", + "reader_0.get_party_instance(role='guest', party_id=guest).component_param(table=data_0)\n", + "\n", + "gpt2_type = '/data/projects/fate/cwj/gpt2/'\n", + "\n", + "from pipeline.component.nn import DatasetParam\n", + "dataset_param = DatasetParam(dataset_name='qa_dataset', tokenizer_name_or_path=gpt2_type, select_num=100)\n", + "\n", + "from pipeline.component.homo_nn import TrainerParam # Interface\n", + "sub_model_client = t.nn.CustModel(module_name='offsite_tuning.gpt2_ot', class_name='GPT2LMHeadSubModel', model_name_or_path=gpt2_type \\\n", + " ,emulator_layer_num=4, adapter_top_layer_num=2, adapter_bottom_layer_num=2)\n", + "main_model_server = t.nn.CustModel(module_name='offsite_tuning.gpt2_ot', class_name='GPT2LMHeadMainModel', model_name_or_path=gpt2_type \\\n", + " ,emulator_layer_num=4, adapter_top_layer_num=2, adapter_bottom_layer_num=2)\n", + "\n", + "nn_component = HomoNN(name='nn_0')\n", + "\n", + "nn_component.get_party_instance(role='guest', party_id=guest).component_param(model=sub_model_client, dataset=dataset_param, # dataset\n", + " trainer=TrainerParam(trainer_name='offsite_tuning_trainer', epochs=3, batch_size=4, collate_fn='DataCollatorForTokenClassification', task_type='causal_ml', \\\n", + " save_to_local_dir=True, cuda=0),\n", + " optimizer=t.optim.Adam(lr=5e-5)\n", + " )\n", + "nn_component.get_party_instance(role='arbiter', party_id=arbiter).component_param(model=main_model_server, \n", + " trainer=TrainerParam(trainer_name='offsite_tuning_trainer', collate_fn='DataCollatorForTokenClassification', save_to_local_dir=True),\n", + " # Attention here\n", + " server_init=True # This parameter must be set True !!!!!!!!!!!\n", + " )\n", + "pipeline.add_component(reader_0)\n", + "pipeline.add_component(nn_component, data=Data(train_data=reader_0.output.data))\n", + "pipeline.compile()" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "74497742-4030-4a7a-a13e-2c020da47cd1", + "metadata": {}, + "outputs": [], + "source": [ + "pipeline.fit()" + ] + }, + { + "cell_type": "markdown", + "id": "b33b2e2b-3b53-4881-8db6-a67e1293e88b", + "metadata": {}, + "source": [ + "## Add Deepspeed Setting\n", + "\n", + "By simply adding a ds_config, we can run our task with a deepspeed backend:" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "6e8f063b-263c-4ba5-b2ba-98a86ce38b94", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import torch as t\n", + "from torch import nn\n", + "from pipeline import fate_torch_hook\n", + "from pipeline.component import HomoNN\n", + "from pipeline.backend.pipeline import PipeLine\n", + "from pipeline.component import Reader, Evaluation, DataTransform\n", + "from pipeline.interface import Data, Model\n", + "\n", + "t = fate_torch_hook(t)\n", + "\n", + "import os\n", + "# bind data path to name & namespace\n", + "fate_project_path = os.path.abspath('../../../')\n", + "guest = 9997\n", + "arbiter = 9997\n", + "pipeline = PipeLine().set_initiator(role='guest', party_id=guest).set_roles(guest=guest, arbiter=arbiter)\n", + "\n", + "# bind data path with name & namespace\n", + "data_0 = {\"name\": \"sciq\", \"namespace\": \"experiment\"}\n", + "data_path_0 = fate_project_path + '/sciq/'\n", + "pipeline.bind_table(name=data_0['name'], namespace=data_0['namespace'], path=data_path_0)\n", + "\n", + "reader_0 = Reader(name=\"reader_0\")\n", + "reader_0.get_party_instance(role='guest', party_id=guest).component_param(table=data_0)\n", + "\n", + "# deepspeed config\n", + "ds_config = {\n", + " \"train_micro_batch_size_per_gpu\": 2,\n", + " \"gradient_accumulation_steps\": 2,\n", + " \"optimizer\": {\n", + " \"type\": \"AdamW\",\n", + " \"params\": {\n", + " \"lr\": 5e-5\n", + " }\n", + " }\n", + " ,\n", + " \"fp16\": {\n", + " \"enabled\": False\n", + " }\n", + " ,\n", + " \"zero_optimization\": {\n", + " \"stage\": 1,\n", + " \"offload_optimizer\": {\n", + " \"device\": \"cpu\"\n", + " },\n", + " \"contiguous_gradients\": True,\n", + " \"overlap_comm\": True\n", + " }\n", + "}\n", + "\n", + "gpt2_type = '/data/projects/fate/cwj/gpt2/'\n", + "\n", + "from pipeline.component.nn import DatasetParam\n", + "dataset_param = DatasetParam(dataset_name='qa_dataset', tokenizer_name_or_path=gpt2_type, select_num=100)\n", + "\n", + "from pipeline.component.homo_nn import TrainerParam # Interface\n", + "sub_model_client = t.nn.CustModel(module_name='offsite_tuning.gpt2_ot', class_name='GPT2LMHeadSubModel', model_name_or_path=gpt2_type \\\n", + " ,emulator_layer_num=4, adapter_top_layer_num=2, adapter_bottom_layer_num=2)\n", + "main_model_server = t.nn.CustModel(module_name='offsite_tuning.gpt2_ot', class_name='GPT2LMHeadMainModel', model_name_or_path=gpt2_type \\\n", + " ,emulator_layer_num=4, adapter_top_layer_num=2, adapter_bottom_layer_num=2)\n", + "\n", + "nn_component = HomoNN(name='nn_0')\n", + "\n", + "nn_component.get_party_instance(role='guest', party_id=guest).component_param(model=sub_model_client, dataset=dataset_param, # dataset\n", + " trainer=TrainerParam(trainer_name='offsite_tuning_trainer', epochs=3, batch_size=4, collate_fn='DataCollatorForTokenClassification', task_type='causal_ml', \\\n", + " save_to_local_dir=True),\n", + " optimizer=t.optim.Adam(lr=5e-5)\n", + " )\n", + "nn_component.get_party_instance(role='arbiter', party_id=arbiter).component_param(model=main_model_server, \n", + " trainer=TrainerParam(trainer_name='offsite_tuning_trainer', collate_fn='DataCollatorForTokenClassification', save_to_local_dir=True),\n", + " # Attention here\n", + " server_init=True # This parameter must be set True !!!!!!!!!!!\n", + " )\n", + "pipeline.add_component(reader_0)\n", + "pipeline.add_component(nn_component, data=Data(train_data=reader_0.output.data))\n", + "pipeline.compile()" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "23320cb9-d06a-44ac-8966-398b0f7bbaae", + "metadata": {}, + "outputs": [], + "source": [ + "from pipeline.runtime.entity import JobParameters\n", + "pipeline.fit(JobParameters(task_conf={\n", + " \"nn_0\": {\n", + " \"launcher\": \"deepspeed\",\n", + " \"world_size\": 4\n", + " }\n", + "}))" + ] + }, + { + "cell_type": "markdown", + "id": "97249681-c3a3-43bd-8167-7ae3f4e1616b", + "metadata": {}, + "source": [ + "## Offsite-tuning + Multi Client Federation\n", + "\n", + "\n", + "The Offsite-Tuning + FedAVG federation is configured based on the standard Offsite-Tuning. The setup is a bit more complex, but we will walk you through it step by step. The pipeline code below contains detailed comments. When reading, please pay attention to the following points:\n", + "\n", + "1. In a multi-party scenario, please fill in different party_ids based on your deployment.\n", + "2. The operation to bind the data path with the name & namespace needs to be run on the machines of all parties. For convenience, we've placed the code in one location.\n", + "3. When configuring Trainer parameters, make sure to add the 'need_aggregate=True' parameter to the OffsiteTuningTrainer for each client and server. So adapters will be aggregated during training." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fdbdc60c-a948-4be3-bba6-519d8640b0a9", + "metadata": {}, + "outputs": [], + "source": [ + "import torch as t\n", + "from torch import nn\n", + "from pipeline import fate_torch_hook\n", + "from pipeline.component import HomoNN\n", + "from pipeline.backend.pipeline import PipeLine\n", + "from pipeline.component import Reader, Evaluation, DataTransform\n", + "from pipeline.interface import Data, Model\n", + "\n", + "t = fate_torch_hook(t)\n", + "\n", + "import os\n", + "# bind data path to name & namespace\n", + "fate_project_path = os.path.abspath('../../../')\n", + "guest = 9997\n", + "hosts = [9999, 10000]\n", + "arbiter = 9997\n", + "pipeline = PipeLine().set_initiator(role='guest', party_id=guest).set_roles(guest=guest, arbiter=arbiter, host=hosts)\n", + "\n", + "data_9997 = {\"name\": \"sciq-9997-gpt2\", \"namespace\": \"experiment\"}\n", + "data_9999 = {\"name\": \"sciq-9999-gpt2\", \"namespace\": \"experiment\"}\n", + "data_10000 = {\"name\": \"sciq-10000-gpt2\", \"namespace\": \"experiment\"}\n", + "\n", + "# run the binding codes on 9997\n", + "data_path_9997 = fate_project_path + '/sciq/'\n", + "pipeline.bind_table(name=data_9997['name'], namespace=data_9997['namespace'], path=data_path_9997)\n", + "\n", + "# run the binding codes on 9998\n", + "data_path_9999 = fate_project_path + '/sciq/'\n", + "pipeline.bind_table(name=data_9999['name'], namespace=data_9999['namespace'], path=data_path_9999)\n", + "\n", + "# run the binding codes on 10000\n", + "data_path_10000 = fate_project_path + '/sciq/'\n", + "pipeline.bind_table(name=data_10000['name'], namespace=data_10000['namespace'], path=data_path_10000)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "253499d2-37a1-4fbe-9427-646d51fd6edd", + "metadata": {}, + "outputs": [], + "source": [ + "# deepspeed config\n", + "ds_config = {\n", + " \"train_micro_batch_size_per_gpu\": 2,\n", + " \"gradient_accumulation_steps\": 2,\n", + " \"optimizer\": {\n", + " \"type\": \"AdamW\",\n", + " \"params\": {\n", + " \"lr\": 5e-5\n", + " }\n", + " }\n", + " ,\n", + " \"fp16\": {\n", + " \"enabled\": False\n", + " }\n", + " ,\n", + " \"zero_optimization\": {\n", + " \"stage\": 1,\n", + " \"offload_optimizer\": {\n", + " \"device\": \"cpu\"\n", + " },\n", + " \"contiguous_gradients\": True,\n", + " \"overlap_comm\": True\n", + " }\n", + "}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "909dc4fb-8d1e-4831-a6f7-744cf7d826c1", + "metadata": {}, + "outputs": [], + "source": [ + "model_path = 'gpt2'" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2283025d-9acf-4ffa-8a25-648aa619528e", + "metadata": {}, + "outputs": [], + "source": [ + "reader_0 = Reader(name=\"reader_0\")\n", + "reader_0.get_party_instance(role='guest', party_id=guest).component_param(table=data_9997)\n", + "reader_0.get_party_instance(role='host', party_id=hosts[0]).component_param(table=data_9999)\n", + "reader_0.get_party_instance(role='host', party_id=hosts[1]).component_param(table=data_10000)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5ce1cc8a-1003-4379-aa4f-bf3fa28237c8", + "metadata": {}, + "outputs": [], + "source": [ + "from pipeline.component.nn import DatasetParam\n", + "\n", + "# This demo utilizes the same dataset but selects distinct segments to mimic an equal data distribution across different parties. \n", + "# We adopt this strategy for the sake of convenience.\n", + "dataset_param_0 = DatasetParam(dataset_name='qa_ds', tokenizer_name_or_path=model_path, start_idx=0, select_num=3893)\n", + "dataset_param_1 = DatasetParam(dataset_name='qa_ds', tokenizer_name_or_path=model_path, start_idx=3893, select_num=3893)\n", + "dataset_param_2 = DatasetParam(dataset_name='qa_ds', tokenizer_name_or_path=model_path, start_idx=7786, select_num=3893)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "50ea1168-417c-41da-b7da-b2625c26af50", + "metadata": {}, + "outputs": [], + "source": [ + "from pipeline.component.homo_nn import TrainerParam # Interface\n", + "\n", + "# define model structure\n", + "sub_model_client = t.nn.CustModel(module_name='offsite_tuning.gpt2_ot', class_name='GPT2LMHeadSubModel', model_name_or_path=model_path \\\n", + " ,emulator_layer_num=4, adapter_top_layer_num=2, adapter_bottom_layer_num=2)\n", + "main_model_server = t.nn.CustModel(module_name='offsite_tuning.gpt2_ot', class_name='GPT2LMHeadMainModel', model_name_or_path=model_path \\\n", + " ,emulator_layer_num=4, adapter_top_layer_num=2, adapter_bottom_layer_num=2)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "dffcace2-0d59-411e-856f-512e7eafd793", + "metadata": {}, + "outputs": [], + "source": [ + "nn_component = HomoNN(name='nn_0')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3c854117-3fe1-4a7b-9505-bb131d95f178", + "metadata": {}, + "outputs": [], + "source": [ + "epochs = 8\n", + "# We have 4 party to set\n", + "# Please make sure that need_aggregate is True, and epochs parameter of all parties are the same\n", + "nn_component.get_party_instance(role='guest', party_id=guest).component_param(model=sub_model_client, dataset=dataset_param_0, # dataset\n", + " trainer=TrainerParam(trainer_name='offsite_tuning_trainer', epochs=epochs, batch_size=4, collate_fn='DataCollatorForTokenClassification', task_type='causal_ml', \\\n", + " save_to_local_dir=True, need_aggregate=True), ds_config=ds_config)\n", + "\n", + "nn_component.get_party_instance(role='host', party_id=hosts[0]).component_param(model=sub_model_client, dataset=dataset_param_1, # dataset\n", + " trainer=TrainerParam(trainer_name='offsite_tuning_trainer', epochs=epochs, batch_size=4, collate_fn='DataCollatorForTokenClassification', task_type='causal_ml', \\\n", + " save_to_local_dir=True, need_aggregate=True), ds_config=ds_config)\n", + "\n", + "nn_component.get_party_instance(role='host', party_id=hosts[1]).component_param(model=sub_model_client, dataset=dataset_param_2, # dataset\n", + " trainer=TrainerParam(trainer_name='offsite_tuning_trainer', epochs=epochs, batch_size=4, collate_fn='DataCollatorForTokenClassification', task_type='causal_ml', \\\n", + " save_to_local_dir=True, need_aggregate=True), ds_config=ds_config)\n", + "\n", + "\n", + "nn_component.get_party_instance(role='arbiter', party_id=arbiter).component_param(model=main_model_server,\n", + " trainer=TrainerParam(trainer_name='offsite_tuning_trainer', epochs=epochs, save_to_local_dir=True,\n", + " need_aggregate=True),\n", + " server_init=True\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a5d173c1-5d72-4d25-9b78-91e6ef766d8c", + "metadata": {}, + "outputs": [], + "source": [ + "pipeline.add_component(reader_0)\n", + "pipeline.add_component(nn_component, data=Data(train_data=reader_0.output.data))\n", + "pipeline.compile()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f6674178-2c59-43d6-b6ce-888e426f27b3", + "metadata": {}, + "outputs": [], + "source": [ + "from pipeline.runtime.entity import JobParameters\n", + "pipeline.fit(JobParameters(task_conf={\n", + " \"nn_0\": {\n", + " \"launcher\": \"deepspeed\",\n", + " \"world_size\": 4\n", + " }\n", + "}))" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.13" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/doc/tutorial/offsite_tuning/Offsite_tuning_tutorial_0.ipynb b/doc/tutorial/offsite_tuning/Offsite_tuning_tutorial_0.ipynb deleted file mode 100644 index e69de29..0000000 From 8317f55bc26b3cd7f5f72a0804e4f776c268fd43 Mon Sep 17 00:00:00 2001 From: cwj Date: Tue, 5 Sep 2023 17:39:43 +0800 Subject: [PATCH 16/23] remove path Signed-off-by: cwj --- .../Offsite_tuning_tutorial.ipynb | 44 +++++++------------ 1 file changed, 16 insertions(+), 28 deletions(-) diff --git a/doc/tutorial/offsite_tuning/Offsite_tuning_tutorial.ipynb b/doc/tutorial/offsite_tuning/Offsite_tuning_tutorial.ipynb index 23bbb04..b8b90ea 100644 --- a/doc/tutorial/offsite_tuning/Offsite_tuning_tutorial.ipynb +++ b/doc/tutorial/offsite_tuning/Offsite_tuning_tutorial.ipynb @@ -28,15 +28,15 @@ "## Introduction of Offsite-tuning\n", "\n", "Offsite-Tuning is a novel approach designed for the efficient and privacy-preserving adaptation of large foundational models for specific downstream tasks. The framework allows data owners to fine-tune models locally without uploading sensitive data to the LLM owner's servers. Specifically, the LLM owner sends a lightweight \"Adapter\" and a lossy compressed \"Emulator\" to the data owner. Using these smaller components, the data owner can then fine-tune the model solely on their private data. The Adapter, once fine-tuned, is returned to the model owner and integrated back into the large model to enhance its performance on the specific dataset.\n", - "\r\n", - "Offsite-Tuning addresses the challenge of unequal distribution of computational power and data. It allows thLLMel owner to enhance the model's capabilities without direct access to private data, while also enabling data owners who may not have the resources to train a full-scale model to fine-tune a portion of it using less computational power. This mutually beneficial arrangement accommodates both parties involve.\r\n", - "\r\n", - "Beyond the standard two-party setup involving the model owner and the data ownin FATE-LLM, er, Offsite-Tunframework ing is also extendable to scenarios with multiple data owners. FATE supports multi-party Offsite-Tuning, allowing multiple data owners to fine-tune and aggregate their Adapters locally, further enhancing the flexibility and applicability of this framewrFor more details of Offsite-tuning, please refer to the [original paper](https://arxiv.org/pdf/2302.04870.pdf).\r\n", - "\r\n", - "\r\n", - "\r\n", - "\r\n", - "\r\n" + "\n", + "Offsite-Tuning addresses the challenge of unequal distribution of computational power and data. It allows thLLMel owner to enhance the model's capabilities without direct access to private data, while also enabling data owners who may not have the resources to train a full-scale model to fine-tune a portion of it using less computational power. This mutually beneficial arrangement accommodates both parties involve.\n", + "\n", + "Beyond the standard two-party setup involving the model owner and the data ownin FATE-LLM, er, Offsite-Tunframework ing is also extendable to scenarios with multiple data owners. FATE supports multi-party Offsite-Tuning, allowing multiple data owners to fine-tune and aggregate their Adapters locally, further enhancing the flexibility and applicability of this framewrFor more details of Offsite-tuning, please refer to the [original paper](https://arxiv.org/pdf/2302.04870.pdf).\n", + "\n", + "\n", + "\n", + "\n", + "\n" ] }, { @@ -58,7 +58,7 @@ "outputs": [], "source": [ "import sys\n", - "your_path_to_fate_python = '/data/projects/fate/fate/python'\n", + "your_path_to_fate_python = 'xxx/fate/fate/python'\n", "sys.path.append(your_path_to_fate_python)" ] }, @@ -268,26 +268,14 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 1, "id": "84f6947e-f0a3-4a42-9549-a9776a15b66d", "metadata": {}, - "outputs": [ - { - "ename": "ModuleNotFoundError", - "evalue": "No module named 'fate_llm'", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[10], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mfate_llm\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mdataset\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mqa_dataset\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m tokenize_qa_dataset\n\u001b[1;32m 2\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mtransformers\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m AutoTokenizer\n\u001b[1;32m 3\u001b[0m tokenizer_name_or_path \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124m/data/projects/fate/cwj/gpt2\u001b[39m\u001b[38;5;124m'\u001b[39m\n", - "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'fate_llm'" - ] - } - ], + "outputs": [], "source": [ "from fate_llm.dataset.qa_dataset import tokenize_qa_dataset\n", "from transformers import AutoTokenizer\n", - "tokenizer_name_or_path = '/data/projects/fate/cwj/gpt2'\n", + "tokenizer_name_or_path = 'gpt2'\n", "tokenizer = AutoTokenizer.from_pretrained(gpt2_path)\n", "\n", "if 'llama' in tokenizer_name_or_path:\n", @@ -404,7 +392,7 @@ "reader_0 = Reader(name=\"reader_0\")\n", "reader_0.get_party_instance(role='guest', party_id=guest).component_param(table=data_0)\n", "\n", - "gpt2_type = '/data/projects/fate/cwj/gpt2/'\n", + "gpt2_type = 'gpt2'\n", "\n", "from pipeline.component.nn import DatasetParam\n", "dataset_param = DatasetParam(dataset_name='qa_dataset', tokenizer_name_or_path=gpt2_type, select_num=100)\n", @@ -520,7 +508,7 @@ " }\n", "}\n", "\n", - "gpt2_type = '/data/projects/fate/cwj/gpt2/'\n", + "gpt2_type = 'gpt2'\n", "\n", "from pipeline.component.nn import DatasetParam\n", "dataset_param = DatasetParam(dataset_name='qa_dataset', tokenizer_name_or_path=gpt2_type, select_num=100)\n", @@ -794,7 +782,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.13" + "version": "3.8.16" } }, "nbformat": 4, From 47c46399f0243cc4466a9f35f6e9b41a0c4bf85c Mon Sep 17 00:00:00 2001 From: mgqa34 Date: Wed, 6 Sep 2023 02:23:45 +0800 Subject: [PATCH 17/23] pellm update: support chatglm2, bloom Signed-off-by: mgqa34 --- python/fate_llm/dataset/glm_tokenizer.py | 11 ++++-- ...llama_tokenizer.py => prompt_tokenizer.py} | 17 +++++---- python/fate_llm/model_zoo/pellm/bloom.py | 36 +++++++++++++++++++ 3 files changed, 55 insertions(+), 9 deletions(-) rename python/fate_llm/dataset/{llama_tokenizer.py => prompt_tokenizer.py} (87%) create mode 100644 python/fate_llm/model_zoo/pellm/bloom.py diff --git a/python/fate_llm/dataset/glm_tokenizer.py b/python/fate_llm/dataset/glm_tokenizer.py index 99f6d13..8a7ddf9 100644 --- a/python/fate_llm/dataset/glm_tokenizer.py +++ b/python/fate_llm/dataset/glm_tokenizer.py @@ -28,7 +28,8 @@ def __init__(self, truncation=True, text_max_length=256, trust_remote_code=True, prompt_template=None, prompt_column="content", - response_column="summary" + response_column="summary", + version=1 ): super(GLMTokenizerDataset, self).__init__() @@ -44,6 +45,8 @@ def __init__(self, truncation=True, text_max_length=256, if pad_token is not None: self.tokenizer.add_special_tokens({'pad_token': pad_token}) + self._version = version + self.prompt_template = prompt_template if prompt_template else PROMPT_TEMPLATE self.prompt_column = prompt_column self.response_column = response_column @@ -69,7 +72,11 @@ def _process_data(self, line): input_ids = self.tokenizer.build_inputs_with_special_tokens( prompt_ids, target_ids) - seq_length = input_ids.index(self.tokenizer.bos_token_id) + if self._version == 1: + seq_length = input_ids.index(self.tokenizer.bos_token_id) + else: + seq_length = len(prompt_ids) + labels = [-100] * seq_length + input_ids[seq_length:] return { diff --git a/python/fate_llm/dataset/llama_tokenizer.py b/python/fate_llm/dataset/prompt_tokenizer.py similarity index 87% rename from python/fate_llm/dataset/llama_tokenizer.py rename to python/fate_llm/dataset/prompt_tokenizer.py index a71b5e9..1aff07b 100644 --- a/python/fate_llm/dataset/llama_tokenizer.py +++ b/python/fate_llm/dataset/prompt_tokenizer.py @@ -14,14 +14,14 @@ # limitations under the License. # import pandas as pd -from transformers import LlamaTokenizer +from transformers import AutoTokenizer from federatedml.nn.dataset.base import Dataset PROMPT_TEMPLATE = "{prompt}" -class LLAMATokenizerDataset(Dataset): +class PromptTokenizerDataset(Dataset): def __init__(self, text_max_length=256, tokenizer_name_or_path=None, padding=False, padding_side='left', @@ -35,17 +35,20 @@ def __init__(self, text_max_length=256, response_column="summary", ): - super(LLAMATokenizerDataset, self).__init__() + super(PromptTokenizerDataset, self).__init__() self.tokenizer = None self.padding = padding self.add_special_tokens = add_special_tokens self.max_length = text_max_length self.tokenizer_name_or_path = tokenizer_name_or_path - self.tokenizer = LlamaTokenizer.from_pretrained( + self.tokenizer = AutoTokenizer.from_pretrained( self.tokenizer_name_or_path, add_eos_token=add_eos_token) - self.tokenizer.pad_token_id = pad_token_id - self.tokenizer.bos_token_id = bos_token_id - self.tokenizer.eos_token_id = eos_token_id + if pad_token_id is not None: + self.tokenizer.pad_token_id = pad_token_id + if bos_token_id is not None: + self.tokenizer.bos_token_id = bos_token_id + if eos_token_id is not None: + self.tokenizer.eos_token_id = eos_token_id self.tokenizer.padding_side = padding_side self.prompt_template = prompt_template if prompt_template else PROMPT_TEMPLATE diff --git a/python/fate_llm/model_zoo/pellm/bloom.py b/python/fate_llm/model_zoo/pellm/bloom.py new file mode 100644 index 0000000..ae48925 --- /dev/null +++ b/python/fate_llm/model_zoo/pellm/bloom.py @@ -0,0 +1,36 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from transformers import BloomConfig +from transformers import BloomForCausalLM +from fate_llm.model_zoo.pellm.parameter_efficient_llm import PELLM + + +class BloomForCausalLM(PELLM): + + config_class = BloomConfig + model_loader = BloomForCausalLM + + def __init__(self, config: dict = None, + pretrained_path: str = None, + peft_type: str = None, + peft_config: dict = None, + **kwargs + ) -> None: + + if config is None and pretrained_path is None: + config = BloomConfig().to_dict() # use default model setting + super().__init__(config=config, pretrained_path=pretrained_path, + peft_type=peft_type, peft_config=peft_config, **kwargs) From 3e74567bac434a5d0946c501f1224dd4eb135ed7 Mon Sep 17 00:00:00 2001 From: mgqa34 Date: Wed, 6 Sep 2023 10:56:11 +0800 Subject: [PATCH 18/23] update doc: release note and builtin models Signed-off-by: mgqa34 --- README.md | 2 +- RELEASE.md | 15 +++++++++++++++ doc/tutorial/builtin_models.md | 5 ++++- 3 files changed, 20 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index ab12b0e..42f01d0 100644 --- a/README.md +++ b/README.md @@ -25,4 +25,4 @@ Use [FATE-LLM deployment packages](https://github.com/FederatedAI/FATE/wiki/Down ## Quick Start - [Federated ChatGLM-6B Training](./doc/tutorial/ChatGLM-6B_ds.ipynb) - [GPT-2 Training](./doc/tutorial/GPT2-example.ipynb) -- [Builtin Models](./doc/tutorial/builtin_models.md) \ No newline at end of file +- [Builtin Models In PELLM](./doc/tutorial/builtin_models.md) \ No newline at end of file diff --git a/RELEASE.md b/RELEASE.md index 9c59130..8beb3b5 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -1,3 +1,18 @@ +## Release 1.2.0 +### Major Features and Improvements +* Support Offsite-Tuning: + * Standard 2-party Offsite-Tuning and Offsite-Tuning + FedAVG now supported + * Framework available for Emulator and Adapter development + * New Offsite-Tuning Trainer introduced + * Includes built-in models such as GPT-2 family, Llama7b, and Bloom family +* Support FedIPR: + * Introduced WatermarkDataset as the foundational dataset class for backdoor-based watermarks + * Added SignConv and SignLayerNorm blocks for feature-based watermark models + * New FedIPR Trainer available + * Built-in models with feature-based watermarks include Alexnet, Resnet18, DistilBert, and GPT2 +* More models support parameter-efficient fine-tuning: ChatGLM2-6B and Bloom-7b1 + + ## Release 1.2.0 ### Major Features and Improvements * Support Federated Training of LLaMA-7B with parameter-efficient fine-tuning. diff --git a/doc/tutorial/builtin_models.md b/doc/tutorial/builtin_models.md index 06044e7..5069246 100644 --- a/doc/tutorial/builtin_models.md +++ b/doc/tutorial/builtin_models.md @@ -7,7 +7,10 @@ After reading the training tutorial above, it's easy to use other models listing | Model | ModuleName | ClassName | DataSetName | | -------------- | ----------------- | --------------------------------- | ---------------- | -| LLaMA-7B | pellm.llama | LLAMAForCausalLM | llama_tokenizer | +| Bloom-7B1 | pellm.bloom | BloomForCausalLM | prompt_tokenizer | +| LLaMA-2-7B | pellm.llama | LLAMAForCausalLM | prompt_tokenizer | +| LLaMA-7B | pellm.llama | LLAMAForCausalLM | prompt_tokenizer | +| ChatGLM2-6B | pellm.chatglm | ChatGLMForConditionalGeneration | glm_tokenizer | | ChatGLM-6B | pellm.chatglm | ChatGLMForConditionalGeneration | glm_tokenizer | | GPT-2 | pellm.gpt2 | GPT2 | nlp_tokenizer | | ALBERT | pellm.albert | Albert | nlp_tokenizer | From 086f667480bc6780613f947688d766bcd05e4bcc Mon Sep 17 00:00:00 2001 From: mgqa34 Date: Wed, 6 Sep 2023 14:36:13 +0800 Subject: [PATCH 19/23] update release Signed-off-by: mgqa34 --- RELEASE.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/RELEASE.md b/RELEASE.md index 8beb3b5..67652e4 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -1,16 +1,16 @@ ## Release 1.2.0 ### Major Features and Improvements -* Support Offsite-Tuning: - * Standard 2-party Offsite-Tuning and Offsite-Tuning + FedAVG now supported +* FTL-LLM(Fedrated Learning + Transfer Learning + LLM) + * Standard Offsite-Tuning and Extended Offsite-Tuning(Federated Offsite-Tuning+)now supported * Framework available for Emulator and Adapter development * New Offsite-Tuning Trainer introduced * Includes built-in models such as GPT-2 family, Llama7b, and Bloom family -* Support FedIPR: +* FedIPR * Introduced WatermarkDataset as the foundational dataset class for backdoor-based watermarks * Added SignConv and SignLayerNorm blocks for feature-based watermark models * New FedIPR Trainer available * Built-in models with feature-based watermarks include Alexnet, Resnet18, DistilBert, and GPT2 -* More models support parameter-efficient fine-tuning: ChatGLM2-6B and Bloom-7b1 +* More models support parameter-efficient fine-tuning: ChatGLM2-6B and Bloom-7B1 ## Release 1.2.0 From 7915e4d6bbf7281355397ccbac6872694148caad Mon Sep 17 00:00:00 2001 From: cwj Date: Wed, 6 Sep 2023 15:38:25 +0800 Subject: [PATCH 20/23] fix parameter bug Signed-off-by: cwj --- python/fate_llm/trainer/offsite_tuning_trainer.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python/fate_llm/trainer/offsite_tuning_trainer.py b/python/fate_llm/trainer/offsite_tuning_trainer.py index 19b2cd0..c48bc30 100644 --- a/python/fate_llm/trainer/offsite_tuning_trainer.py +++ b/python/fate_llm/trainer/offsite_tuning_trainer.py @@ -287,8 +287,7 @@ def server_aggregate_procedure(self, extra_data={}): self.server_agg = SecureAggServer( self.secure_aggregate, communicate_match_suffix=self.comm_suffix, - clients=clients, - lm_aggregate=True + clients=clients ) from federatedml.framework.homo.blocks import CommunicatorTransVar self.model_transvar = CommunicatorTransVar(clients=clients, prefix='model', disable_gc=True) From f6d536fd7828041302b4427d73fa6a2ad71ea00c Mon Sep 17 00:00:00 2001 From: cwj Date: Thu, 7 Sep 2023 19:56:04 +0800 Subject: [PATCH 21/23] model support bf16 Signed-off-by: cwj --- .../offsite_tuning/offsite_tuning_model.py | 20 ++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/python/fate_llm/model_zoo/offsite_tuning/offsite_tuning_model.py b/python/fate_llm/model_zoo/offsite_tuning/offsite_tuning_model.py index 1334148..1b34292 100644 --- a/python/fate_llm/model_zoo/offsite_tuning/offsite_tuning_model.py +++ b/python/fate_llm/model_zoo/offsite_tuning/offsite_tuning_model.py @@ -119,6 +119,16 @@ def load_additional_param_state_dict(self, submodel_weights: dict): # load additional weights: pass + def _get_numpy_arr(self, v): + if v.dtype == t.bfloat16: + # float 32 + v = v.detach().cpu().float().cpu().numpy() + else: + v = v.detach().cpu().numpy() + + return v + + def load_numpy_state_dict(self, module_dict, state_dict): param_dict = module_dict @@ -135,20 +145,20 @@ def get_numpy_state_dict(self, module_dict): weight_dict = {} for k, v in module_dict.items(): weight_dict[k] = { - k: v.detach().cpu().numpy() for k, + k: self._get_numpy_arr(v) for k, v in v.state_dict().items()} return weight_dict def get_submodel_weights(self) -> dict: submodel_weights = { "emulator": { - k: v.detach().cpu().numpy() for k, + k: self._get_numpy_arr(v) for k, v in self.get_emulator().state_dict().items()}, "adapter_top": { - k: v.detach().cpu().numpy() for k, + k: self._get_numpy_arr(v) for k, v in self.get_adapter_top().state_dict().items()}, "adapter_bottom": { - k: v.detach().cpu().numpy() for k, + k: self._get_numpy_arr(v) for k, v in self.get_adapter_bottom().state_dict().items()}} addition_weights = self.get_additional_param_state_dict() submodel_weights.update(addition_weights) @@ -200,4 +210,4 @@ def post_initialization(self): param.requires_grad = True for param in self.adapter_bottom.parameters(): param.data = param.data.float() - param.requires_grad = True + param.requires_grad = True \ No newline at end of file From 2517dd737e8abfb2d9bac4b5876f7f3d17baf0a1 Mon Sep 17 00:00:00 2001 From: cwj Date: Thu, 7 Sep 2023 19:59:48 +0800 Subject: [PATCH 22/23] update codes Signed-off-by: cwj --- .../fate_llm/model_zoo/offsite_tuning/offsite_tuning_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/fate_llm/model_zoo/offsite_tuning/offsite_tuning_model.py b/python/fate_llm/model_zoo/offsite_tuning/offsite_tuning_model.py index 1b34292..6bc6b47 100644 --- a/python/fate_llm/model_zoo/offsite_tuning/offsite_tuning_model.py +++ b/python/fate_llm/model_zoo/offsite_tuning/offsite_tuning_model.py @@ -122,7 +122,7 @@ def load_additional_param_state_dict(self, submodel_weights: dict): def _get_numpy_arr(self, v): if v.dtype == t.bfloat16: # float 32 - v = v.detach().cpu().float().cpu().numpy() + v = v.detach().cpu().float().numpy() else: v = v.detach().cpu().numpy() From 33d86eaaa2468f70a37356f86ee5c9bf34f869f1 Mon Sep 17 00:00:00 2001 From: cwj Date: Thu, 7 Sep 2023 20:08:08 +0800 Subject: [PATCH 23/23] update doc Signed-off-by: cwj --- README.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 42f01d0..1d8b395 100644 --- a/README.md +++ b/README.md @@ -17,12 +17,14 @@ FATE-LLM is a framework to support federated learning for large language models( ### Standalone deployment Please refer to [FATE-Standalone deployment](https://github.com/FederatedAI/FATE#standalone-deployment). -Deploy FATE-Standalone version with 1.11.2 <= version < 2.0, then copy directory `python/fate_llm` to `{fate_install}/fate/python/fate_llm` +Deploy FATE-Standalone version with 1.11.3 <= version < 2.0, then copy directory `python/fate_llm` to `{fate_install}/fate/python/fate_llm` ### Cluster deployment Use [FATE-LLM deployment packages](https://github.com/FederatedAI/FATE/wiki/Download#llm%E9%83%A8%E7%BD%B2%E5%8C%85) to deploy, refer to [FATE-Cluster deployment](https://github.com/FederatedAI/FATE#cluster-deployment) for more deployment details. ## Quick Start +- [Offsite-tuning Tutorial: Model Definition and Job Submission](./doc/tutorial/offsite_tuning/Offsite_tuning_tutorial.ipynb) +- [FedIPR Tutorial: Add Watermarks to Your Model](./doc/tutorial/fed_ipr/FedIPR-tutorial.ipynb) - [Federated ChatGLM-6B Training](./doc/tutorial/ChatGLM-6B_ds.ipynb) - [GPT-2 Training](./doc/tutorial/GPT2-example.ipynb) - [Builtin Models In PELLM](./doc/tutorial/builtin_models.md) \ No newline at end of file