Skip to content

Commit

Permalink
Fix numerics, add some features to VAE runner, add cpu scheduling opt…
Browse files Browse the repository at this point in the history
…ions
  • Loading branch information
monorimet committed Jun 19, 2024
1 parent fd2a2ba commit b1f20f1
Show file tree
Hide file tree
Showing 7 changed files with 141 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,12 @@ def is_valid_file(arg):
action="store_true",
help="Just compile attention reproducer for mmdit.",
)
p.add_argument(
"--vae_input_path",
type=str,
default=None,
help="Path to input latents for VAE inference numerics validation.",
)


##############################################################################
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ def export_mmdit_model(
torch.empty(hidden_states_shape, dtype=dtype),
torch.empty(encoder_hidden_states_shape, dtype=dtype),
torch.empty(pooled_projections_shape, dtype=dtype),
torch.empty(1, dtype=dtype),
torch.empty(init_batch_dim, dtype=dtype),
]

decomp_list = []
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def find_errs(turbine_output, torch_output, dim=[], failed_dims=[], errs=[]):
(batch_size, args.max_length * 2, 4096), dtype=dtype
)
pooled_projections = torch.randn((batch_size, 2048), dtype=dtype)
timestep = torch.tensor([0], dtype=dtype)
timestep = torch.tensor([0, 0], dtype=dtype)

turbine_output = run_mmdit_turbine(
hidden_states,
Expand All @@ -180,6 +180,7 @@ def find_errs(turbine_output, torch_output, dim=[], failed_dims=[], errs=[]):
timestep,
args,
)
np.save("torch_mmdit_output.npy", torch_output.astype(np.float16))
print("torch OUTPUT:", torch_output, torch_output.shape, torch_output.dtype)

print("\n(torch (comfy) image latents to iree image latents): ")
Expand Down
109 changes: 79 additions & 30 deletions models/turbine_models/custom_models/sd3_inference/sd3_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from turbine_models.custom_models.sd_inference import utils
from turbine_models.model_runner import vmfbRunner
from transformers import CLIPTokenizer
from diffusers import FlowMatchEulerDiscreteScheduler

from PIL import Image
import os
Expand Down Expand Up @@ -426,10 +427,16 @@ def load_pipeline(
unet_loaded = time.time()
print("\n[LOG] MMDiT loaded in ", unet_loaded - load_start, "sec")

runners["scheduler"] = sd3_schedulers.SharkSchedulerWrapper(
self.devices["mmdit"]["driver"],
vmfbs["scheduler"],
)
if not self.cpu_scheduling:
runners["scheduler"] = sd3_schedulers.SharkSchedulerWrapper(
self.devices["mmdit"]["driver"],
vmfbs["scheduler"],
)
else:
print("Using torch CPU scheduler.")
runners["scheduler"] = FlowMatchEulerDiscreteScheduler.from_pretrained(
self.hf_model_name, subfolder="scheduler"
)

sched_loaded = time.time()
print("\n[LOG] Scheduler loaded in ", sched_loaded - unet_loaded, "sec")
Expand Down Expand Up @@ -502,11 +509,12 @@ def generate_images(
)
)

guidance_scale = ireert.asdevicearray(
self.runners["pipe"].config.device,
np.asarray([guidance_scale]),
dtype=iree_dtype,
)
if not self.cpu_scheduling:
guidance_scale = ireert.asdevicearray(
self.runners["pipe"].config.device,
np.asarray([guidance_scale]),
dtype=iree_dtype,
)

tokenize_start = time.time()
text_input_ids_dict = self.tokenizer.tokenize_with_weights(prompt)
Expand Down Expand Up @@ -540,12 +548,23 @@ def generate_images(
"clip"
].ctx.modules.compiled_text_encoder["encode_tokens"](*text_encoders_inputs)
encode_prompts_end = time.time()
if self.cpu_scheduling:
timesteps, num_inference_steps = sd3_schedulers.retrieve_timesteps(
self.runners["scheduler"],
num_inference_steps=self.num_inference_steps,
timesteps=None,
)
steps = num_inference_steps


