Skip to content

Commit

Permalink
add sd3.5 config (#42)
Browse files Browse the repository at this point in the history
  • Loading branch information
ljleb authored Oct 28, 2024
1 parent ef4e793 commit 450532e
Show file tree
Hide file tree
Showing 6 changed files with 2,162 additions and 6 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "sd-mecha"
version = "0.0.26"
version = "0.0.27"
description = "State dict recipe merger"
readme = "README.md"
authors = [{ name = "ljleb" }]
Expand Down
1 change: 1 addition & 0 deletions sd_mecha/builtin_model_archs.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@
register_model_arch(models_dir / "sd1_ldm.yaml", "sd1")
register_model_arch(models_dir / "sdxl_sgm.yaml", "sdxl")
register_model_arch(models_dir / "sd3_sgm.yaml", "sd3")
register_model_arch(models_dir / "sd35.yaml", "sd35")
register_model_arch(models_dir / "flux_flux.yaml", "flux")
5 changes: 1 addition & 4 deletions sd_mecha/extensions/model_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,6 @@ def _create_header(fake_state_dict):
if k != "__metadata__"
}

def _create_fake_tensor(*shape, dtype):
return torch.empty(shape, dtype=dtype)

if self.strict_suffixes:
for sd_key in state_dict:
if sd_key == "__metadata__":
Expand All @@ -68,7 +65,7 @@ def _create_fake_tensor(*shape, dtype):

with FakeTensorMode():
fake_state_dict = {
k: _create_fake_tensor(*h["shape"], dtype=DTYPE_MAPPING[h["dtype"]][0])
k: torch.empty(tuple(h["shape"]), dtype=DTYPE_MAPPING[h["dtype"]][0])
for k, h in state_dict.header.items()
if k != "__metadata__" and (self.key_suffixes is None or k.endswith(self.key_suffixes))
}
Expand Down
1 change: 0 additions & 1 deletion sd_mecha/model_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from sd_mecha.extensions.model_arch import ModelArch
from sd_mecha.hypers import get_hyper
from sd_mecha.recipe_nodes import RecipeNode, ModelRecipeNode, ParameterRecipeNode, MergeRecipeNode, DepthRecipeVisitor, RecipeVisitor
from sd_mecha.streaming import InSafetensorsDict


@dataclasses.dataclass
Expand Down
64 changes: 64 additions & 0 deletions sd_mecha/models/sd35.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
passthrough:
- first_stage_model

merge:
unet:
prefix: model.diffusion_model
blocks:
in0:
- joint_blocks.0
- pos_embed
- t_embedder
- x_embedder
- y_embedder
- context_embedder
in37:
- joint_blocks.37
- pos_embed
- t_embedder
- x_embedder
- y_embedder
- final_layer
in*:
- joint_blocks.*
- pos_embed
- t_embedder
- x_embedder
- y_embedder

txt:
prefix: text_encoders.clip_l.transformer
blocks:
in0:
- text_model.encoder.layers.0
- text_model.embeddings
in11:
- text_model.encoder.layers.11
- text_model.final_layer_norm
in*: text_model.encoder.layers.*

txt2:
prefix: text_encoders.clip_g.transformer
blocks:
in0:
- text_model.encoder.layers.0
- text_model.embeddings
in31:
- text_model.encoder.layers.31
- text_model.final_layer_norm
- text_projection
in*: text_model.encoder.layers.*

t5xxl:
prefix: text_encoders.t5xxl.transformer
blocks:
in0:
- encoder.block.0
- encoder.embed_tokens
- shared
in23:
- encoder.block.23
- encoder.final_layer_norm
in*: encoder.block.*

keys: sd35_keys.txt
Loading

0 comments on commit 450532e

Please sign in to comment.