-
Notifications
You must be signed in to change notification settings - Fork 0
/
generate_img.py
149 lines (121 loc) · 4.17 KB
/
generate_img.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import numpy as np
import random
from functools import partial
from PIL import Image
from tqdm.notebook import trange
import jax
import jax.numpy as jnp
from dalle_mini import DalleBart, DalleBartProcessor
from vqgan_jax.modeling_flax_vqgan import VQModel
from transformers import CLIPProcessor, FlaxCLIPModel
from flax.jax_utils import replicate
from flax.training.common_utils import shard, shard_prng_key
# Model references
# dalle-mini
DALLE_MODEL = "dalle-mini/dalle-mini/mini-1:v0" # can be wandb artifact or 🤗 Hub or local folder or google bucket
DALLE_COMMIT_ID = None
# VQGAN model
VQGAN_REPO = "dalle-mini/vqgan_imagenet_f16_16384"
VQGAN_COMMIT_ID = "e93a26e7707683d349bf5d5c41c5b0ef69b677a9"
# CLIP model
CLIP_REPO = "openai/clip-vit-large-patch14"
CLIP_COMMIT_ID = None
# dalle-mega
DALLE_MODEL = "dalle-mini/dalle-mini/mega-1-fp16:latest"
# check how many devices are available
jax.local_device_count()
# Load dalle-mini
model, params = DalleBart.from_pretrained(DALLE_MODEL, revision=DALLE_COMMIT_ID, dtype=jnp.float16, _do_init=False)
# Load VQGAN
vqgan = VQModel.from_pretrained(VQGAN_REPO, revision=VQGAN_COMMIT_ID)
# Load CLIP
clip = FlaxCLIPModel.from_pretrained(CLIP_REPO, revision=CLIP_COMMIT_ID)
clip_processor = CLIPProcessor.from_pretrained(CLIP_REPO, revision=CLIP_COMMIT_ID)
# Model parameters are replicated on each device for faster inference.
params = replicate(params)
vqgan._params = replicate(vqgan.params)
clip._params = replicate(clip.params)
# Model functions are compiled and parallelized to take advantage of multiple devices.
# model inference
@partial(jax.pmap, axis_name="batch", static_broadcasted_argnums=(3, 4, 5, 6))
def p_generate(
tokenized_prompt, key, params, top_k, top_p, temperature, condition_scale
):
return model.generate(
**tokenized_prompt,
prng_key=key,
params=params,
top_k=top_k,
top_p=top_p,
temperature=temperature,
condition_scale=condition_scale,
)
# decode image
@partial(jax.pmap, axis_name="batch")
def p_decode(indices, params):
return vqgan.decode_code(indices, params=params)
# score images
@partial(jax.pmap, axis_name="batch")
def p_clip(inputs, params):
logits = clip(params=params, **inputs).logits_per_image
return logits
# Keys are passed to the model on each device to generate unique inference per device.
# create a random key
seed = random.randint(0, 2**32 - 1)
key = jax.random.PRNGKey(seed)
# Text Prompt
# Our model requires processing prompts.
processor = DalleBartProcessor.from_pretrained(DALLE_MODEL, revision=DALLE_COMMIT_ID)
# Define prompt
prompt = "sunset over a lake in the mountains"
tokenized_prompt = processor([prompt])
# Replicate on each device
tokenized_prompt = replicate(tokenized_prompt)
# Generate images
# We generate images using dalle-mini model and decode them with the VQGAN.
# number of predictions
n_predictions = 16
# We can customize top_k/top_p used for generating samples
gen_top_k = None
gen_top_p = None
temperature = None
cond_scale = 3.0
# generate images
images = []
for i in trange(max(n_predictions // jax.device_count(), 1)):
# get a new key
key, subkey = jax.random.split(key)
# generate images
encoded_images = p_generate(
tokenized_prompt,
shard_prng_key(subkey),
params,
gen_top_k,
gen_top_p,
temperature,
cond_scale,
)
# remove BOS
encoded_images = encoded_images.sequences[..., 1:]
# decode images
decoded_images = p_decode(encoded_images, vqgan.params)
decoded_images = decoded_images.clip(0.0, 1.0).reshape((-1, 256, 256, 3))
for img in decoded_images:
images.append(Image.fromarray(np.asarray(img * 255, dtype=np.uint8)))
# calculate image scores with CLIP.
# get clip scores
clip_inputs = clip_processor(
text=[prompt] * jax.device_count(),
images=images,
return_tensors="np",
padding="max_length",
max_length=77,
truncation=True,
).data
logits = p_clip(shard(clip_inputs), clip.params)
logits = logits.squeeze().flatten()
# Let's display images ranked by CLIP score.
print(f"Prompt: {prompt}\n")
for idx in logits.argsort()[::-1]:
display(images[idx])
print(f"Score: {logits[idx]:.2f}\n")