From 08398314e8ed8b1a480ec1d59db07e4dbd5a49c9 Mon Sep 17 00:00:00 2001 From: laksjdjf Date: Thu, 18 Apr 2024 16:22:06 +0900 Subject: [PATCH] allow huggingface path --- tools/create_control_lora.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/tools/create_control_lora.py b/tools/create_control_lora.py index 4e8bc89..a3e2a5c 100644 --- a/tools/create_control_lora.py +++ b/tools/create_control_lora.py @@ -175,6 +175,8 @@ def svd_extract(pretrained, finetuned, rank, device="cuda", dtype=torch.float16) if __name__ == "__main__": import argparse from safetensors.torch import load_file, save_file + from diffusers import UNet2DConditionModel, ControlNetModel + import os parser = argparse.ArgumentParser() parser.add_argument("--controlnet", "-c", type=str, required=True) parser.add_argument("--unet", "-u", type=str, required=True) @@ -184,9 +186,16 @@ def svd_extract(pretrained, finetuned, rank, device="cuda", dtype=torch.float16) parser.add_argument("--dtype", type=str, default="torch.float16") args = parser.parse_args() - controlnet = load_file(args.controlnet) + if os.path.exists(args.controlnet) and os.path.isfile(args.controlnet): + controlnet = load_file(args.controlnet) + else: + controlnet = ControlNetModel.from_pretrained(args.controlnet).state_dict() print("controlnet loaded") - unet = load_file(args.unet) + + if os.path.exists(args.unet) and os.path.isfile(args.unet): + unet = load_file(args.unet) + else: + unet = UNet2DConditionModel.from_pretrained(args.unet, subfolder="unet").state_dict() print("unet loaded") control_lora = create_lora(controlnet, unet, rank=args.rank, device=args.device, dtype=eval(args.dtype))