From b57d146dc3beab67042bd700ee62b8dd45f62454 Mon Sep 17 00:00:00 2001 From: ljleb Date: Mon, 17 Jun 2024 12:52:46 -0400 Subject: [PATCH] Update custom_merge_method.py --- examples/custom_merge_method.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/examples/custom_merge_method.py b/examples/custom_merge_method.py index a34ffc4..3e5a5ea 100644 --- a/examples/custom_merge_method.py +++ b/examples/custom_merge_method.py @@ -1,6 +1,7 @@ import sd_mecha import torch from sd_mecha.extensions.merge_method import convert_to_recipe, LiftFlag, MergeSpace +from sd_mecha import Hyper sd_mecha.set_log_level() @@ -28,8 +29,10 @@ def custom_sum( b: torch.Tensor | LiftFlag[MergeSpace.BASE], *, # hyperparameters go here - alpha: float = 0.5, # default arguments are honored - beta: float, + # `Hyper` is an union type of `float`, `int` and `dict` (the dict case is for a different weight per block MBW), which is what the caller of the method excpects. + # in practice the method only ever receives numeric types (int or float), so no need to worry about the dict case + alpha: Hyper = 0.5, # default arguments are honored + beta: Hyper, # `@convert_to_recipe` introduces additional kwargs: `device=` and `dtype=` # We must put `**kwargs` to satisfy the type system: **kwargs,