-
Notifications
You must be signed in to change notification settings - Fork 0
/
illustrator.py
155 lines (128 loc) · 7.93 KB
/
illustrator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import dataclasses
import yaml
import argparse
import os
from typing import List
from typing import Dict
import torch
from diffusers import StableDiffusionPipeline, EulerDiscreteScheduler
from utils import write_story_pages, parse_line_to_story_page
from utils import StoryPage, CustomCharacter
from customization.textual_inversion import TextualInversionTrainer
from customization.dreambooth import DreamBoothTrainer
class Illustrator(object):
def __init__(self, device="cuda:0", **config):
self.device = device if torch.cuda.is_available() else "cpu"
self.scheduler = EulerDiscreteScheduler.from_pretrained(config["model_id"], subfolder="scheduler")
self.scheduler.set_timesteps(config["inference_steps"], self.device)
self.pipe = StableDiffusionPipeline.from_pretrained(
config["model_id"], scheduler=self.scheduler, torch_dtype=torch.float32
)
self.trainer = None
if config.get("custom_type") == "dreambooth":
self.trainer = DreamBoothTrainer(self.pipe, **config["custom_args"], device=self.device)
elif config.get("custom_type") == "textual-inversion":
self.trainer = TextualInversionTrainer(self.pipe, **config["custom_args"], device=self.device)
self.prompt_template = f"{config['prefix']}, %s, {config['suffix']}"
self.negative_prompt = config["negative_prompt"]
def customize(self, custom_characters: List[CustomCharacter]):
if self.trainer is None:
print("No customization applied since trainer not initialized.")
return
self.trainer.train(custom_characters)
def generate(self, pages: List[StoryPage], custom_characters: List[CustomCharacter]=None) -> List[StoryPage]:
illustrated_pages = []
self.pipe = self.pipe.to(self.device, torch_dtype=torch.float32)
for page in pages:
prompt = page.prompt
if self.trainer is not None and custom_characters is not None:
for character in custom_characters:
placeholder_custom_token = self.trainer.get_placeholder_token(character.custom_name)
new_token = f"{placeholder_custom_token} {character.orig_object}"
prompt = prompt.replace(character.orig_name, new_token)
full_prompt = self.prompt_template % prompt
print(full_prompt)
image = self.pipe(full_prompt, negative_prompt=self.negative_prompt).images[0]
illustrated_pages.append(dataclasses.replace(page, image=image))
return illustrated_pages
if __name__ == "__main__":
"""
# use trained customization model to generate images
python illustrator.py --orig_name Goldilocks --prompts_path output/goldilocks_and_the_three_bears/story.txt \
--prefix "mdjrny-v4 kids story illustration" \
--suffix "drawn by Rebecca Sugar, bright engaging children's illustration, digital painting, big eyes, beautiful shading, beautiful colors, amazon kdp, happy, interesting, 2D" \
--skip_training --device "cuda:1"
python illustrator.py --orig_name Jack --prompts_path output/jack_and_the_beanstalk/01/story.txt \
--prefix "mdjrny-v4 kids story illustration" \
--suffix "drawn by Rebecca Sugar, bright engaging children's illustration, digital painting, big eyes, beautiful shading, beautiful colors, amazon kdp, happy, interesting, 2D" \
--model_id "runs/dreambooth-model" --device "cuda:1"
# illustration with NO customization
python illustrator.py --orig_name Jack --prompts_path output/jack_and_the_beanstalk/story.txt \
--prefix "mdjrny-v4 kids illustration showcasing the story of 'Jack and the Beanstalk'" \
--suffix "drawn by Rebecca Sugar, bright engaging children's illustration, digital painting, big eyes, beautiful shading, beautiful colors, amazon kdp, happy, interesting, 2D" \
--config_path "config/sd2.yml"
python illustrator.py --orig_name Goldilocks --prompts_path output/goldilocks_and_the_three_bears/story.txt \
--prefix "mdjrny-v4 kids story illustration" \
--suffix "drawn by Rebecca Sugar, bright engaging children's illustration, digital painting, big eyes, beautiful shading, beautiful colors, amazon kdp, happy, interesting, 2D" \
python illustrator.py --orig_name "Little Red Riding Hood" --prompts_path output/little_red_riding_hood/story.txt \
--custom_name Aspen --custom_img_dir sample_images/aspen_512 \
--prefix "mdjrny-v4 kids story illustration" \
--suffix "drawn by Rebecca Sugar, bright engaging children's illustration, digital painting, big eyes, beautiful shading, beautiful colors, amazon kdp, happy, interesting, 2D" \
--suffix "rich colors, highly detailed, sharp focus, cinematic lighting, by Atey Ghailan and Beatrix Potter"
--prefix "Rebecca Sugar style kids book illustration showcasing the story Jack and the Beanstalk" \
python illustrator.py --orig_name Goldilocks --prompts_path output/goldilocks_and_the_three_bears/story.txt \
--custom_name Simon --custom_img_dir sample_images/simon_512 \
--prefix "mdjrny-v4 kids story illustration" \
--suffix "drawn by Rebecca Sugar, bright engaging children's illustration, digital painting, big eyes, beautiful shading, beautiful colors, amazon kdp, happy, interesting, 2D" \
--config_path config/dreambooth-sd1-5.yml --device cuda:1
"""
parser = argparse.ArgumentParser()
# required args
parser.add_argument("--orig_name", type=str, required=True) # choose a main character from story
parser.add_argument("--prompts_path", type=str, required=True) # text file with story
# prompt eng
parser.add_argument("--prefix", type=str, required=False, default=None) # prefix for text2image model
parser.add_argument("--suffix", type=str, required=False, default=None) # suffix for text2image model
parser.add_argument("--device", type=str, required=False, default="cuda:0") # suffix for text2image model
# custom args
parser.add_argument("--orig_object", type=str, required=False, default="boy")
parser.add_argument("--custom_name", type=str, required=False, default=None)
parser.add_argument("--custom_img_dir", type=str, required=False, default=None)
parser.add_argument("--config_path", type=str, required=False, default="config/openjourney.yml")
parser.add_argument("--skip_training", required=False, action="store_true")
args = parser.parse_args()
# read from text file into story pages
with open(args.prompts_path, 'r') as prompts_file:
lines = [line.strip() for line in prompts_file.readlines()]
title = lines[0]
pages = [parse_line_to_story_page(line) for line in lines[1:]]
print(pages)
# output dir
title_tag = title.lower().replace(" ", "_")
output_dir = f"output/{title_tag}"
os.makedirs(output_dir, exist_ok=True)
with open(args.config_path) as config_file:
config: Dict = yaml.safe_load(config_file)
print(f"Running config from {args.config_path}...")
# overwrite prefix and suffix from default if given
if args.prefix is not None:
config["illustrator"]["prefix"] = args.prefix
if args.suffix is not None:
config["illustrator"]["suffix"] = args.suffix
# overwrite model_id if given
if args.skip_training:
config["illustrator"]["model_id"] = config["illustrator"]["custom_args"]["custom_model_dir"]
illustrator = Illustrator(**config["illustrator"], device=args.device)
characters = [
CustomCharacter(
orig_name=args.orig_name,
orig_object=args.orig_object,
custom_name=args.custom_name,
custom_img_dir=args.custom_img_dir,
)
]
print("Custom characters: ", characters)
if not args.skip_training:
illustrator.customize(characters)
images = illustrator.generate(pages, custom_characters=characters)
write_story_pages(title, images, output_dir=output_dir)