Skip to content

Commit

Permalink
allow huggingface path
Browse files Browse the repository at this point in the history
  • Loading branch information
laksjdjf authored Apr 18, 2024
1 parent 683d977 commit 0839831
Showing 1 changed file with 11 additions and 2 deletions.
13 changes: 11 additions & 2 deletions tools/create_control_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,8 @@ def svd_extract(pretrained, finetuned, rank, device="cuda", dtype=torch.float16)
if __name__ == "__main__":
import argparse
from safetensors.torch import load_file, save_file
from diffusers import UNet2DConditionModel, ControlNetModel
import os
parser = argparse.ArgumentParser()
parser.add_argument("--controlnet", "-c", type=str, required=True)
parser.add_argument("--unet", "-u", type=str, required=True)
Expand All @@ -184,9 +186,16 @@ def svd_extract(pretrained, finetuned, rank, device="cuda", dtype=torch.float16)
parser.add_argument("--dtype", type=str, default="torch.float16")
args = parser.parse_args()

controlnet = load_file(args.controlnet)
if os.path.exists(args.controlnet) and os.path.isfile(args.controlnet):
controlnet = load_file(args.controlnet)
else:
controlnet = ControlNetModel.from_pretrained(args.controlnet).state_dict()
print("controlnet loaded")
unet = load_file(args.unet)

if os.path.exists(args.unet) and os.path.isfile(args.unet):
unet = load_file(args.unet)
else:
unet = UNet2DConditionModel.from_pretrained(args.unet, subfolder="unet").state_dict()
print("unet loaded")

control_lora = create_lora(controlnet, unet, rank=args.rank, device=args.device, dtype=eval(args.dtype))
Expand Down

0 comments on commit 0839831

Please sign in to comment.