From 52f0fec9a2d052aabe6512f45936db7bc14dbc27 Mon Sep 17 00:00:00 2001 From: Pierre Colle Date: Tue, 13 Feb 2024 11:44:56 +0100 Subject: [PATCH] histo-auto-encode progress --- .../train-histogram-autoencoder-16-64.toml | 61 +++++++++ .../train-histogram-autoencoder-16x7-64.toml | 61 +++++++++ .../train-histogram-autoencoder-32-64.toml | 61 +++++++++ .../train-histogram-autoencoder-4-16.toml | 61 +++++++++ .../train-histogram-autoencoder-4-32.toml | 61 +++++++++ .../train-histogram-autoencoder-4-64.toml | 61 +++++++++ .../train-histogram-autoencoder-4-8.toml | 61 +++++++++ .../train-histogram-autoencoder-8-64.toml | 61 +++++++++ ... train-histogram-autoencoder-8x6-8x4.toml} | 10 +- .../finetune-histogram-xs-2-layer.toml | 24 ++-- scripts/training/scheduler-local.bash | 4 +- .../fluxion/adapters/color_palette.py | 19 +++ src/refiners/fluxion/adapters/histogram.py | 43 ++++++- .../adapters/histogram_auto_encoder.py | 13 +- src/refiners/fluxion/layers/norm.py | 2 +- src/refiners/fluxion/utils.py | 3 + .../trainers/abstract_color_trainer.py | 29 ++++- .../training_utils/trainers/color_palette.py | 20 +-- .../training_utils/trainers/histogram.py | 25 +++- .../trainers/histogram_auto_encoder.py | 116 +++++++++++++++--- .../training_utils/trainers/trainer.py | 2 + tests/adapters/test_histogram.py | 29 ++++- 22 files changed, 753 insertions(+), 74 deletions(-) create mode 100644 configs/histogram-auto-encoder/train-histogram-autoencoder-16-64.toml create mode 100644 configs/histogram-auto-encoder/train-histogram-autoencoder-16x7-64.toml create mode 100644 configs/histogram-auto-encoder/train-histogram-autoencoder-32-64.toml create mode 100644 configs/histogram-auto-encoder/train-histogram-autoencoder-4-16.toml create mode 100644 configs/histogram-auto-encoder/train-histogram-autoencoder-4-32.toml create mode 100644 configs/histogram-auto-encoder/train-histogram-autoencoder-4-64.toml create mode 100644 configs/histogram-auto-encoder/train-histogram-autoencoder-4-8.toml create mode 100644 configs/histogram-auto-encoder/train-histogram-autoencoder-8-64.toml rename configs/histogram-auto-encoder/{train-histogram-autoencoder.toml => train-histogram-autoencoder-8x6-8x4.toml} (84%) diff --git a/configs/histogram-auto-encoder/train-histogram-autoencoder-16-64.toml b/configs/histogram-auto-encoder/train-histogram-autoencoder-16-64.toml new file mode 100644 index 000000000..3efcdec64 --- /dev/null +++ b/configs/histogram-auto-encoder/train-histogram-autoencoder-16-64.toml @@ -0,0 +1,61 @@ +script = "finetune-ldm-color-palette.py" # not used for now +[wandb] +mode = "online" # "online", "offline", "disabled" +entity = "piercus" +project = "histo-autoencoder" +name = "16-64" +tags = ["autoencoder"] + +[histogram_auto_encoder] +latent_dim = 64 +resnet_sizes = [16, 16, 16, 16, 16, 16, 16] +n_down_samples = 6 +color_bits = 6 + +[models] +histogram_auto_encoder = {train = true} + +[training] +duration = "1:epoch" +seed = 0 +gpu_index = 1 +batch_size = 8 +gradient_accumulation = "1:step" +clip_grad_norm = 1.0 +# clip_grad_value = 1.0 +evaluation_interval = "50:step" +evaluation_seed = 1 +num_workers = 8 + +[optimizer] +optimizer = "AdamW" # "SGD", "Adam", "AdamW", "AdamW8bit", "Lion8bit" +learning_rate = 5e-3 +betas = [0.9, 0.999] +eps = 1e-8 +weight_decay = 1e-2 + +[scheduler] +scheduler_type = "ConstantLR" +update_interval = "1:step" + +[dropout] +dropout_probability = 0.2 +use_gyro_dropout = false + +[dataset] +hf_repo = "1aurent/unsplash-lite-palette" +revision = "main" +resize_image_max_size = 512 +caption_key = "ai_description" +split = "train[10:]" + +[eval_dataset] +hf_repo = "1aurent/unsplash-lite-palette" +revision = "main" +resize_image_max_size = 512 +caption_key = "ai_description" +split = "train[:10]" + +[checkpointing] +save_interval = "100:step" +use_wandb = true diff --git a/configs/histogram-auto-encoder/train-histogram-autoencoder-16x7-64.toml b/configs/histogram-auto-encoder/train-histogram-autoencoder-16x7-64.toml new file mode 100644 index 000000000..a4922fb67 --- /dev/null +++ b/configs/histogram-auto-encoder/train-histogram-autoencoder-16x7-64.toml @@ -0,0 +1,61 @@ +script = "finetune-ldm-color-palette.py" # not used for now +[wandb] +mode = "online" # "online", "offline", "disabled" +entity = "piercus" +project = "histo-autoencoder" +name = "16x7-32" +tags = ["autoencoder"] + +[histogram_auto_encoder] +latent_dim = 32 +resnet_sizes = [4, 4, 8, 8, 16, 16, 16] +n_down_samples = 6 +color_bits = 6 + +[models] +histogram_auto_encoder = {train = true} + +[training] +duration = "10:epoch" +seed = 0 +gpu_index = 1 +batch_size = 8 +gradient_accumulation = "1:step" +clip_grad_norm = 1.0 +# clip_grad_value = 1.0 +evaluation_interval = "50:step" +evaluation_seed = 1 +num_workers = 8 + +[optimizer] +optimizer = "AdamW" # "SGD", "Adam", "AdamW", "AdamW8bit", "Lion8bit" +learning_rate = 5e-3 +betas = [0.9, 0.999] +eps = 1e-8 +weight_decay = 1e-2 + +[scheduler] +scheduler_type = "ConstantLR" +update_interval = "1:step" + +[dropout] +dropout_probability = 0.2 +use_gyro_dropout = false + +[dataset] +hf_repo = "1aurent/unsplash-lite-palette" +revision = "main" +resize_image_max_size = 512 +caption_key = "ai_description" +split = "train[10:]" + +[eval_dataset] +hf_repo = "1aurent/unsplash-lite-palette" +revision = "main" +resize_image_max_size = 512 +caption_key = "ai_description" +split = "train[:10]" + +[checkpointing] +save_interval = "100:step" +use_wandb = true diff --git a/configs/histogram-auto-encoder/train-histogram-autoencoder-32-64.toml b/configs/histogram-auto-encoder/train-histogram-autoencoder-32-64.toml new file mode 100644 index 000000000..1ce4223a2 --- /dev/null +++ b/configs/histogram-auto-encoder/train-histogram-autoencoder-32-64.toml @@ -0,0 +1,61 @@ +script = "finetune-ldm-color-palette.py" # not used for now +[wandb] +mode = "online" # "online", "offline", "disabled" +entity = "piercus" +project = "histo-autoencoder" +name = "32-64" +tags = ["autoencoder"] + +[histogram_auto_encoder] +latent_dim = 64 +resnet_sizes = [32, 32, 32, 32, 32, 32, 32] +n_down_samples = 6 +color_bits = 6 + +[models] +histogram_auto_encoder = {train = true} + +[training] +duration = "1:epoch" +seed = 0 +gpu_index = 1 +batch_size = 8 +gradient_accumulation = "1:step" +clip_grad_norm = 1.0 +# clip_grad_value = 1.0 +evaluation_interval = "50:step" +evaluation_seed = 1 +num_workers = 8 + +[optimizer] +optimizer = "AdamW" # "SGD", "Adam", "AdamW", "AdamW8bit", "Lion8bit" +learning_rate = 5e-3 +betas = [0.9, 0.999] +eps = 1e-8 +weight_decay = 1e-2 + +[scheduler] +scheduler_type = "ConstantLR" +update_interval = "1:step" + +[dropout] +dropout_probability = 0.2 +use_gyro_dropout = false + +[dataset] +hf_repo = "1aurent/unsplash-lite-palette" +revision = "main" +resize_image_max_size = 512 +caption_key = "ai_description" +split = "train[10:]" + +[eval_dataset] +hf_repo = "1aurent/unsplash-lite-palette" +revision = "main" +resize_image_max_size = 512 +caption_key = "ai_description" +split = "train[:10]" + +[checkpointing] +save_interval = "100:step" +use_wandb = true diff --git a/configs/histogram-auto-encoder/train-histogram-autoencoder-4-16.toml b/configs/histogram-auto-encoder/train-histogram-autoencoder-4-16.toml new file mode 100644 index 000000000..4d7eefd78 --- /dev/null +++ b/configs/histogram-auto-encoder/train-histogram-autoencoder-4-16.toml @@ -0,0 +1,61 @@ +script = "finetune-ldm-color-palette.py" # not used for now +[wandb] +mode = "online" # "online", "offline", "disabled" +entity = "piercus" +project = "histo-autoencoder" +name = "4-16" +tags = ["autoencoder"] + +[histogram_auto_encoder] +latent_dim = 64 +resnet_sizes = [4, 4, 4, 4, 4, 4, 4] +n_down_samples = 6 +color_bits = 6 + +[models] +histogram_auto_encoder = {train = true} + +[training] +duration = "1:epoch" +seed = 0 +gpu_index = 1 +batch_size = 64 +gradient_accumulation = "1:step" +clip_grad_norm = 1.0 +# clip_grad_value = 1.0 +evaluation_interval = "50:step" +evaluation_seed = 1 +num_workers = 8 + +[optimizer] +optimizer = "AdamW" # "SGD", "Adam", "AdamW", "AdamW8bit", "Lion8bit" +learning_rate = 5e-3 +betas = [0.9, 0.999] +eps = 1e-8 +weight_decay = 1e-2 + +[scheduler] +scheduler_type = "ConstantLR" +update_interval = "1:step" + +[dropout] +dropout_probability = 0.2 +use_gyro_dropout = false + +[dataset] +hf_repo = "1aurent/unsplash-lite-palette" +revision = "main" +resize_image_max_size = 512 +caption_key = "ai_description" +split = "train[10:]" + +[eval_dataset] +hf_repo = "1aurent/unsplash-lite-palette" +revision = "main" +resize_image_max_size = 512 +caption_key = "ai_description" +split = "train[:10]" + +[checkpointing] +save_interval = "100:step" +use_wandb = true diff --git a/configs/histogram-auto-encoder/train-histogram-autoencoder-4-32.toml b/configs/histogram-auto-encoder/train-histogram-autoencoder-4-32.toml new file mode 100644 index 000000000..1ce568a77 --- /dev/null +++ b/configs/histogram-auto-encoder/train-histogram-autoencoder-4-32.toml @@ -0,0 +1,61 @@ +script = "finetune-ldm-color-palette.py" # not used for now +[wandb] +mode = "online" # "online", "offline", "disabled" +entity = "piercus" +project = "histo-autoencoder" +name = "4-32" +tags = ["autoencoder"] + +[histogram_auto_encoder] +latent_dim = 64 +resnet_sizes = [4, 4, 4, 4, 4, 4, 4] +n_down_samples = 6 +color_bits = 6 + +[models] +histogram_auto_encoder = {train = true} + +[training] +duration = "1:epoch" +seed = 0 +gpu_index = 1 +batch_size = 64 +gradient_accumulation = "1:step" +clip_grad_norm = 1.0 +# clip_grad_value = 1.0 +evaluation_interval = "50:step" +evaluation_seed = 1 +num_workers = 8 + +[optimizer] +optimizer = "AdamW" # "SGD", "Adam", "AdamW", "AdamW8bit", "Lion8bit" +learning_rate = 5e-3 +betas = [0.9, 0.999] +eps = 1e-8 +weight_decay = 1e-2 + +[scheduler] +scheduler_type = "ConstantLR" +update_interval = "1:step" + +[dropout] +dropout_probability = 0.2 +use_gyro_dropout = false + +[dataset] +hf_repo = "1aurent/unsplash-lite-palette" +revision = "main" +resize_image_max_size = 512 +caption_key = "ai_description" +split = "train[10:]" + +[eval_dataset] +hf_repo = "1aurent/unsplash-lite-palette" +revision = "main" +resize_image_max_size = 512 +caption_key = "ai_description" +split = "train[:10]" + +[checkpointing] +save_interval = "100:step" +use_wandb = true diff --git a/configs/histogram-auto-encoder/train-histogram-autoencoder-4-64.toml b/configs/histogram-auto-encoder/train-histogram-autoencoder-4-64.toml new file mode 100644 index 000000000..2c663d9d7 --- /dev/null +++ b/configs/histogram-auto-encoder/train-histogram-autoencoder-4-64.toml @@ -0,0 +1,61 @@ +script = "finetune-ldm-color-palette.py" # not used for now +[wandb] +mode = "online" # "online", "offline", "disabled" +entity = "piercus" +project = "histo-autoencoder" +name = "4-64" +tags = ["autoencoder"] + +[histogram_auto_encoder] +latent_dim = 64 +resnet_sizes = [4, 4, 4, 4, 4, 4, 4] +n_down_samples = 6 +color_bits = 6 + +[models] +histogram_auto_encoder = {train = true} + +[training] +duration = "1:epoch" +seed = 0 +gpu_index = 1 +batch_size = 64 +gradient_accumulation = "1:step" +clip_grad_norm = 1.0 +# clip_grad_value = 1.0 +evaluation_interval = "50:step" +evaluation_seed = 1 +num_workers = 8 + +[optimizer] +optimizer = "AdamW" # "SGD", "Adam", "AdamW", "AdamW8bit", "Lion8bit" +learning_rate = 5e-3 +betas = [0.9, 0.999] +eps = 1e-8 +weight_decay = 1e-2 + +[scheduler] +scheduler_type = "ConstantLR" +update_interval = "1:step" + +[dropout] +dropout_probability = 0.2 +use_gyro_dropout = false + +[dataset] +hf_repo = "1aurent/unsplash-lite-palette" +revision = "main" +resize_image_max_size = 512 +caption_key = "ai_description" +split = "train[10:]" + +[eval_dataset] +hf_repo = "1aurent/unsplash-lite-palette" +revision = "main" +resize_image_max_size = 512 +caption_key = "ai_description" +split = "train[:10]" + +[checkpointing] +save_interval = "100:step" +use_wandb = true diff --git a/configs/histogram-auto-encoder/train-histogram-autoencoder-4-8.toml b/configs/histogram-auto-encoder/train-histogram-autoencoder-4-8.toml new file mode 100644 index 000000000..aa435ec7c --- /dev/null +++ b/configs/histogram-auto-encoder/train-histogram-autoencoder-4-8.toml @@ -0,0 +1,61 @@ +script = "finetune-ldm-color-palette.py" # not used for now +[wandb] +mode = "online" # "online", "offline", "disabled" +entity = "piercus" +project = "histo-autoencoder" +name = "4-8" +tags = ["autoencoder"] + +[histogram_auto_encoder] +latent_dim = 64 +resnet_sizes = [4, 4, 4, 4, 4, 4, 4] +n_down_samples = 6 +color_bits = 6 + +[models] +histogram_auto_encoder = {train = true} + +[training] +duration = "1:epoch" +seed = 0 +gpu_index = 1 +batch_size = 64 +gradient_accumulation = "1:step" +clip_grad_norm = 1.0 +# clip_grad_value = 1.0 +evaluation_interval = "50:step" +evaluation_seed = 1 +num_workers = 8 + +[optimizer] +optimizer = "AdamW" # "SGD", "Adam", "AdamW", "AdamW8bit", "Lion8bit" +learning_rate = 5e-3 +betas = [0.9, 0.999] +eps = 1e-8 +weight_decay = 1e-2 + +[scheduler] +scheduler_type = "ConstantLR" +update_interval = "1:step" + +[dropout] +dropout_probability = 0.2 +use_gyro_dropout = false + +[dataset] +hf_repo = "1aurent/unsplash-lite-palette" +revision = "main" +resize_image_max_size = 512 +caption_key = "ai_description" +split = "train[10:]" + +[eval_dataset] +hf_repo = "1aurent/unsplash-lite-palette" +revision = "main" +resize_image_max_size = 512 +caption_key = "ai_description" +split = "train[:10]" + +[checkpointing] +save_interval = "100:step" +use_wandb = true diff --git a/configs/histogram-auto-encoder/train-histogram-autoencoder-8-64.toml b/configs/histogram-auto-encoder/train-histogram-autoencoder-8-64.toml new file mode 100644 index 000000000..86c42433c --- /dev/null +++ b/configs/histogram-auto-encoder/train-histogram-autoencoder-8-64.toml @@ -0,0 +1,61 @@ +script = "finetune-ldm-color-palette.py" # not used for now +[wandb] +mode = "online" # "online", "offline", "disabled" +entity = "piercus" +project = "histo-autoencoder" +name = "8-64" +tags = ["autoencoder"] + +[histogram_auto_encoder] +latent_dim = 64 +resnet_sizes = [8, 8, 8, 8, 8, 8, 8] +n_down_samples = 6 +color_bits = 6 + +[models] +histogram_auto_encoder = {train = true} + +[training] +duration = "1:epoch" +seed = 0 +gpu_index = 1 +batch_size = 8 +gradient_accumulation = "1:step" +clip_grad_norm = 1.0 +# clip_grad_value = 1.0 +evaluation_interval = "50:step" +evaluation_seed = 1 +num_workers = 8 + +[optimizer] +optimizer = "AdamW" # "SGD", "Adam", "AdamW", "AdamW8bit", "Lion8bit" +learning_rate = 5e-3 +betas = [0.9, 0.999] +eps = 1e-8 +weight_decay = 1e-2 + +[scheduler] +scheduler_type = "ConstantLR" +update_interval = "1:step" + +[dropout] +dropout_probability = 0.2 +use_gyro_dropout = false + +[dataset] +hf_repo = "1aurent/unsplash-lite-palette" +revision = "main" +resize_image_max_size = 512 +caption_key = "ai_description" +split = "train[10:]" + +[eval_dataset] +hf_repo = "1aurent/unsplash-lite-palette" +revision = "main" +resize_image_max_size = 512 +caption_key = "ai_description" +split = "train[:10]" + +[checkpointing] +save_interval = "100:step" +use_wandb = true diff --git a/configs/histogram-auto-encoder/train-histogram-autoencoder.toml b/configs/histogram-auto-encoder/train-histogram-autoencoder-8x6-8x4.toml similarity index 84% rename from configs/histogram-auto-encoder/train-histogram-autoencoder.toml rename to configs/histogram-auto-encoder/train-histogram-autoencoder-8x6-8x4.toml index b4cf3b54e..44204e8b9 100644 --- a/configs/histogram-auto-encoder/train-histogram-autoencoder.toml +++ b/configs/histogram-auto-encoder/train-histogram-autoencoder-8x6-8x4.toml @@ -3,23 +3,23 @@ script = "finetune-ldm-color-palette.py" # not used for now mode = "online" # "online", "offline", "disabled" entity = "piercus" project = "histo-autoencoder" -name = "load-ckpt-5e-3" +name = "8x6-8x4" tags = ["autoencoder"] [histogram_auto_encoder] latent_dim = 8 -resnet_sizes = [4, 4, 4, 4, 4, 4] +resnet_sizes = [8, 8, 8, 8, 8, 8] n_down_samples = 5 color_bits = 6 [models] -histogram_auto_encoder = {train = true, checkpoint= './tmp/histogram-auto-encoder-step100.safetensors'} +histogram_auto_encoder = {train = true} [training] -duration = "5:epoch" +duration = "1:epoch" seed = 0 gpu_index = 1 -batch_size = 64 +batch_size = 32 gradient_accumulation = "1:step" clip_grad_norm = 1.0 # clip_grad_value = 1.0 diff --git a/configs/scheduled-local-histogram/finetune-histogram-xs-2-layer.toml b/configs/scheduled-local-histogram/finetune-histogram-xs-2-layer.toml index fd7afe01c..e169ac199 100644 --- a/configs/scheduled-local-histogram/finetune-histogram-xs-2-layer.toml +++ b/configs/scheduled-local-histogram/finetune-histogram-xs-2-layer.toml @@ -1,9 +1,9 @@ script = "finetune-ldm-color-palette.py" # not used for now [wandb] -mode = "disabled" # "online", "offline", "disabled" +mode = "offline" # "online", "offline", "disabled" entity = "piercus" project = "histogram" -name="xs" +name="local-histo-palette-eval" tags = ["l4", "local"] [models] @@ -61,7 +61,7 @@ hf_repo = "1aurent/unsplash-lite-palette" revision = "main" resize_image_max_size = 512 caption_key = "ai_description" -split = "train[4:]" +split = "train[200:]" #random_crop = false [eval_dataset] @@ -69,7 +69,7 @@ hf_repo = "1aurent/unsplash-lite-palette" revision = "main" resize_image_max_size = 512 caption_key = "ai_description" -split = "train[0:4]" +split = "train[0:200]" [checkpointing] #save_folder = "ckpts" @@ -78,15 +78,11 @@ use_wandb = true [evaluation] num_inference_steps = 30 -db_indexes = [0]#, 1, 2, 3] -batch_size = 1 +db_indexes = [15] +batch_size = 4 prompts = [ - "A Bustling City Street" -#, -# "A cute cat", -# "An oil painting", -# "A photography of a beautiful woman", -# "A pair of shoes", -# "A group of working people" +# "A Bustling City Street", + "A cute cat", +# "A photography of a beautiful woman" ] -condition_scale = 7.5 +condition_scale = 7.5 \ No newline at end of file diff --git a/scripts/training/scheduler-local.bash b/scripts/training/scheduler-local.bash index 8c0ae51f9..05773bf52 100644 --- a/scripts/training/scheduler-local.bash +++ b/scripts/training/scheduler-local.bash @@ -1,9 +1,9 @@ #!/bin/bash # Path to the directory containing the config files -config_dir="./configs/scheduled-local" +config_dir="./configs/histogram-auto-encoder" prefix="" -script="./scripts/training/finetune-ldm-color-palette.py" +script="./scripts/training/train-histogram-autoencoder.py" # Log file path log_file="./tmp/schedule-log.txt" diff --git a/src/refiners/fluxion/adapters/color_palette.py b/src/refiners/fluxion/adapters/color_palette.py index 2659b8714..f5911253a 100644 --- a/src/refiners/fluxion/adapters/color_palette.py +++ b/src/refiners/fluxion/adapters/color_palette.py @@ -218,6 +218,9 @@ def __call__(self, image: Image.Image, size: int | None = None) -> ColorPalette: image_np = np.array(image) pixels = image_np.reshape(-1, 3) + return self.from_pixels(pixels, size) + def from_pixels(self, pixels: np.ndarray, size: int | None = None) -> ColorPalette: + print("pixels.shape", pixels.shape) kmeans = KMeans(n_clusters=size).fit(pixels) # type: ignore counts = np.unique(kmeans.labels_, return_counts=True)[1] # type: ignore palette : ColorPalette = [] @@ -233,6 +236,22 @@ def __call__(self, image: Image.Image, size: int | None = None) -> ColorPalette: palette.append(color_cluster) sorted_palette = sorted(palette, key=lambda x: x[1], reverse=True) return sorted_palette + + def from_histogram(self, histogram: Tensor, color_bits: int, size: int | None = None, num: int = 1) -> ColorPalette: + if histogram.dim() != 4: + raise Exception('histogram must be 4 dimensions') + cube_size = 2 ** color_bits + color_factor = 256 / cube_size + pixels : list[np.ndarray] = [] + for histo in histogram.split(1): + for r in range(cube_size): + for g in range(cube_size): + for b in range(cube_size): + for i in range(int(histo[0, r, g, b]* num)): + pixels.append(np.array([r*color_factor, g*color_factor, b*color_factor])) + + return self.from_pixels(np.array(pixels), size) + def distance(self, a: ColorPalette, b: ColorPalette) -> float: #TO DO raise NotImplementedError diff --git a/src/refiners/fluxion/adapters/histogram.py b/src/refiners/fluxion/adapters/histogram.py index 4339967bd..587b96318 100644 --- a/src/refiners/fluxion/adapters/histogram.py +++ b/src/refiners/fluxion/adapters/histogram.py @@ -1,9 +1,9 @@ from typing import Any, List, TypeVar from refiners.foundationals.dinov2.vit import FeedForward -from torch import Tensor, sort, flatten, cat, device as Device, dtype as DType, histogramdd, histogram, nn, stack, zeros_like, float32 +from torch import ge, isnan, min, max, log, Tensor, sort, flatten, cat, device as Device, dtype as DType, histogramdd, histogram, nn, stack, zeros_like, float32 from torch.nn import init, L1Loss -from torch.nn.functional import mse_loss as _mse_loss +from torch.nn.functional import mse_loss as _mse_loss, kl_div as _kl_div import refiners.fluxion.layers as fl from refiners.fluxion.adapters.adapter import Adapter @@ -113,9 +113,46 @@ def __init__( self.color_bits = color_bits super().__init__(fl.Lambda(func=self.kl_div)) - def kl_div(self, x: Tensor, y: Tensor) -> Tensor: + def mse(self, x: Tensor, y: Tensor) -> Tensor: return _mse_loss(x, y) + + def correlation(self, x: Tensor, y: Tensor) -> Tensor: + n = (2 ** self.color_bits) ** 3 + + centered_x = x - 1/n + centered_y = y - 1/n + + denom = ((centered_x*centered_x).sum() * (centered_y*centered_y).sum()).sqrt() + return (centered_x*centered_y).sum()/denom + + def chi_square(self, x: Tensor, y: Tensor) -> Tensor: + return (2*((x - y)**2)/(x + y)).sum()/x.shape[0] + def intersection(self, x: Tensor, y: Tensor) -> Tensor: + return min(stack([x,y]), dim=0)[0].sum()/x.shape[0] + + def hellinger(self, x: Tensor, y: Tensor) -> Tensor: + x = x.reshape(x.shape[0], -1) + y = y.reshape(x.shape[0], -1) + + base = x.sqrt() - y.sqrt() + dist = (base * base).sum(dim = 1).sqrt() + return dist.mean() + + def kl_div(self, actual: Tensor, expected: Tensor) -> Tensor: + # TODO: connect it to the pre-softmax logits + # to reduce log calculation needs + return _kl_div(actual.log(), expected) + + def metrics(self, x: Tensor, y: Tensor) -> dict[str, Tensor]: + return { + "mse": self.mse(x, y), + "correlation": self.correlation(x, y), + "chi_square": self.chi_square(x, y), + "intersection": self.intersection(x, y), + "hellinger": self.hellinger(x, y), + "kl_div": self.kl_div(x, y) + } class HistogramExtractor(fl.Chain): def __init__( diff --git a/src/refiners/fluxion/adapters/histogram_auto_encoder.py b/src/refiners/fluxion/adapters/histogram_auto_encoder.py index e84ea0e19..08967d4fa 100644 --- a/src/refiners/fluxion/adapters/histogram_auto_encoder.py +++ b/src/refiners/fluxion/adapters/histogram_auto_encoder.py @@ -1,7 +1,9 @@ from PIL import Image from refiners.fluxion.layers.converter import Converter +from refiners.fluxion.utils import summarize_tensor from torch import Tensor, device as Device, dtype as DType, zeros_like, cat -from refiners.fluxion.layers.basics import Unsqueeze, Squeeze +from torch.nn import Softmax +from refiners.fluxion.layers.basics import Reshape, Unsqueeze, Squeeze from refiners.foundationals.latent_diffusion.auto_encoder import Encoder, Decoder from refiners.fluxion.layers import Chain @@ -26,6 +28,7 @@ def __init__( self.n_down_samples = n_down_samples self.latent_dim = latent_dim self.color_bits = color_bits + cube_size = 2 ** color_bits super().__init__( Chain( @@ -49,13 +52,16 @@ def __init__( num_groups = num_groups, resnet_sizes = resnet_sizes, output_channels = histogram_dim, - n_up_samples=n_down_samples, + n_up_samples=n_down_samples, latent_dim = latent_dim, device=device, dtype=dtype ), Squeeze(dim=1) - ) + ), + Reshape(cube_size*cube_size*cube_size), + Softmax(dim=1), + Reshape(cube_size,cube_size,cube_size) ) def encode(self, x: Tensor) -> Tensor: @@ -89,7 +95,6 @@ def embedding_dim(self) -> int: embedding_dim = color_size**3 / self.compression_rate return int(embedding_dim) - def unconditionnal_embedding_like(self, x: Tensor) -> Tensor: numel: int = x.numel() if numel == 0: diff --git a/src/refiners/fluxion/layers/norm.py b/src/refiners/fluxion/layers/norm.py index bc7f0dd38..5ea3a5acb 100644 --- a/src/refiners/fluxion/layers/norm.py +++ b/src/refiners/fluxion/layers/norm.py @@ -5,11 +5,11 @@ InstanceNorm2d as _InstanceNorm2d, LayerNorm as _LayerNorm, Parameter as TorchParameter, + Softmax as _Softmax ) from refiners.fluxion.layers.module import Module, WeightedModule - class LayerNorm(_LayerNorm, WeightedModule): """Layer Normalization layer. diff --git a/src/refiners/fluxion/utils.py b/src/refiners/fluxion/utils.py index 9bfaae51e..d81a7da2c 100644 --- a/src/refiners/fluxion/utils.py +++ b/src/refiners/fluxion/utils.py @@ -142,6 +142,9 @@ def image_to_tensor(image: Image.Image, device: Device | str | None = None, dtyp """ if image.mode == "P": image = image.convert("RGB") + + if image.mode != "RGB": + raise Exception("not an RGB image") image_tensor = torch.tensor(array(image).astype(float32) / 255.0, device=device, dtype=dtype) diff --git a/src/refiners/training_utils/trainers/abstract_color_trainer.py b/src/refiners/training_utils/trainers/abstract_color_trainer.py index 7a5076702..0665bd27a 100644 --- a/src/refiners/training_utils/trainers/abstract_color_trainer.py +++ b/src/refiners/training_utils/trainers/abstract_color_trainer.py @@ -6,13 +6,14 @@ from refiners.training_utils.metrics.color_palette import AbstractColorPrompt, AbstractColorResults from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL from torch import Tensor, randn, tensor - +from refiners.fluxion.adapters.color_palette import ColorPaletteExtractor, ColorPalette +import numpy as np from torch.utils.data import DataLoader from refiners.fluxion.adapters.histogram import ( HistogramExtractor ) - +from PIL import Image from refiners.foundationals.latent_diffusion import ( DPMSolver, StableDiffusion_1, @@ -36,6 +37,8 @@ from torch.utils.data import Dataset + + class ColorTrainerEvaluationConfig(TestDiffusionBaseConfig): db_indexes: list[int] batch_size: int = 1 @@ -119,7 +122,27 @@ def eval_dataloader(self) -> DataLoader[PromptType]: @cached_property def unconditionnal_text_embedding(self) -> Tensor: return self.text_encoder([""]) - + + @cached_property + def color_palette_extractor(self) -> ColorPaletteExtractor: + return ColorPaletteExtractor( + size=self.config.color_palette.max_colors, + weighted_palette=self.config.color_palette.weighted_palette + ) + + def draw_palette(self, palette: ColorPalette, width: int, height: int) -> Image.Image: + palette_img = Image.new(mode="RGB", size=(width, height)) + + # sort the palette by weight + current_x = 0 + for (color, weight) in palette: + box_width = int(weight*width) + color_box = Image.fromarray(np.full((height, box_width, 3), color, dtype=np.uint8)) # type: ignore + palette_img.paste(color_box, box=(current_x, 0)) + current_x+=box_width + + return palette_img + @scoped_seed(5) def compute_batch_evaluation(self, batch: PromptType, same_seed: bool = True) -> ResultType: batch_size = len(batch.source_prompts) diff --git a/src/refiners/training_utils/trainers/color_palette.py b/src/refiners/training_utils/trainers/color_palette.py index e9323ffc8..a26115b89 100644 --- a/src/refiners/training_utils/trainers/color_palette.py +++ b/src/refiners/training_utils/trainers/color_palette.py @@ -86,12 +86,7 @@ def color_palette_encoder(self) -> ColorPaletteEncoder: device=self.device, ) return encoder - @cached_property - def color_palette_extractor(self) -> ColorPaletteExtractor: - return ColorPaletteExtractor( - size=self.config.color_palette.max_colors, - weighted_palette=self.config.color_palette.weighted_palette - ) + @cached_property def color_palette_adapter(self) -> SD1ColorPaletteAdapter[Any]: @@ -220,19 +215,6 @@ def eval_set_adapter_values(self, batch: BatchColorPalettePrompt) -> None: ) ) - def draw_palette(self, palette: ColorPalette, width: int, palette_img_size: int) -> Image.Image: - palette_img = Image.new(mode="RGB", size=(width, palette_img_size)) - - # sort the palette by weight - current_x = 0 - for (color, weight) in palette: - box_width = int(weight*width) - color_box = Image.fromarray(np.full((palette_img_size, box_width, 3), color, dtype=np.uint8)) # type: ignore - palette_img.paste(color_box, box=(current_x, 0)) - current_x+=box_width - - return palette_img - def draw_cover_image(self, batch: BatchColorPaletteResults) -> Image.Image: (batch_size, _, height, width) = batch.result_images.shape diff --git a/src/refiners/training_utils/trainers/histogram.py b/src/refiners/training_utils/trainers/histogram.py index 9914d6256..198cb6d14 100644 --- a/src/refiners/training_utils/trainers/histogram.py +++ b/src/refiners/training_utils/trainers/histogram.py @@ -123,9 +123,9 @@ def load_dataset(self) -> ColorPaletteDataset: def batch_metrics(self, results: BatchHistogramResults, prefix: str = "histogram-img") -> None: - self.log({f"{prefix}/mse": self.histogram_distance( - results.source_histograms.to(device=self.device), - results.result_histograms.to(device=self.device) + self.log({f"{prefix}/loss": self.histogram_distance( + results.result_histograms.to(device=self.device), + results.source_histograms.to(device=self.device) )}) @@ -358,15 +358,30 @@ def draw_cover_image(self, batch: BatchHistogramResults) -> Image.Image: for i in range(batch_size): join_canvas_image.paste(source_images[i].resize((width//2, height//2)), box=(0, height *i)) + source_image_palette = self.draw_palette( + self.color_palette_extractor.from_histogram(source_histograms[i], color_bits= self.config.histogram_auto_encoder.color_bits, size=len(batch.source_palettes[i])), + width//2, + height//16 + ) + join_canvas_image.paste(source_image_palette, box=(0, height *i + height//2)) + + res_image_palette = self.draw_palette( + self.color_palette_extractor.from_histogram(results_histograms[i], color_bits= self.config.histogram_auto_encoder.color_bits, size=len(batch.source_palettes[i])), + width//2, + height//16 + ) + + join_canvas_image.paste(res_image_palette, box=(0, height *i + (15*height)//16)) + for (color_id, color_name) in enumerate(colors): image_curve = self.draw_curves( res_histo_channels[color_id][i].cpu().tolist(), # type: ignore src_histo_channels[color_id][i].cpu().tolist(), # type: ignore color_name, width//2, - height//6 + height//8 ) - join_canvas_image.paste(image_curve, box=(0, height *i + height//2 + color_id*height//6)) + join_canvas_image.paste(image_curve, box=(0, height *i + height//2 + ((1+2*color_id)*height)//16)) return join_canvas_image diff --git a/src/refiners/training_utils/trainers/histogram_auto_encoder.py b/src/refiners/training_utils/trainers/histogram_auto_encoder.py index 2007e4a34..896718361 100644 --- a/src/refiners/training_utils/trainers/histogram_auto_encoder.py +++ b/src/refiners/training_utils/trainers/histogram_auto_encoder.py @@ -1,25 +1,29 @@ from functools import cached_property +from tkinter import W from typing import Any from loguru import logger +from refiners.training_utils.wandb import WandbLoggable -from torch import Tensor +from torch import Tensor, isnan from refiners.fluxion.adapters.histogram_auto_encoder import HistogramAutoEncoder -from torch.nn.functional import mse_loss import refiners.fluxion.layers as fl from refiners.fluxion.adapters.histogram import ( HistogramDistance, - HistogramExtractor + HistogramExtractor, + histogram_to_histo_channels ) from torch.utils.data import DataLoader from refiners.fluxion.utils import save_to_safetensors from refiners.training_utils.callback import Callback, GradientNormLayerLogging from refiners.training_utils.config import BaseConfig -from refiners.training_utils.datasets.color_palette import ColorPaletteDataset, TextEmbeddingColorPaletteLatentsBatch +from refiners.training_utils.datasets.color_palette import ColorDatasetConfig, ColorPaletteDataset, TextEmbeddingColorPaletteLatentsBatch from refiners.training_utils.huggingface_datasets import HuggingfaceDatasetConfig from refiners.training_utils.trainers.trainer import Trainer from pydantic import BaseModel +from refiners.fluxion.adapters.color_palette import ColorPaletteExtractor, ColorPalette +from PIL import Image, ImageDraw class HistogramAutoEncoderConfig(BaseModel): latent_dim: int @@ -28,9 +32,9 @@ class HistogramAutoEncoderConfig(BaseModel): color_bits: int class TrainHistogramAutoEncoderConfig(BaseConfig): - dataset: HuggingfaceDatasetConfig + dataset: ColorDatasetConfig histogram_auto_encoder: HistogramAutoEncoderConfig - eval_dataset: HuggingfaceDatasetConfig + eval_dataset: ColorDatasetConfig class HistogramAutoEncoderTrainer( Trainer[TrainHistogramAutoEncoderConfig, TextEmbeddingColorPaletteLatentsBatch] @@ -53,6 +57,26 @@ def load_eval_dataset(self) -> ColorPaletteDataset: config=self.config.eval_dataset ) + def draw_palette(self, palette: ColorPalette, width: int, height: int) -> Image.Image: + palette_img = Image.new(mode="RGB", size=(width, height)) + + # sort the palette by weight + current_x = 0 + for (color, weight) in palette: + box_width = int(weight*width) + color_box = Image.fromarray(np.full((height, box_width, 3), color, dtype=np.uint8)) # type: ignore + palette_img.paste(color_box, box=(current_x, 0)) + current_x+=box_width + + return palette_img + + @cached_property + def color_palette_extractor(self) -> ColorPaletteExtractor: + return ColorPaletteExtractor( + size=8, + weighted_palette=True + ) + @cached_property def histogram_auto_encoder(self) -> HistogramAutoEncoder: assert self.config.models["histogram_auto_encoder"] is not None, "The config must contain a histogram entry." @@ -60,7 +84,8 @@ def histogram_auto_encoder(self) -> HistogramAutoEncoder: latent_dim=self.config.histogram_auto_encoder.latent_dim, resnet_sizes=self.config.histogram_auto_encoder.resnet_sizes, n_down_samples=self.config.histogram_auto_encoder.n_down_samples, - device=self.device + device=self.device, + color_bits=self.config.histogram_auto_encoder.color_bits ) logger.info(f"Building autoencoder with compression rate {autoencoder.compression_rate}") return autoencoder @@ -80,11 +105,14 @@ def load_models(self) -> dict[str, fl.Module]: def compute_loss(self, batch: TextEmbeddingColorPaletteLatentsBatch) -> Tensor: - actual = self.histogram_extractor.images_to_histograms([item.image for item in batch], device = self.device, dtype = self.dtype) + expected = self.histogram_extractor.images_to_histograms([item.image for item in batch], device = self.device, dtype = self.dtype) - expected = self.histogram_auto_encoder(actual) + actual = self.histogram_auto_encoder(expected) + + if isnan(actual).any(): + raise ValueError("The autoencoder produced NaNs.") - loss = mse_loss(actual, expected) + loss = self.histogram_distance(actual, expected) return loss @@ -97,7 +125,6 @@ def histogram_distance(self) -> HistogramDistance: return HistogramDistance(color_bits=self.config.histogram_auto_encoder.color_bits) @cached_property - def eval_dataloader(self) -> DataLoader[TextEmbeddingColorPaletteLatentsBatch]: collate_fn = getattr(self.eval_dataset, "collate_fn", None) @@ -109,11 +136,72 @@ def eval_dataloader(self) -> DataLoader[TextEmbeddingColorPaletteLatentsBatch]: num_workers=self.config.training.num_workers ) + def draw_curves(self, res_histo: list[float], src_histo: list[float], color: str, width: int, height: int) -> Image.Image: + histo_img = Image.new(mode="RGB", size=(width, height)) + + draw = ImageDraw.Draw(histo_img) + + if len(res_histo) != len(src_histo): + raise ValueError("The histograms must have the same length.") + + ratio = width/len(res_histo) + semi_height = height//2 + + scale_ratio = 5 + + draw.line([ + (i*ratio, (1-res_histo[i]*scale_ratio)*semi_height + semi_height) for i in range(len(res_histo)) + ], fill=color, width=4) + + draw.line([ + (i*ratio, (1-src_histo[i]*scale_ratio)*semi_height) for i in range(len(src_histo)) + ], fill=color, width=1) + + return histo_img + def compute_evaluation_metrics(self, batch: TextEmbeddingColorPaletteLatentsBatch) -> Tensor: + + expected = self.histogram_extractor.images_to_histograms([item.image for item in batch], device = self.device, dtype = self.dtype) + + actual = self.histogram_auto_encoder(expected) + + metrics = self.histogram_distance.metrics(actual, expected) + log_dict : dict[str, WandbLoggable] = {} + for (key, value) in metrics.items(): + log_dict[f"eval/{key}"] = value + self.log(log_dict) + + + images : dict[str, WandbLoggable] = {} - eval_loss = self.compute_loss(batch) - self.log({f"eval/loss": eval_loss.item()}) - return eval_loss + res_histo_channels = histogram_to_histo_channels(actual) + src_histo_channels = histogram_to_histo_channels(expected) + + batch_size = expected.shape[0] + + histo_h = 30 + width = 256 + + colors = ["red", "green", "blue"] + + joint_canvas = Image.new(mode="RGB", size=(256, 3*histo_h * batch_size)) + draw = ImageDraw.Draw(joint_canvas) + for i in range(batch_size): + draw.line([(0, histo_h * 3 * i), (width, histo_h * 3 * i)], fill="white", width=5) + for (color_id, color_name) in enumerate(colors): + image_curve = self.draw_curves( + res_histo_channels[color_id][i].cpu().tolist(), # type: ignore + src_histo_channels[color_id][i].cpu().tolist(), # type: ignore + color_name, + width, + histo_h + ) + + joint_canvas.paste(image_curve, box=(0, histo_h * (3 * i + color_id))) + + images[f"color_curves/eval"] = joint_canvas + self.log(data=images) + # images = [item.image for item in batch] # [red_gt, green_gt, blue_gt] = images_to_histo_channels(images) diff --git a/src/refiners/training_utils/trainers/trainer.py b/src/refiners/training_utils/trainers/trainer.py index 29a84d0ef..485e36d3e 100644 --- a/src/refiners/training_utils/trainers/trainer.py +++ b/src/refiners/training_utils/trainers/trainer.py @@ -651,3 +651,5 @@ def evaluate(self) -> None: def _call_callbacks(self, event_name: str) -> None: for callback in self.callbacks: getattr(callback, event_name)(self) + + diff --git a/tests/adapters/test_histogram.py b/tests/adapters/test_histogram.py index 360d05b6a..a259748ae 100644 --- a/tests/adapters/test_histogram.py +++ b/tests/adapters/test_histogram.py @@ -2,7 +2,8 @@ from refiners.fluxion.adapters.histogram import HistogramDistance, HistogramEncoder, HistogramExtractor, ColorLoss, histogram_to_histo_channels, sorted_channels_to_histo_channels, tensor_to_sorted_channels from refiners.fluxion.utils import image_to_tensor, tensor_to_image - +from PIL import Image +import numpy as np def test_histogram_extractor() -> None: color_bits = 3 @@ -25,7 +26,12 @@ def test_histogram_extractor() -> None: histogram_white = extractor(img_white) assert abs(histogram_white[0, -1, -1, -1] - 1.0) < 1e-4, "histogram_white should be 1.0 at -1,-1,-1,-1" assert abs(histogram_white.sum() - 1.0) < 1e-4, "histogram sum should equal 1.0" - + + imarray = np.random.rand(256,256,3) * 255 + image = Image.fromarray(imarray.astype('uint8')).convert('RGB') + histogram_img = extractor.images_to_histograms([image, image]) + assert abs(histogram_img.sum() - 2.0) < 1e-5, "histogram sum should equal 1.0" + def test_images_histogram_extractor() -> None: color_bits = 3 @@ -47,7 +53,7 @@ def test_histogram_distance() -> None: distance = HistogramDistance() color_bits = 2 color_size = 2**color_bits - batch_size = 2 + batch_size = 10 histo1 = torch.rand((batch_size, color_size, color_size, color_size)) sum1 = histo1.sum() @@ -58,7 +64,22 @@ def test_histogram_distance() -> None: histo2 = histo2 / sum2 dist_same = distance(histo1, histo1) - assert dist_same == 0.0, "distance between himself should be 0.0" + assert abs(dist_same) < 1e-4, "distance between himself should be 0.0" + + dist_diff = distance(histo1, histo2) + assert dist_diff >= 0.0, "distance should more than 0.0" + + dist_bhattacharyya = distance.bhattacharyya(histo1, histo2) + assert dist_bhattacharyya >= 0.0, "distance bhattacharyya should be more than 0" + + dist_kl_div = distance.kl_div(histo1, histo2) + assert dist_kl_div >= 0.0, "distance kl div himself should more than 0.0" + + dist_intersection_same = distance.intersection(histo1, histo1) + assert abs(dist_intersection_same - 1.0) < 1e-6, "distance intersection should be 1" + + dist_correlation_same = distance.correlation(histo1, histo1) + assert dist_correlation_same > 0.0, "distance correlation should be more than 0" def test_histogram_encoder() -> None: