From ad43ee9be844442e05548cd448d7f7b18c41a705 Mon Sep 17 00:00:00 2001 From: Paul Date: Mon, 17 Jun 2024 15:01:29 +0100 Subject: [PATCH] Fix layerdiffuse for diffusers 0.29.0 --- lib_layerdiffusion/models.py | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/lib_layerdiffusion/models.py b/lib_layerdiffusion/models.py index 8496be8..c9e6d3c 100644 --- a/lib_layerdiffusion/models.py +++ b/lib_layerdiffusion/models.py @@ -7,20 +7,23 @@ from typing import Optional, Tuple from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.models.modeling_utils import ModelMixin -from diffusers.models.unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block +import importlib.metadata +from packaging.version import parse +diffusers_version = importlib.metadata.version('diffusers') -def check_diffusers_version(): - import diffusers - from packaging.version import parse - - assert parse(diffusers.__version__) >= parse( - "0.25.0" - ), "diffusers>=0.25.0 requirement not satisfied. Please install correct diffusers version." - +def check_diffusers_version(min_version="0.25.0"): + assert parse(diffusers_version) >= parse( + min_version + ), f"diffusers>={min_version} requirement not satisfied. Please install correct diffusers version." check_diffusers_version() +if parse(diffusers_version) >= parse("0.29.0"): + from diffusers.models.unets.unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block +else: + from diffusers.models.unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block + def zero_module(module): """