for i in range(batch_count):
unet_start = time.time()
sample, steps, timesteps = self.runners["scheduler"].initialize(samples[i])
if not self.cpu_scheduling:
latents, steps, timesteps = self.runners["scheduler"].initialize(samples[i])
else:
latents = torch.tensor(samples[i].to_host(), dtype=self.torch_dtype)
iree_inputs = [
sample,
latents,
ireert.asdevicearray(
self.runners["pipe"].config.device, prompt_embeds, dtype=iree_dtype
),
Expand All @@ -560,41 +579,71 @@ def generate_images(
# print(f"step {s}")
if self.cpu_scheduling:
step_index = s
t = timesteps[s]
if self.do_classifier_free_guidance:
latent_model_input = torch.cat([latents] * 2)
timestep = ireert.asdevicearray(
self.runners["pipe"].config.device,
t.expand(latent_model_input.shape[0]),
dtype=iree_dtype,
)
latent_model_input = ireert.asdevicearray(
self.runners["pipe"].config.device,
latent_model_input,
dtype=iree_dtype,
)
else:
step_index = ireert.asdevicearray(
self.runners["scheduler"].runner.config.device,
torch.tensor([s]),
"int64",
)
latents, t = self.runners["scheduler"].prep(
sample,
step_index,
timesteps,
)
latent_model_input, timestep = self.runners["scheduler"].prep(
latents,
step_index,
timesteps,
)
t = ireert.asdevicearray(
self.runners["scheduler"].runner.config.device,
timestep.to_host()[0]
)
noise_pred = self.runners["pipe"].ctx.modules.compiled_mmdit[
"run_forward"
](
latents,
latent_model_input,
iree_inputs[1],
iree_inputs[2],
t,
)
sample = self.runners["scheduler"].step(
noise_pred,
t,
sample,
guidance_scale,
step_index,
timestep,
)
if isinstance(sample, torch.Tensor):
if not self.cpu_scheduling:
latents = self.runners["scheduler"].step(
noise_pred,
t,
latents,
guidance_scale,
step_index,
)
else:
noise_pred = torch.tensor(noise_pred.to_host(), dtype=self.torch_dtype)
if self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
latents = self.runners["scheduler"].step(
noise_pred,
t,
latents,
return_dict=False,
)[0]

if isinstance(latents, torch.Tensor):
latents = latents.type(self.vae_dtype)
latents = ireert.asdevicearray(
self.runners["vae"].config.device,
sample,
dtype=self.vae_dtype,
latents,
)
else:
vae_numpy_dtype = np.float32 if self.vae_precision == "fp32" else np.float16
latents = sample.astype(vae_numpy_dtype)
latents = latents.astype(vae_numpy_dtype)

vae_start = time.time()
vae_out = self.runners["vae"].ctx.modules.compiled_vae["decode"](latents)
Expand Down Expand Up @@ -791,10 +840,10 @@ def run_diffusers_cpu(
cpu_scheduling=args.cpu_scheduling,
vae_precision=args.vae_precision,
)
vmfbs, weights = sd3_pipe.check_prepared(mlirs, vmfbs, weights)
if args.cpu_scheduling:
vmfbs.pop("scheduler")
weights.pop("scheduler")
vmfbs, weights = sd3_pipe.check_prepared(mlirs, vmfbs, weights)
if args.npu_delegate_path:
extra_device_args = {"npu_delegate_path": args.npu_delegate_path}
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import os
import inspect
from typing import List

import torch
from typing import Any, Callable, Dict, List, Optional, Union
from shark_turbine.aot import *
import shark_turbine.ops.iree as ops
from iree.compiler.ir import Context
Expand Down Expand Up @@ -75,11 +77,12 @@ def initialize(self, sample):

def prepare_model_input(self, sample, t, timesteps):
t = timesteps[t]
t = t.expand(sample.shape[0])

if self.do_classifier_free_guidance:
latent_model_input = torch.cat([sample] * 2)
else:
latent_model_input = sample
t = t.expand(sample.shape[0])
return latent_model_input.type(self.dtype), t.type(self.dtype)

def step(self, noise_pred, t, sample, guidance_scale, i):
Expand Down Expand Up @@ -146,6 +149,42 @@ def step(self, noise_pred, t, latents, guidance_scale, i):
return_dict=False,
)[0]

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
# Only used for cpu scheduling.
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None,
sigmas: Optional[List[float]] = None,
**kwargs,
):
if timesteps is not None and sigmas is not None:
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
if timesteps is not None:
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accepts_timesteps:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" timestep schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
elif sigmas is not None:
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accept_sigmas:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" sigmas schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps
return timesteps, num_inference_steps

@torch.no_grad()
def export_scheduler_model(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def __init__(
)

def decode(self, inp):
inp = (inp / self.vae.config.scaling_factor) + self.vae.config.shift_factor
image = self.vae.decode(inp, return_dict=False)[0]
image = image.float()
image = torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)[0]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,17 @@ def imagearray_from_vae_out(image):
if __name__ == "__main__":
from turbine_models.custom_models.sd3_inference.sd3_cmd_opts import args
import numpy as np
from PIL import Image

dtype = torch.float16 if args.precision == "fp16" else torch.float32
if args.vae_variant == "decode":
example_input = torch.rand(
args.batch_size, 16, args.height // 8, args.width // 8, dtype=dtype
)
if args.vae_input_path:
example_input = np.load(args.vae_input_path)
if example_input.shape[0] == 2:
example_input = np.split(example_input, 2)[0]
elif args.vae_variant == "encode":
example_input = torch.rand(
args.batch_size, 3, args.height, args.width, dtype=dtype
Expand All @@ -74,13 +79,16 @@ def imagearray_from_vae_out(image):
from turbine_models.custom_models.sd_inference import utils

torch_output = run_torch_vae(
args.hf_model_name, args.vae_variant, example_input.float()
args.hf_model_name, args.vae_variant, torch.tensor(example_input).float()
)
print("TORCH OUTPUT:", torch_output, torch_output.shape, torch_output.dtype)
if args.vae_input_path:
out_image_torch = Image.fromarray(torch_output)
out_image_torch.save("vae_test_output_torch.png")
out_image_turbine = Image.fromarray(turbine_results)
out_image_turbine.save("vae_test_output_turbine.png")
# Allow a small amount of wiggle room for rounding errors (1)

np.testing.assert_allclose(
turbine_results, torch_output, rtol=1, atol=1
)

# TODO: Figure out why we occasionally segfault without unlinking output variables
turbine_results = None

0 comments on commit b1f20f1

Please sign in to comment.