diff --git a/sd_mecha/merge_methods.py b/sd_mecha/merge_methods.py index 1a5a5d6..fb51514 100644 --- a/sd_mecha/merge_methods.py +++ b/sd_mecha/merge_methods.py @@ -273,11 +273,48 @@ def train_difference( *, alpha: Hyper = 1.0, **kwargs, +) -> Tensor | SameMergeSpace: + mask = 1.8 * torch.nan_to_num(torch.abs(b - a) / (torch.abs(b - a) + torch.abs(b - c)), nan=0) + return a + (b - c) * alpha * mask + + +@convert_to_recipe +def add_opposite( + a: Tensor | SameMergeSpace, + b: Tensor | SameMergeSpace, + c: Tensor | SameMergeSpace, + *, + alpha: Hyper = 1.0, + **kwargs, ) -> Tensor | SameMergeSpace: threshold = torch.maximum(torch.abs(a - c), torch.abs(b - c)) - dissimilarity = torch.clamp(torch.nan_to_num((c - a) * (b - c) / threshold**2, nan=0), 0) + mask = 1 - torch.nan_to_num((a - c) * (b - c) / threshold**2, nan=0) + return a + (b - c) * alpha * mask - return a + (b - c) * alpha * dissimilarity + +@convert_to_recipe +def clamped_add_opposite( + a: Tensor | SameMergeSpace, + b: Tensor | SameMergeSpace, + c: Tensor | SameMergeSpace, + *, + alpha: Hyper = 1.0, + **kwargs, +) -> Tensor | SameMergeSpace: + threshold = torch.maximum(torch.abs(a - c), torch.abs(b - c)) + mask = torch.clamp(torch.nan_to_num((c - a) * (b - c) / threshold**2, nan=0), 0) * 2 + return a + (b - c) * alpha * mask + + +@convert_to_recipe +def select_max_delta( + a: Tensor | LiftFlag[MergeSpace.DELTA], + b: Tensor | LiftFlag[MergeSpace.DELTA], + *, + alpha: Hyper = 0.5, + **kwargs, +) -> Tensor | LiftFlag[MergeSpace.DELTA]: + return torch.where((1 - alpha) * a.abs() >= alpha * b.abs(), a, b) @convert_to_recipe