diff --git a/experiments/attack_defense_test.py b/experiments/attack_defense_test.py index 046f888..eb8cf4e 100644 --- a/experiments/attack_defense_test.py +++ b/experiments/attack_defense_test.py @@ -173,7 +173,7 @@ def test_attack_defense(): _import_path=EVASION_ATTACK_PARAMETERS_PATH, _config_class="EvasionAttackConfig", _config_kwargs={ - "epsilon": 0.01 * 1, + "epsilon": 0.001 * 12, } ) @@ -223,6 +223,15 @@ def test_attack_defense(): } ) + distillation_evasion_defense_config = ConfigPattern( + _class_name="DistillationDefender", + _import_path=EVASION_DEFENSE_PARAMETERS_PATH, + _config_class="EvasionDefenseConfig", + _config_kwargs={ + "temperature": 0.5 * 20 + } + ) + fgsm_evasion_attack_config0 = ConfigPattern( _class_name="FGSM", _import_path=EVASION_ATTACK_PARAMETERS_PATH, @@ -244,7 +253,7 @@ def test_attack_defense(): # gnn_model_manager.set_poison_attacker(poison_attack_config=random_poison_attack_config) # gnn_model_manager.set_poison_defender(poison_defense_config=gnnguard_poison_defense_config) gnn_model_manager.set_evasion_attacker(evasion_attack_config=fgsm_evasion_attack_config) - gnn_model_manager.set_evasion_defender(evasion_defense_config=gradientregularization_evasion_defense_config) + gnn_model_manager.set_evasion_defender(evasion_defense_config=distillation_evasion_defense_config) warnings.warn("Start training") dataset.train_test_split() diff --git a/metainfo/evasion_defense_parameters.json b/metainfo/evasion_defense_parameters.json index b9514be..e9ff771 100644 --- a/metainfo/evasion_defense_parameters.json +++ b/metainfo/evasion_defense_parameters.json @@ -7,6 +7,9 @@ "QuantizationDefender": { "num_levels": ["num_levels", "int", 32, {"min": 2, "step": 1}, "?"] }, + "DistillationDefender": { + "temperature": ["temperature", "float", 5.0, {"min": 1, "step": 0.01}, "?"] + }, "AdvTraining": { "attack_name": ["attack_name", "str", "FGSM", {}, "?"] } diff --git a/src/defense/evasion_defense.py b/src/defense/evasion_defense.py index 9cd9c5a..2eea79e 100644 --- a/src/defense/evasion_defense.py +++ b/src/defense/evasion_defense.py @@ -108,7 +108,7 @@ def pre_batch( def quantize( self, - x + x: torch.Tensor ): x_min = x.min() x_max = x.max() @@ -118,6 +118,40 @@ def quantize( return x_quantized +class DistillationDefender( + EvasionDefender +): + name = "DistillationDefender" + + def __init__( + self, + temperature: float = 5.0 + ): + """ + """ + super().__init__() + self.temperature = temperature + + def post_batch( + self, + model_manager, + batch, + loss: torch.Tensor + ): + """ + """ + model = model_manager.gnn + logits = model(batch) + soft_targets = torch.softmax(logits / self.temperature, dim=1) + distillation_loss = torch.nn.functional.kl_div( + torch.log_softmax(logits / self.temperature, dim=1), + soft_targets, + reduction='batchmean' + ) * (self.temperature ** 2) + modified_loss = loss + distillation_loss + return {"loss": modified_loss} + + class AdvTraining( EvasionDefender ): diff --git a/src/models_builder/gnn_constructor.py b/src/models_builder/gnn_constructor.py index 3df22e8..bf8b325 100644 --- a/src/models_builder/gnn_constructor.py +++ b/src/models_builder/gnn_constructor.py @@ -617,17 +617,35 @@ def arguments_read( edge_weight = kwargs.get('edge_weight', None) if batch is None: batch = torch.zeros(kwargs['x'].shape[0], dtype=torch.int64, device=x.device) - elif len(args) == 2: - x, edge_index = args[0], args[1] - batch = torch.zeros(args[0].shape[0], dtype=torch.int64, device=x.device) - edge_weight = None - elif len(args) == 3: - x, edge_index, batch = args[0], args[1], args[2] - edge_weight = None - elif len(args) == 4: - x, edge_index, batch, edge_weight = args[0], args[1], args[2], args[3] else: - raise ValueError(f"forward's args should take 2 or 3 arguments but got {len(args)}") + if len(args) == 1: + args = args[0] + if 'x' in args and 'edge_index' in args: + x, edge_index = args.x, args.edge_index + else: + raise ValueError(f"forward's args should contain x and 3" + f" edge_index Tensors but {args.keys} doesn't content this Tensors") + if 'batch' in args: + batch = args.batch + else: + batch = torch.zeros(args.x.shape[0], dtype=torch.int64, device=x.device) + if 'edge_weight' in args: + edge_weight = args.edge_weight + else: + edge_weight = None + else: + if len(args) == 2: + x, edge_index = args[0], args[1] + batch = torch.zeros(args[0].shape[0], dtype=torch.int64, device=x.device) + edge_weight = None + elif len(args) == 3: + x, edge_index, batch = args[0], args[1], args[2] + edge_weight = None + elif len(args) == 4: + x, edge_index, batch, edge_weight = args[0], args[1], args[2], args[3] + else: + raise ValueError(f"forward's args should take 2 or 3 arguments but got {len(args)}") + else: if hasattr(data, "edge_weight"): x, edge_index, batch, edge_weight = data.x, data.edge_index, data.batch, data.edge_weight