Skip to content

Commit

Permalink
Merge branch 'main' into imporve-doc
Browse files Browse the repository at this point in the history
  • Loading branch information
doombeaker authored Apr 8, 2024
2 parents c26420d + 954f759 commit 950eb6b
Show file tree
Hide file tree
Showing 39 changed files with 1,447 additions and 669 deletions.
14 changes: 7 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ The Full Introduction of OneDiff:
- [Features](#features)
- [Acceleration for State-of-the-art models](#acceleration-for-state-of-the-art-models)
- [Acceleration for production environment](#acceleration-for-production-environment)
- [OneDiff Quality Evalution](#onediff-quality-evalution)
- [OneDiff Enterprise Edition](#onediff-enterprise-edition)
- [OneDiff Quality Benchmark](#onefiff-quality-benchmark)
- [Installation](#installation)
- [Release](#release)
<!-- tocstop -->
Expand Down Expand Up @@ -120,7 +120,12 @@ Compile and save the compiled result offline, then load it online for serving
- [Save and Load the compiled graph](https://github.com/siliconflow/onediff/blob/main/onediff_diffusers_extensions/examples/text_to_image_sdxl_save_load.py)
- [Change device of the compiled graph to do multi-process serving](https://github.com/siliconflow/onediff/blob/main/onediff_diffusers_extensions/examples/text_to_image_sdxl_mp_load.py)
- Compile at one device(such as device 0), then use the compiled result to other device(such as device 1~7).
- This is for special scene and is in the Enterprise Edition.
- This is for special scenes and is in the Enterprise Edition.

### OneDiff Quality Evalution

We also maintain a repository for benchmarking the quality of generation after acceleration using OneDiff:
[OneDiffGenMetrics](https://github.com/siliconflow/OneDiffGenMetrics)

### OneDiff Enterprise Edition
If you need **Enterprise-level Support** for your system or business, you can
Expand All @@ -137,11 +142,6 @@ OneDiff Enterprise Edition can be **subscripted for one month and one GPU** and
| Technical Support for deployment | High priority support | Community |
| Get the experimental features | Yes | |

### OneDiff Quality Benchmark

We also maintain a repository for benchmarking the quality of generation after compilation acceleration using OneDiff:
[OneDiffGenMetrics](https://github.com/siliconflow/OneDiffGenMetrics)

## Installation
### OS and GPU support
- Linux
Expand Down
9 changes: 7 additions & 2 deletions onediff_comfy_nodes/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,24 @@
COMFYUI_SPEEDUP_ROOT = Path(os.path.abspath(__file__)).parents[0]
INFER_COMPILER_REGISTRY = Path(COMFYUI_SPEEDUP_ROOT) / "infer_compiler_registry"
os.environ["COMFYUI_ROOT"] = str(COMFYUI_ROOT)
custom_nodes_path = os.path.join(COMFYUI_ROOT, "custom_nodes")
sys.path.insert(0, str(COMFYUI_ROOT))
sys.path.insert(0, str(COMFYUI_SPEEDUP_ROOT))
sys.path.insert(0, str(INFER_COMPILER_REGISTRY))
if custom_nodes_path not in sys.path:
sys.path.append(custom_nodes_path)

import register_comfy # load plugins

from onediff.infer_compiler.utils import is_community_version

if is_community_version():
_USE_UNET_INT8 = False

if _USE_UNET_INT8:
import register_onediff_quant # load plugins

from folder_paths import folder_names_and_paths, supported_pt_extensions, models_dir
from folder_paths import (folder_names_and_paths, models_dir,
supported_pt_extensions)

unet_int8_model_dir = Path(models_dir) / "unet_int8"
unet_int8_model_dir.mkdir(parents=True, exist_ok=True)
Expand Down
29 changes: 7 additions & 22 deletions onediff_comfy_nodes/_nodes.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
from functools import partial
from onediff.infer_compiler.transform import torch2oflow
from onediff.infer_compiler.with_oneflow_compile import oneflow_compile
from ._config import _USE_UNET_INT8, ONEDIFF_QUANTIZED_OPTIMIZED_MODELS
from onediff.infer_compiler.utils import set_boolean_env_var
from onediff.optimization.quant_optimizer import quantize_model
from onediff.infer_compiler import oneflow_compile
from onediff.infer_compiler.with_oneflow_compile import DeployableModule
from onediff.infer_compiler.deployable_module import DeployableModule

import os
import re
Expand Down Expand Up @@ -39,8 +38,11 @@
nodes_hijacker.hijack()
from .modules.hijack_samplers import samplers_hijack
from .modules.hijack_animatediff import animatediff_hijacker
from .modules.hijack_ipadapter_plus import ipadapter_plus_hijacker

samplers_hijack.hijack()
animatediff_hijacker.hijack()
ipadapter_plus_hijacker.hijack()


__all__ = [
Expand Down Expand Up @@ -71,8 +73,6 @@ def INPUT_TYPES(s):
CATEGORY = "OneDiff"

def speedup(self, model, static_mode):
from onediff.infer_compiler import oneflow_compile

use_graph = static_mode == "enable"

offload_device = model_management.unet_offload_device()
Expand Down Expand Up @@ -141,8 +141,6 @@ def INPUT_TYPES(s):
CATEGORY = "OneDiff"

def load_graph(self, model, graph):
from onediff.infer_compiler.with_oneflow_compile import DeployableModule

diffusion_model = model.model.diffusion_model

load_graph(diffusion_model, graph, "cuda", subfolder="unet")
Expand Down Expand Up @@ -183,8 +181,6 @@ def INPUT_TYPES(s):
CATEGORY = "OneDiff"

def speedup(self, model, static_mode):
from onediff.infer_compiler import oneflow_compile

# To avoid overflow issues while maintaining performance,
# refer to: https://github.com/siliconflow/onediff/blob/09a94df1c1a9c93ec8681e79d24bcb39ff6f227b/examples/image_to_video.py#L112
set_boolean_env_var(
Expand Down Expand Up @@ -213,8 +209,6 @@ def INPUT_TYPES(s):
CATEGORY = "OneDiff"

def speedup(self, vae, static_mode):
from onediff.infer_compiler import oneflow_compile

use_graph = static_mode == "enable"

new_vae = copy.deepcopy(
Expand Down Expand Up @@ -246,8 +240,6 @@ def INPUT_TYPES(s):
CATEGORY = "OneDiff"

def load_graph(self, vae, graph):
from onediff.infer_compiler.with_oneflow_compile import DeployableModule

vae_model = vae.first_stage_model
device = model_management.vae_offload_device()
load_graph(vae_model, graph, device, subfolder="vae")
Expand All @@ -271,8 +263,6 @@ def INPUT_TYPES(s):
OUTPUT_NODE = True

def save_graph(self, images, vae, filename_prefix):
from onediff.infer_compiler.with_oneflow_compile import DeployableModule

vae_model = vae.first_stage_model
vae_device = model_management.vae_offload_device()
save_graph(vae_model, filename_prefix, vae_device, subfolder="vae")
Expand Down Expand Up @@ -718,10 +708,7 @@ class BatchSizePatcher:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model": ("MODEL",),
"latent_image": ("LATENT", ),
},
"required": {"model": ("MODEL",), "latent_image": ("LATENT",),},
}

RETURN_TYPES = ("MODEL",)
Expand All @@ -736,7 +723,7 @@ def set_cache_filename(self, model, latent_image):
file_dir = os.path.dirname(file_path)
file_name = os.path.basename(file_path)
names = file_name.split("_")
key , is_replace = "bs=", False
key, is_replace = "bs=", False
for i, name in enumerate(names):
if key in name:
names[i] = f"{key}{batch_size}"
Expand All @@ -746,15 +733,13 @@ def set_cache_filename(self, model, latent_image):

new_file_name = "_".join(names)
new_file_path = os.path.join(file_dir, new_file_name)

diff_model.set_graph_file(new_file_path)
else:
print(f"Warning: model is not a {DeployableModule}")
return (model,)




if _USE_UNET_INT8:
from .utils.quant_ksampler_tools import (
KSampleQuantumBase,
Expand Down
12 changes: 7 additions & 5 deletions onediff_comfy_nodes/modules/hijack_animatediff/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,19 @@
from ..sd_hijack_utils import Hijacker
from onediff.infer_compiler.transform import transform_mgr
from onediff.infer_compiler.import_tools import DynamicModuleLoader
from onediff.infer_compiler.utils.log_utils import logger
COMFYUI_ROOT = os.getenv("COMFYUI_ROOT")
pkg_name = "ComfyUI-AnimateDiff-Evolved"
animatediff_root = os.path.join(COMFYUI_ROOT, "custom_nodes", pkg_name)
load_animatediff_package = True
try:
animatediff_pt = DynamicModuleLoader.from_path(animatediff_root)
animatediff_of = transform_mgr.transform_package(pkg_name)
comfy_of = transform_mgr.transform_package("comfy")
if os.path.exists(animatediff_root):
animatediff_pt = DynamicModuleLoader.from_path(animatediff_root)
animatediff_of = transform_mgr.transform_package(pkg_name)
comfy_of = transform_mgr.transform_package("comfy")
else:
load_animatediff_package = False
except Exception as e:
logger.warning(f"Failed to load {pkg_name} from {animatediff_root} due to {e}")
print(f"Warning: Failed to load {pkg_name} from {animatediff_root} due to {e}")
load_animatediff_package = False

animatediff_hijacker = Hijacker()
53 changes: 34 additions & 19 deletions onediff_comfy_nodes/modules/hijack_animatediff/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,31 +2,35 @@
from einops import rearrange
from oneflow.nn.functional import group_norm
import oneflow as flow
from onediff.infer_compiler.with_oneflow_compile import DeployableModule
from onediff.infer_compiler.deployable_module import DeployableModule
from onediff.infer_compiler.transform import register
from ._config import animatediff_pt, animatediff_hijacker, animatediff_of, comfy_of

FunctionInjectionHolder = animatediff_pt.animatediff.sampling.FunctionInjectionHolder


def cast_bias_weight(s, input):
bias = None
# non_blocking = comfy.model_management.device_supports_non_blocking(input.device)
non_blocking = False
if s.bias is not None:
bias = s.bias.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking)
weight = s.weight.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking)
bias = s.bias.to(
device=input.device, dtype=input.dtype, non_blocking=non_blocking
)
weight = s.weight.to(
device=input.device, dtype=input.dtype, non_blocking=non_blocking
)
return weight, bias



def groupnorm_mm_factory(params, manual_cast=False):
def groupnorm_mm_forward(self, input):
# axes_factor normalizes batch based on total conds and unconds passed in batch;
# the conds and unconds per batch can change based on VRAM optimizations that may kick in
if not params.is_using_sliding_context():
batched_conds = input.size(0)//params.full_length
batched_conds = input.size(0) // params.full_length
else:
batched_conds = input.size(0)//params.context_options.context_length
batched_conds = input.size(0) // params.context_options.context_length

# input = rearrange(input, "(b f) c h w -> b c f h w", b=batched_conds)
input = input.unflatten(0, (batched_conds, -1)).permute(0, 2, 1, 3, 4)
Expand All @@ -40,8 +44,8 @@ def groupnorm_mm_forward(self, input):
# input = rearrange(input, "b c f h w -> (b f) c h w", b=batched_conds)
input = input.permute(0, 2, 1, 3, 4).flatten(0, 1)
return input
return groupnorm_mm_forward

return groupnorm_mm_forward


# ComfyUI/custom_nodes/ComfyUI-AnimateDiff-Evolved/animatediff/motion_module_ad.py
Expand All @@ -50,9 +54,11 @@ def groupnorm_mm_forward(self, input):
# ComfyUI/custom_nodes/ComfyUI-AnimateDiff-Evolved/animatediff/utils_model.py
ModelTypeSD = animatediff_pt.animatediff.utils_model.ModelTypeSD


class Handles:
def __init__(self):
self.handles = []

def add(self, obj, key, value):
org_attr = getattr(obj, key, None)
setattr(obj, key, value)
Expand All @@ -63,22 +69,35 @@ def restore(self):
handle()
self.handles = []


handles = Handles()


def inject_functions(orig_func, self, model, params):

ret = orig_func(self, model, params)

if model.motion_models is not None:
# only apply groupnorm hack if not [v3 or ([not Hotshot] and SD1.5 and v2 and apply_v2_properly)]
info = model.motion_models[0].model.mm_info
if not (info.mm_version == AnimateDiffVersion.V3 or
(info.mm_format not in [AnimateDiffFormat.HOTSHOTXL] and info.sd_type == ModelTypeSD.SD1_5 and info.mm_version == AnimateDiffVersion.V2 and params.apply_v2_properly)):

if not (
info.mm_version == AnimateDiffVersion.V3
or (
info.mm_format not in [AnimateDiffFormat.HOTSHOTXL]
and info.sd_type == ModelTypeSD.SD1_5
and info.mm_version == AnimateDiffVersion.V2
and params.apply_v2_properly
)
):

handles.add(flow.nn.GroupNorm, "forward", groupnorm_mm_factory(params))
# comfy_of.ops.manual_cast.GroupNorm.forward_comfy_cast_weights = groupnorm_mm_factory(params, manual_cast=True)
handles.add(comfy_of.ops.manual_cast.GroupNorm, "forward_comfy_cast_weights", groupnorm_mm_factory(params, manual_cast=True))

handles.add(
comfy_of.ops.manual_cast.GroupNorm,
"forward_comfy_cast_weights",
groupnorm_mm_factory(params, manual_cast=True),
)

del info
return ret

Expand All @@ -98,13 +117,9 @@ def cond_func(orig_func, self, model, *args, **kwargs):


animatediff_hijacker.register(
FunctionInjectionHolder.inject_functions,
inject_functions,
cond_func,
FunctionInjectionHolder.inject_functions, inject_functions, cond_func,
)

animatediff_hijacker.register(
FunctionInjectionHolder.restore_functions,
restore_functions,
cond_func,
FunctionInjectionHolder.restore_functions, restore_functions, cond_func,
)
Loading

0 comments on commit 950eb6b

Please sign in to comment.