From ef18ecd788bda43667eceb73853386fe4b37b9fd Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Wed, 4 Sep 2024 17:37:20 -0400 Subject: [PATCH] add support for CLIPTextModel single file loading --- .../clip_text_model/config.json | 25 ++++++++++ invokeai/backend/model_manager/config.py | 12 +++++ .../model_manager/load/model_loaders/flux.py | 47 ++++++++++++++++++- invokeai/backend/model_manager/probe.py | 19 ++++++-- 4 files changed, 97 insertions(+), 6 deletions(-) create mode 100644 invokeai/backend/assets/model_base_conf_files/clip_text_model/config.json diff --git a/invokeai/backend/assets/model_base_conf_files/clip_text_model/config.json b/invokeai/backend/assets/model_base_conf_files/clip_text_model/config.json new file mode 100644 index 00000000000..5ebd923ce51 --- /dev/null +++ b/invokeai/backend/assets/model_base_conf_files/clip_text_model/config.json @@ -0,0 +1,25 @@ +{ + "_name_or_path": "openai/clip-vit-large-patch14", + "architectures": [ + "CLIPTextModel" + ], + "attention_dropout": 0.0, + "bos_token_id": 0, + "dropout": 0.0, + "eos_token_id": 2, + "hidden_act": "quick_gelu", + "hidden_size": 768, + "initializer_factor": 1.0, + "initializer_range": 0.02, + "intermediate_size": 3072, + "layer_norm_eps": 1e-05, + "max_position_embeddings": 77, + "model_type": "clip_text_model", + "num_attention_heads": 12, + "num_hidden_layers": 12, + "pad_token_id": 1, + "projection_dim": 768, + "torch_dtype": "bfloat16", + "transformers_version": "4.43.3", + "vocab_size": 49408 +} diff --git a/invokeai/backend/model_manager/config.py b/invokeai/backend/model_manager/config.py index 66e54d82f3a..caa361e7f1a 100644 --- a/invokeai/backend/model_manager/config.py +++ b/invokeai/backend/model_manager/config.py @@ -403,6 +403,17 @@ def get_tag() -> Tag: return Tag(f"{ModelType.CLIPEmbed.value}.{ModelFormat.Diffusers.value}") +class CLIPEmbedCheckpointConfig(CheckpointConfigBase): + """Model config for CLIP Embedding checkpoints.""" + + type: Literal[ModelType.CLIPEmbed] = ModelType.CLIPEmbed + format: Literal[ModelFormat.Checkpoint] + + @staticmethod + def get_tag() -> Tag: + return Tag(f"{ModelType.CLIPEmbed.value}.{ModelFormat.Checkpoint.value}") + + class CLIPVisionDiffusersConfig(DiffusersConfigBase): """Model config for CLIPVision.""" @@ -478,6 +489,7 @@ def get_model_discriminator_value(v: Any) -> str: Annotated[SpandrelImageToImageConfig, SpandrelImageToImageConfig.get_tag()], Annotated[CLIPVisionDiffusersConfig, CLIPVisionDiffusersConfig.get_tag()], Annotated[CLIPEmbedDiffusersConfig, CLIPEmbedDiffusersConfig.get_tag()], + Annotated[CLIPEmbedCheckpointConfig, CLIPEmbedCheckpointConfig.get_tag()], ], Discriminator(get_model_discriminator_value), ] diff --git a/invokeai/backend/model_manager/load/model_loaders/flux.py b/invokeai/backend/model_manager/load/model_loaders/flux.py index dcceda5ad21..75e34bbaad9 100644 --- a/invokeai/backend/model_manager/load/model_loaders/flux.py +++ b/invokeai/backend/model_manager/load/model_loaders/flux.py @@ -7,8 +7,17 @@ import accelerate import torch from safetensors.torch import load_file -from transformers import AutoConfig, AutoModelForTextEncoding, CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer +from transformers import ( + AutoConfig, + AutoModelForTextEncoding, + CLIPTextConfig, + CLIPTextModel, + CLIPTokenizer, + T5EncoderModel, + T5Tokenizer, +) +import invokeai.backend.assets.model_base_conf_files as model_conf_files from invokeai.app.services.config.config_default import get_config from invokeai.backend.flux.model import Flux from invokeai.backend.flux.modules.autoencoder import AutoEncoder @@ -23,6 +32,7 @@ ) from invokeai.backend.model_manager.config import ( CheckpointConfigBase, + CLIPEmbedCheckpointConfig, CLIPEmbedDiffusersConfig, MainBnbQuantized4bCheckpointConfig, MainCheckpointConfig, @@ -69,7 +79,7 @@ def _load_model( @ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.CLIPEmbed, format=ModelFormat.Diffusers) -class ClipCheckpointModel(ModelLoader): +class ClipDiffusersModel(ModelLoader): """Class to load main models.""" def _load_model( @@ -91,6 +101,39 @@ def _load_model( ) +@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.CLIPEmbed, format=ModelFormat.Checkpoint) +class ClipCheckpointModel(ModelLoader): + """Class to load main models.""" + + def _load_model( + self, + config: AnyModelConfig, + submodel_type: Optional[SubModelType] = None, + ) -> AnyModel: + if not isinstance(config, CLIPEmbedCheckpointConfig): + raise ValueError("Only CLIPEmbedCheckpointConfig models are currently supported here.") + + match submodel_type: + case SubModelType.Tokenizer: + # Clip embedding checkpoints don't have an integrated tokenizer, so we cheat and fetch it into the HuggingFace cache + # TODO: Fix this ugly workaround + return CLIPTokenizer.from_pretrained( + "InvokeAI/clip-vit-large-patch14-text-encoder", subfolder="bfloat16/tokenizer" + ) + case SubModelType.TextEncoder: + config_json = CLIPTextConfig.from_json_file(Path(model_conf_files.__path__[0], config.config_path)) + model = CLIPTextModel(config_json) + state_dict = load_file(config.path) + new_dict = {key: value for (key, value) in state_dict.items() if key.startswith("text_model.")} + model.load_state_dict(new_dict) + model.eval() + return model + + raise ValueError( + f"Only Tokenizer and TextEncoder submodels are currently supported. Received: {submodel_type.value if submodel_type else 'None'}" + ) + + @ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.T5Encoder, format=ModelFormat.BnbQuantizedLlmInt8b) class BnbQuantizedLlmInt8bCheckpointModel(ModelLoader): """Class to load main models.""" diff --git a/invokeai/backend/model_manager/probe.py b/invokeai/backend/model_manager/probe.py index c4f51b464cb..c58c7119116 100644 --- a/invokeai/backend/model_manager/probe.py +++ b/invokeai/backend/model_manager/probe.py @@ -8,7 +8,6 @@ import torch from picklescan.scanner import scan_file_path -import invokeai.backend.util.logging as logger from invokeai.app.util.misc import uuid_string from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS, ModelHash from invokeai.backend.model_manager.config import ( @@ -27,6 +26,7 @@ ) from invokeai.backend.model_manager.util.model_util import lora_token_vector_length, read_checkpoint_meta from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel +from invokeai.backend.util.logging import InvokeAILogger from invokeai.backend.util.silence_warnings import SilenceWarnings CkptType = Dict[str | int, Any] @@ -180,7 +180,9 @@ def probe( fields["repo_variant"] = fields.get("repo_variant") or probe.get_repo_variant() # additional fields needed for main and controlnet models - if fields["type"] in [ModelType.Main, ModelType.ControlNet, ModelType.VAE] and fields["format"] in [ + if fields["type"] in [ModelType.Main, ModelType.ControlNet, ModelType.VAE, ModelType.CLIPEmbed] and fields[ + "format" + ] in [ ModelFormat.Checkpoint, ModelFormat.BnbQuantizednf4b, ]: @@ -203,7 +205,6 @@ def probe( fields["base"] == BaseModelType.StableDiffusion2 and fields["prediction_type"] == SchedulerPredictionType.VPrediction ) - model_info = ModelConfigFactory.make_config(fields) # , key=fields.get("key", None)) return model_info @@ -252,6 +253,8 @@ def get_model_type_from_checkpoint(cls, model_path: Path, checkpoint: Optional[C return ModelType.IPAdapter elif key in {"emb_params", "string_to_param"}: return ModelType.TextualInversion + elif key.startswith(("text_model.embeddings", "text_model.encoder")): + return ModelType.CLIPEmbed # diffusers-ti if len(ckpt) < 10 and all(isinstance(v, torch.Tensor) for v in ckpt.values()): @@ -388,6 +391,8 @@ def _get_checkpoint_config_path( if base_type is BaseModelType.StableDiffusionXL else "stable-diffusion/v2-inference.yaml" ) + elif model_type is ModelType.CLIPEmbed: + return Path("clip_text_model", "config.json") else: raise InvalidModelConfigException( f"{model_path}: Unrecognized combination of model_type={model_type}, base_type={base_type}" @@ -650,6 +655,11 @@ def get_base_type(self) -> BaseModelType: raise NotImplementedError() +class CLIPEmbedCheckpointProbe(CheckpointProbeBase): + def get_base_type(self) -> BaseModelType: + return BaseModelType.Any + + class T2IAdapterCheckpointProbe(CheckpointProbeBase): def get_base_type(self) -> BaseModelType: raise NotImplementedError() @@ -807,7 +817,7 @@ def get_base_type(self) -> BaseModelType: if (self.model_path / "unet" / "config.json").exists(): return super().get_base_type() else: - logger.warning('Base type probing is not implemented for ONNX models. Assuming "sd-1"') + InvokeAILogger.get_logger().warning('Base type probing is not implemented for ONNX models. Assuming "sd-1"') return BaseModelType.StableDiffusion1 def get_format(self) -> ModelFormat: @@ -941,6 +951,7 @@ def get_base_type(self) -> BaseModelType: ModelProbe.register_probe("checkpoint", ModelType.ControlNet, ControlNetCheckpointProbe) ModelProbe.register_probe("checkpoint", ModelType.IPAdapter, IPAdapterCheckpointProbe) ModelProbe.register_probe("checkpoint", ModelType.CLIPVision, CLIPVisionCheckpointProbe) +ModelProbe.register_probe("checkpoint", ModelType.CLIPEmbed, CLIPEmbedCheckpointProbe) ModelProbe.register_probe("checkpoint", ModelType.T2IAdapter, T2IAdapterCheckpointProbe) ModelProbe.register_probe("checkpoint", ModelType.SpandrelImageToImage, SpandrelImageToImageCheckpointProbe)