Skip to content

Commit

Permalink
Update DirectML (#667)
Browse files Browse the repository at this point in the history
* torch 2

* update patches
  • Loading branch information
NullSenseStudio authored Jul 2, 2023
1 parent 52960e9 commit cca6960
Show file tree
Hide file tree
Showing 9 changed files with 30 additions and 47 deletions.
4 changes: 2 additions & 2 deletions generator_process/actions/control_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,7 +465,7 @@ def __call__(
batch_size = len(prompt) if isinstance(prompt, list) else 1
generator = []
for _ in range(batch_size):
gen = torch.Generator(device="cpu" if device in ("mps", "privateuseone") else device) # MPS and DML do not support the `Generator` API
gen = torch.Generator(device="cpu" if device in ("mps", "dml") else device) # MPS and DML do not support the `Generator` API
generator.append(gen.manual_seed(random.randrange(0, np.iinfo(np.uint32).max) if seed is None else seed))
if batch_size == 1:
# Some schedulers don't handle a list of generators: https://github.com/huggingface/diffusers/issues/1909
Expand Down Expand Up @@ -510,7 +510,7 @@ def __call__(
_configure_model_padding(pipe.vae, seamless_axes)

# Inference
with (torch.inference_mode() if device not in ('mps', "privateuseone") else nullcontext()), \
with (torch.inference_mode() if device not in ('mps', "dml") else nullcontext()), \
(torch.autocast(device) if optimizations.can_use("amp", device) else nullcontext()):
yield from pipe(
prompt=prompt,
Expand Down
4 changes: 2 additions & 2 deletions generator_process/actions/depth_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,7 @@ def __call__(
batch_size = len(prompt) if isinstance(prompt, list) else 1
generator = []
for _ in range(batch_size):
gen = torch.Generator(device="cpu" if device in ("mps", "privateuseone") else device) # MPS and DML do not support the `Generator` API
gen = torch.Generator(device="cpu" if device in ("mps", "dml") else device) # MPS and DML do not support the `Generator` API
generator.append(gen.manual_seed(random.randrange(0, np.iinfo(np.uint32).max) if seed is None else seed))
if batch_size == 1:
# Some schedulers don't handle a list of generators: https://github.com/huggingface/diffusers/issues/1909
Expand Down Expand Up @@ -371,7 +371,7 @@ def __call__(
_configure_model_padding(pipe.vae, seamless_axes)

# Inference
with torch.inference_mode() if device not in ('mps', "privateuseone") else nullcontext():
with torch.inference_mode() if device not in ('mps', "dml") else nullcontext():
yield from pipe(
prompt=prompt,
depth_image=depth_image,
Expand Down
8 changes: 3 additions & 5 deletions generator_process/actions/image_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,6 @@ def __call__(
# TODO: Add UI to enable this
# 10. Run safety checker
# image, has_nsfw_concept = self.run_safety_checker(image, device, text_embeddings.dtype)

image = self.image_processor.postprocess(image, output_type=output_type)

# Offload last model to CPU
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
Expand All @@ -145,7 +143,7 @@ def __call__(
# NOTE: Modified to yield the decoded image as a numpy array.
yield ImageGenerationResult(
[np.asarray(ImageOps.flip(image).convert('RGBA'), dtype=np.float32) / 255.
for i, image in enumerate(image)],
for i, image in enumerate(self.numpy_to_pil(image))],
[gen.initial_seed() for gen in generator] if isinstance(generator, list) else [generator.initial_seed()],
num_inference_steps,
True
Expand All @@ -166,7 +164,7 @@ def __call__(
batch_size = len(prompt) if isinstance(prompt, list) else 1
generator = []
for _ in range(batch_size):
gen = torch.Generator(device="cpu" if device in ("mps", "privateuseone") else device) # MPS and DML do not support the `Generator` API
gen = torch.Generator(device="cpu" if device in ("mps", "dml") else device) # MPS and DML do not support the `Generator` API
generator.append(gen.manual_seed(random.randrange(0, np.iinfo(np.uint32).max) if seed is None else seed))
if batch_size == 1:
# Some schedulers don't handle a list of generators: https://github.com/huggingface/diffusers/issues/1909
Expand All @@ -190,7 +188,7 @@ def __call__(
_configure_model_padding(pipe.vae, seamless_axes)

# Inference
with torch.inference_mode() if device not in ('mps', "privateuseone") else nullcontext():
with torch.inference_mode() if device not in ('mps', "dml") else nullcontext():
yield from pipe(
prompt=prompt,
image=[init_image] * batch_size,
Expand Down
4 changes: 2 additions & 2 deletions generator_process/actions/inpaint.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ def __call__(
batch_size = len(prompt) if isinstance(prompt, list) else 1
generator = []
for _ in range(batch_size):
gen = torch.Generator(device="cpu" if device in ("mps", "privateuseone") else device) # MPS and DML do not support the `Generator` API
gen = torch.Generator(device="cpu" if device in ("mps", "dml") else device) # MPS and DML do not support the `Generator` API
generator.append(gen.manual_seed(random.randrange(0, np.iinfo(np.uint32).max) if seed is None else seed))
if batch_size == 1:
# Some schedulers don't handle a list of generators: https://github.com/huggingface/diffusers/issues/1909
Expand All @@ -223,7 +223,7 @@ def __call__(
_configure_model_padding(pipe.vae, seamless_axes)

# Inference
with torch.inference_mode() if device not in ('mps', "privateuseone") else nullcontext():
with torch.inference_mode() if device not in ('mps', "dml") else nullcontext():
match inpaint_mask_src:
case 'alpha':
mask_image = ImageOps.invert(init_image.getchannel('A'))
Expand Down
20 changes: 10 additions & 10 deletions generator_process/actions/prompt_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,10 +151,10 @@ class Optimizations:
cudnn_benchmark: Annotated[bool, "cuda"] = False
tf32: Annotated[bool, "cuda"] = False
amp: Annotated[bool, "cuda"] = False
half_precision: Annotated[bool, {"cuda", "privateuseone"}] = True
cpu_offload: Annotated[str, {"cuda", "privateuseone"}] = "off"
half_precision: Annotated[bool, {"cuda", "dml"}] = True
cpu_offload: Annotated[str, {"cuda", "dml"}] = "off"
channels_last_memory_format: bool = False
sdp_attention: Annotated[bool, {"cpu", "cuda", "mps"}] = True
sdp_attention: bool = True
batch_size: int = 1
vae_slicing: bool = True
vae_tiling: str = "off"
Expand All @@ -169,7 +169,7 @@ def infer_device() -> str:
if sys.platform == "darwin":
return "mps"
elif Pipeline.directml_available():
return "privateuseone"
return "dml"
else:
return "cuda"

Expand Down Expand Up @@ -277,7 +277,7 @@ def apply(self, pipeline, device):
except: pass

from .. import directml_patches
if device == "privateuseone":
if device == "dml":
directml_patches.enable(pipeline)
else:
directml_patches.disable(pipeline)
Expand Down Expand Up @@ -380,8 +380,8 @@ def choose_device(self) -> str:
if Pipeline.directml_available():
import torch_directml
if torch_directml.is_available():
# can be named better when torch.utils.rename_privateuse1_backend() is released
return "privateuseone"
torch.utils.rename_privateuse1_backend("dml")
return "dml"
return "cpu"

def approximate_decoded_latents(latents):
Expand Down Expand Up @@ -600,7 +600,7 @@ def __call__(
batch_size = len(prompt) if isinstance(prompt, list) else 1
generator = []
for _ in range(batch_size):
gen = torch.Generator(device="cpu" if device in ("mps", "privateuseone") else device) # MPS and DML do not support the `Generator` API
gen = torch.Generator(device="cpu" if device in ("mps", "dml") else device) # MPS and DML do not support the `Generator` API
generator.append(gen.manual_seed(random.randrange(0, np.iinfo(np.uint32).max) if seed is None else seed))
if batch_size == 1:
# Some schedulers don't handle a list of generators: https://github.com/huggingface/diffusers/issues/1909
Expand All @@ -611,7 +611,7 @@ def __call__(
_configure_model_padding(pipe.vae, seamless_axes)

# Inference
with torch.inference_mode() if device not in ('mps', "privateuseone") else nullcontext():
with torch.inference_mode() if device not in ('mps', "dml") else nullcontext():
yield from pipe(
prompt=prompt,
height=height,
Expand Down Expand Up @@ -672,7 +672,7 @@ def _conv_forward_asymmetric(self, input, weight, bias):
"""
Patch for Conv2d._conv_forward that supports asymmetric padding
"""
if input.device.type == "privateuseone":
if input.device.type == "dml":
# DML pad() will wrongly fill the tensor in constant mode with the supplied value
# (default 0) when padding on both ends of a dimension, can't split to two calls.
working = nn.functional.pad(input, self._reversed_padding_repeated_twice, mode='circular')
Expand Down
2 changes: 1 addition & 1 deletion generator_process/actions/upscale.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def upscale(
pipe = pipe.to(device)
pipe = optimizations.apply(pipe, device)

generator = torch.Generator(device="cpu" if device in ("mps", "privateuseone") else device) # MPS and DML do not support the `Generator` API
generator = torch.Generator(device="cpu" if device in ("mps", "dml") else device) # MPS and DML do not support the `Generator` API
if seed is None:
seed = random.randrange(0, np.iinfo(np.uint32).max)

Expand Down
31 changes: 8 additions & 23 deletions generator_process/directml_patches.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,8 @@
active_dml_patches: list | None = None


def tensor_ensure_device(self, other, *, pre_patch):
"""Fix for operations where one tensor is DML and the other is CPU."""
if isinstance(other, Tensor) and self.device != other.device:
if self.device.type != "cpu":
other = other.to(self.device)
else:
self = self.to(other.device)
return pre_patch(self, other)


def baddbmm(input, batch1, batch2, *, beta=1, alpha=1, out=None, pre_patch):
if input.device.type == "privateuseone" and beta == 0:
if input.device.type == "dml" and beta == 0:
if out is not None:
torch.bmm(batch1, batch2, out=out)
out *= alpha
Expand All @@ -27,7 +17,7 @@ def baddbmm(input, batch1, batch2, *, beta=1, alpha=1, out=None, pre_patch):


def pad(input, pad, mode="constant", value=None, *, pre_patch):
if input.device.type == "privateuseone" and mode == "constant":
if input.device.type == "dml" and mode == "constant":
pad_dims = torch.tensor(pad, dtype=torch.int32).view(-1, 2).flip(0)
both_ends = False
for pre, post in pad_dims:
Expand All @@ -49,7 +39,7 @@ def pad(input, pad, mode="constant", value=None, *, pre_patch):


def getitem(self, key, *, pre_patch):
if isinstance(key, Tensor) and "privateuseone" in [self.device.type, key.device.type] and key.numel() == 1:
if isinstance(key, Tensor) and "dml" in [self.device.type, key.device.type] and key.numel() == 1:
return pre_patch(self, int(key))
return pre_patch(self, key)

Expand All @@ -72,15 +62,8 @@ def dml_patch_method(object, name, patched):

# Not all places where the patches have an effect are necessarily listed.

# PNDMScheduler.step()
dml_patch_method(Tensor, "__mul__", tensor_ensure_device)
# PNDMScheduler.step()
dml_patch_method(Tensor, "__sub__", tensor_ensure_device)
# DDIMScheduler.step() last timestep in image_to_image
dml_patch_method(Tensor, "__truediv__", tensor_ensure_device)

# CrossAttention.get_attention_scores()
# AttentionBlock.forward()
# diffusers.models.attention_processor.Attention.get_attention_scores()
# diffusers.models.attention.AttentionBlock.forward()
# Diffusers implementation gives torch.empty() tensors with beta=0 to baddbmm(), which may contain NaNs.
# DML implementation doesn't properly ignore input argument with beta=0 and causes NaN propagation.
dml_patch(torch, "baddbmm", baddbmm)
Expand All @@ -105,7 +88,9 @@ def nan_check(key, x):
nan_check(i, v)
for k, v in kwargs.items():
nan_check(k, v)
return original(*args, **kwargs)
r = original(*args, **kwargs)
nan_check("return", r)
return r
module.forward = func.__get__(module)

# only enable when testing
Expand Down
2 changes: 1 addition & 1 deletion generator_process/models/upscale_tiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ def _conv_forward_asymmetric(self, input, weight, bias):
"""
Patch for Conv2d._conv_forward that supports asymmetric padding
"""
if input.device.type == "privateuseone":
if input.device.type == "dml":
# DML pad() will wrongly fill the tensor in constant mode with the supplied value
# (default 0) when padding on both ends of a dimension, can't split to two calls.
working = nn.functional.pad(input, self._reversed_padding_repeated_twice, mode='circular')
Expand Down
2 changes: 1 addition & 1 deletion requirements/win-dml.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ transformers
accelerate
huggingface_hub

torch>=1.13
torch-directml
torch>=2.0

# Original SD checkpoint conversion
pytorch-lightning
Expand Down

0 comments on commit cca6960

Please sign in to comment.