Skip to content

Commit

Permalink
fix type issues (#40)
Browse files Browse the repository at this point in the history
* fix

* allow to merge in delta space

* allow to merge in delta space

* flip filter post

* double negation ffs

* generator shenanigans
  • Loading branch information
ljleb authored Sep 6, 2024
1 parent b16ba8d commit ef4e793
Show file tree
Hide file tree
Showing 5 changed files with 64 additions and 51 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "sd-mecha"
version = "0.0.25"
version = "0.0.26"
description = "State dict recipe merger"
readme = "README.md"
authors = [{ name = "ljleb" }]
Expand Down
8 changes: 4 additions & 4 deletions sd_mecha/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,17 +376,17 @@ def ties_with_dare(
base: RecipeNodeOrPath,
*models: RecipeNodeOrPath,
probability: Hyper = 0.9,
no_rescale: Hyper = 0.0,
rescale: Hyper = 1.0,
alpha: Hyper = 0.5,
seed: Optional[Hyper] = None,
seed: Hyper = -1,
k: Hyper = 0.2,
vote_sgn: Hyper = 0.0,
apply_stock: Hyper = 0.0,
cos_eps: Hyper = 1e-6,
apply_median: Hyper = 0.0,
eps: Hyper = 1e-6,
maxiter: Hyper = 100,
ftol: Hyper =1e-20,
ftol: Hyper = 1e-20,
device: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
) -> recipe_nodes.RecipeNode:
Expand All @@ -404,7 +404,7 @@ def ties_with_dare(
res = ties_sum_with_dropout(
*deltas,
probability=probability,
no_rescale=no_rescale,
rescale=rescale,
k=k,
vote_sgn=vote_sgn,
seed=seed,
Expand Down
93 changes: 53 additions & 40 deletions sd_mecha/merge_methods/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,8 +149,8 @@ def ties_sum_extended( # aka add_difference_ties
apply_stock: Hyper = 0.0,
cos_eps: Hyper = 1e-6,
apply_median: Hyper = 0.0,
eps: Hyper = 1e-6,
maxiter: Hyper = 100,
eps: Hyper = 1e-6,
maxiter: Hyper = 100,
ftol: Hyper =1e-20,
**kwargs,
) -> Tensor | LiftFlag[MergeSpace.DELTA]:
Expand All @@ -163,11 +163,11 @@ def ties_sum_extended( # aka add_difference_ties
filtered_delta = filtered_delta.sum(dim=0)

# $$ \tau_m $$
return torch.nan_to_num(filtered_delta * t / param_counts)
return torch.nan_to_num(filtered_delta * t / param_counts)
else:
# $$ \tau_m $$, but in geometric median instead of arithmetic mean. Considered to replace model stock.
filtered_delta = geometric_median_list_of_array(torch.unbind(filtered_delta), eps=eps, maxiter=maxiter, ftol=ftol)

return torch.nan_to_num(filtered_delta)


Expand Down Expand Up @@ -472,8 +472,13 @@ def create_filter(shape: Tuple[int, ...] | torch.Size, alpha: float, tilt: float
if not 0 <= alpha <= 1:
raise ValueError("alpha must be between 0 and 1")

# normalize tilt to the range [0, 2]
tilt -= math.floor(tilt // 2 * 2)
# normalize tilt to the range [0, 4]
tilt -= math.floor(tilt // 4 * 4)
if tilt > 2:
alpha = 1 - alpha
alpha_inverted = True
else:
alpha_inverted = False

gradients = list(reversed([
torch.linspace(0, 1, s, device=device)
Expand All @@ -492,12 +497,20 @@ def create_filter(shape: Tuple[int, ...] | torch.Size, alpha: float, tilt: float
else:
mesh = gradients[0]

if tilt < EPSILON or abs(tilt - 2) < EPSILON:
if tilt < EPSILON or abs(tilt - 4) < EPSILON:
dft_filter = (mesh > 1 - alpha).float()
elif abs(tilt - 2) < EPSILON:
dft_filter = (mesh < 1 - alpha).float()
else:
tilt_cot = 1 / math.tan(math.pi * tilt / 2)
dft_filter = torch.clamp(mesh*tilt_cot + alpha*tilt_cot + alpha - tilt_cot, 0, 1)

if tilt <= 1 or 2 < tilt <= 3:
dft_filter = mesh*tilt_cot + alpha*tilt_cot + alpha - tilt_cot
else: # 1 < tilt <= 2 or 3 < tilt
dft_filter = mesh*tilt_cot - alpha*tilt_cot + alpha
dft_filter = dft_filter.clip(0, 1)

if alpha_inverted:
dft_filter = 1 - dft_filter
return dft_filter


Expand All @@ -520,10 +533,8 @@ def rotate(
is_conv = len(a.shape) == 4 and a.shape[-1] != 1
if is_conv:
shape_2d = (-1, functools.reduce(operator.mul, a.shape[2:]))
elif len(a.shape) == 4:
shape_2d = (-1, functools.reduce(operator.mul, a.shape[1:]))
else:
shape_2d = (-1, a.shape[-1])
shape_2d = (a.shape[0], a.shape[1:].numel())

a_neurons = a.reshape(*shape_2d)
b_neurons = b.reshape(*shape_2d)
Expand Down Expand Up @@ -598,6 +609,7 @@ def dropout( # aka n-supermario
delta0: Tensor | LiftFlag[MergeSpace.DELTA],
*deltas: Tensor | LiftFlag[MergeSpace.DELTA],
probability: Hyper = 0.9,
rescale: Hyper = 1.0,
overlap: Hyper = 1.0,
overlap_emphasis: Hyper = 0.0,
seed: Hyper = -1,
Expand Down Expand Up @@ -625,7 +637,13 @@ def dropout( # aka n-supermario
final_delta = torch.zeros_like(delta0)
for mask, delta in zip(masks, deltas):
final_delta[mask] += delta[mask]
return final_delta / masks.sum(0).clamp(1) / (1 - probability)

if probability == 1.0:
rescalar = 1.0
else:
rescalar = (1.0 - probability) ** rescale
rescalar = rescalar if math.isfinite(rescalar) else 1
return final_delta / masks.sum(0).clamp(1) / rescalar


# Part of TIES w/ DARE
Expand All @@ -635,44 +653,39 @@ def dropout( # aka n-supermario
@convert_to_recipe
def ties_sum_with_dropout(
*deltas: Tensor | LiftFlag[MergeSpace.DELTA],
probability: Hyper = 0.9,
no_rescale: Hyper = 0.0,
probability: Hyper = 0.9,
rescale: Hyper = 1.0,
k: Hyper = 0.2,
vote_sgn: Hyper = 0.0,
apply_stock: Hyper = 0.0,
cos_eps: Hyper = 1e-6,
apply_median: Hyper = 0.0,
eps: Hyper = 1e-6,
maxiter: Hyper = 100,
eps: Hyper = 1e-6,
maxiter: Hyper = 100,
ftol: Hyper = 1e-20,
seed: Hyper = -1,
**kwargs,
) -> Tensor | LiftFlag[MergeSpace.DELTA]:
# Set seed
if seed < 0:
seed = None
else:
seed = int(seed)
torch.manual_seed(seed)
if not deltas or probability == 1:
return 0

generator = torch.Generator(deltas[0].device)
if seed is not None and seed >= 0:
generator.manual_seed(round(seed))

# Under "Dropout", delta will be 0 by definition. Multiply it (Hadamard product) will return 0 also.
# $$ \tilde{\delta}^t = (1 - m^t) \odot \delta^t $$
deltas = [delta * torch.bernoulli(torch.full(delta.shape, 1 - probability)) for delta in deltas]
deltas = [delta * torch.bernoulli(torch.full(delta.shape, 1 - probability, device=delta.device, dtype=delta.dtype), generator=generator) for delta in deltas]

# $$ \tilde{\delta}^t = \tau_m = \hat{\tau}_t $$ O(N) in space
deltas = ties_sum_extended.__wrapped__(*deltas, k=k, vote_sgn=vote_sgn, apply_stock=apply_stock, cos_eps=cos_eps, apply_median=apply_median, eps=eps, maxiter=maxiter, ftol=ftol)

if probability == 1.0:
# Corner case
return deltas * 0.0
elif no_rescale <= 0.0:
# Rescale
# $$ \hat{\delta}^t = \tilde{\delta}^t / (1-p) $$
return deltas / (1.0 - probability)
rescalar = 1.0
else:
# No rescale
# $$ \hat{\delta}^t = \tilde{\delta}^t $$
return deltas
rescalar = (1.0 - probability) ** rescale
rescalar = rescalar if math.isfinite(rescalar) else 1
return deltas / rescalar


def overlapping_sets_pmf(n, p, overlap, overlap_emphasis):
Expand Down Expand Up @@ -722,7 +735,7 @@ def binomial_coefficient_np(n, k):
@convert_to_recipe
def model_stock_for_tensor(
*deltas: Tensor | LiftFlag[MergeSpace.DELTA],
cos_eps: Hyper = 1e-6,
cos_eps: Hyper = 1e-6,
**kwargs,
) -> Tensor | LiftFlag[MergeSpace.DELTA]:

Expand All @@ -746,7 +759,7 @@ def get_model_stock_t(deltas, cos_eps):

# One-liner is all you need. I may make it in running average if it really memory hungry.
cos_thetas = [cos(deltas[i], deltas[i + 1]) for i, _ in enumerate(deltas) if (i + 1) < n]

# Still a vector.
cos_theta = torch.stack(cos_thetas).mean(dim=0)

Expand All @@ -760,8 +773,8 @@ def get_model_stock_t(deltas, cos_eps):
@convert_to_recipe
def geometric_median(
*models: Tensor | SameMergeSpace,
eps: Hyper = 1e-6,
maxiter: Hyper = 100,
eps: Hyper = 1e-6,
maxiter: Hyper = 100,
ftol: Hyper = 1e-20,
**kwargs,
) -> Tensor | SameMergeSpace:
Expand All @@ -782,16 +795,16 @@ def geometric_median_list_of_array(models, eps, maxiter, ftol):
objective_value = geometric_median_objective(median, models, weights)

# Weiszfeld iterations
for _ in range(maxiter):
for _ in range(max(0, round(maxiter))):
prev_obj_value = objective_value
denom = torch.stack([l2distance(p, median) for p in models])
new_weights = weights / torch.clamp(denom, min=eps)
new_weights = weights / torch.clamp(denom, min=eps)
median = weighted_average(models, new_weights)

objective_value = geometric_median_objective(median, models, weights)
if abs(prev_obj_value - objective_value) <= ftol * objective_value:
break

return weighted_average(models, new_weights)


Expand Down
9 changes: 4 additions & 5 deletions sd_mecha/merge_methods/svd.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,25 +4,24 @@


def orthogonal_procrustes(a, b, cancel_reflection: bool = False):
if a.shape[0] + 10 < a.shape[1]:
if not cancel_reflection and a.shape[0] + 10 < a.shape[1]:
svd_driver = "gesvdj" if a.is_cuda else None
u, _, v = torch_svd_lowrank(a.T @ b, driver=svd_driver, q=a.shape[0] + 10)
v_t = v.T
del v
else:
svd_driver = "gesvd" if a.is_cuda else None
u, _, v_t = torch.linalg.svd(a.T @ b, driver=svd_driver)

if cancel_reflection:
u[:, -1] /= torch.det(u) * torch.det(v_t)
if cancel_reflection:
u[:, -1] /= torch.det(u) * torch.det(v_t)

transform = u @ v_t
if not torch.isfinite(u).all():
raise ValueError(
f"determinant error: {torch.det(transform)}. "
'This can happen when merging on the CPU with the "rotate" method. '
"Consider merging on a cuda device, "
"or try setting alpha to 1 for the problematic blocks. "
"or try setting `alignment` to 1 for the problematic blocks. "
"See this related discussion for more info: "
"https://github.com/s1dlx/meh/pull/50#discussion_r1429469484"
)
Expand Down
3 changes: 2 additions & 1 deletion sd_mecha/recipe_merger.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,10 @@ def merge_and_save(
save_dtype: Optional[torch.dtype] = torch.float16,
threads: Optional[int] = None,
total_buffer_size: int = 2**28,
strict_weight_space: bool = True,
):
recipe = extensions.merge_method.path_to_node(recipe)
if recipe.merge_space != recipe_nodes.MergeSpace.BASE:
if strict_weight_space and recipe.merge_space != recipe_nodes.MergeSpace.BASE:
raise ValueError(f"recipe should be in model merge space, not {str(recipe.merge_space).split('.')[-1]}")
if isinstance(fallback_model, (str, pathlib.Path)):
fallback_model = extensions.merge_method.path_to_node(fallback_model)
Expand Down

0 comments on commit ef4e793

Please sign in to comment.