Skip to content

Latest commit

 

History

History
224 lines (173 loc) · 13.5 KB

README.md

File metadata and controls

224 lines (173 loc) · 13.5 KB

Perceptual Compression (PerCo)

This repository provides a PyTorch implementation of PerCo based on:

Different from the original work, we use Stable Diffusion v2.1 (Rombach et al., CVPR 2022) as latent diffusion model and hence refer to our work as PerCo (SD). This is to differentiate from the official work, which is based on a proprietary, not publicly available, pre-trained variant based on GLIDE (Nichol et al., ICML 2022).

Under active development.

Updates

06/16/2024

  1. Finetuned whole U-Net (not just linear layers)
  2. Slightly improved results (limited to 50k optimization steps)
  3. Released pre-trained models
  4. Ablation studies: experimented with LoRA and FSQ (no improvements achieved)

05/29/2024

  1. Switched back to official hyper-encoder design, resolved training instabilities
  2. Significantly improved results (limited to 50k optimization steps)

05/24/2024

  1. Initial release of this project

Visual Impressions

Visual Comparison on the Kodak dataset, for our lowest bit-rate (0.0019bpp). Column 1: ground truth. Columns 2-5: set of reconstructions that reflect the uncertainty about the original image source.

0.0019_kodim13_a river runs through a rocky forest with mountains in the background.png
Global conditioning: "a river runs through a rocky forest with mountains in the background".
0.0019_kodim22_a red barn with a pond in the background.png
Global conditioning: "a red barn with a pond in the background".
0.0019_kodim23_two parrots standing next to each other with leaves in the background.png
Global conditioning: "two parrots standing next to each other with leaves in the background".

More visual results can be found here.

Quantitative Performance

In this section we quantitatively compare the performance of PerCo (SD v2.1) to the officially reported numbers. All models were trained using a reduced set of optimization steps (50k). Note that the performance is bounded by the LDM auto-encoder, denoted as SD v2.1 auto-encoder.

We generally obtain highly competitive results in terms of perception (FID, KID), especially for the ultra-low bit-rates, but at the cost of lower image fidelity (MS-SSIM, LPIPS). Note that PerCo (official) was trained using 5 epochs (9M training samples / batch size 160 * 5 epochs = 281250 optimization steps) vs. 50k steps, which roughly corresponds to 18%. Also note that we have not yet considered LPIPS as an auxiliary loss, which is known to increase performance at higher bit-rates.

We will continue our experiments and hope to release more powerful variants at a later stage.

PerCo (official) vs. PerCo (SDv2.1)

Install

$ git clone https://github.com/Nikolai10/PerCo.git 

Please follow our Installation Guide with Docker.

Training/ Inference/ Evaluation

Please have a look at the example notebook for more information.

We use the OpenImagesV6 training dataset by default, similar to MS-ILLM. Please familiarize yourself with the data loading mechanisms (see _openimages_v6.py) and adjust the file paths and training settings in config.py accordingly. Corrupted images must be excluded, see _INVALID_IMAGE_NAMES for more details.

We also provide a simplified Google Colab demo that integrates any tfds dataset (e.g. CLIC 2020), with no data engineering tasks involved: open tutorial.

TODOs

  • Compression functionality
    • adopt script logic presented in MS2020
    • provide decompression functionality as custom HuggingFace pipeline
    • add zlib compression functionality (captions)
    • add entropy coding functionality (hyper-encoder)
    • use DDIM scheduler for inference (20/5 denoising steps)
  • Provide evaluation code/ compare quantitatively to PerCo (official)
  • Training pipeline
    • use train_text_to_image.py as starting point
    • integrate tfds to make use of Open Images v4 (1.7M images)
    • integrate full OpenImagesV6 (9M images) based on NeuralCompression
    • obtain captions dynamically at runtime
    • adjust conditioning logic (z_l, z_g)
    • optimizer AdamW
      • 5 epochs, on 512x512 crops (for now: limited to 50k iterations)
      • peak learning rate 1e-4 -> we use 1e-5
      • weight decay 0.01
      • bs = 160 (w/o LPIPS), bs = 40 (w/ LPIPS)
      • linear warm-up 10k
      • train hyper-encoder + finetune linear all layers of U-Net
      • exchange traditional noise prediction objective with v-prediction
      • add LPIPS loss for target rates > 0.05bpp
    • add classifier-free guidance (drop text-conditioning in 10% of iterations)
    • override validation logic (add validation images)
  • BLIP 2
  • Hyper-encoder
    • request hyper-encoder design from authors
    • integrate improved VQ-VAE functionality (Yu et al. ICLR 2022)
    • wrap into (ModelMixin, ConfigMixin) to make use of convenient loading/ saving
  • U-Net
    • extend the kernel of the first conv layer
    • initialize newly created variables randomly

Note:

  • we have not adjusted the finetuning grid to 50 timesteps as described in the paper.
  • we use Stable Diffusion v2.1 as LDM, due to its native shift from epsilon to v-prediction. In general, however, this project also supports SD 1.X variants with minor adjustments:
    from helpers import update_scheduler
    
    pipe = StableDiffusionPipelinePerco.from_pretrained(...)
    # add this line if you are using v-prediction
    update_scheduler(pipe)

Pre-trained Models

Pre-trained models corresponding to 0.1250bpp, 0.0313bpp and 0.0019bpp can be downloaded here.

All models were trained using a DGX H100 using the following command:

# note that prediction_type must equal config.py prediction_type
!accelerate launch --multi_gpu --num_processes=8 /tf/notebooks/PerCo/src/train_sd_perco.py \
--pretrained_model_name_or_path="stabilityai/stable-diffusion-2-1" \
--validation_image "/tf/notebooks/PerCo/res/eval/kodim13.png" "/tf/notebooks/PerCo/res/eval/kodim23.png" \
--allow_tf32 \
--dataloader_num_workers=12 \
--resolution=512 --center_crop --random_flip \
--train_batch_size=20 \
--gradient_accumulation_steps=1 \
--num_train_epochs=5 \
--max_train_steps 50000 \
--validation_steps 500 \
--prediction_type="v_prediction" \
--checkpointing_steps 500 \
--learning_rate=1e-05 \
--adam_weight_decay=1e-2 \
--max_grad_norm=1 \
--lr_scheduler="constant" \
--lr_warmup_steps=10000 \
--checkpoints_total_limit=2 \
--output_dir="/tf/notebooks/PerCo/res/cmvl_2024"

If you find better hyper-parameters, please share them with the community.

Directions for Improvement

  • Investigate scalar quantizer + hyper-decoder (similar to Agustsson et al. ICCV 2019)
  • The authors only considered controlling the bit-rate via upper bound (i.e. uniform coding scheme); incorporating a powerful entropy model will likely exceed the reported performance.

File Structure

 docker                                             # Docker functionality + dependecies
     ├── install.txt                                 
 notebooks                                          # jupyter-notebooks
     ├── FilterMSCOCO.ipynb                         # How to obtain MS-COCO 30k        
     ├── PerceptualCompression.ipynb                # How to train and eval PerCo                 
 res
     ├── cmvl_2024/                                 # saved model, checkpoints, log files
     ├── data/                                      # evaluation data (must be downloaded separately)
        ├── Kodak/                                  # Kodak dataset (https://r0k.us/graphics/kodak/, 24 images)
        ├── Kodak_gen/                              # Kodak reconstructions
        ├── MSCOCO30k/                              # MS-COCO 30k dataset (see ./notebooks/FilterMSCOCO.ipynb)
        ├── MSCOCO30k_gen/                          # MS-COCO 30k reconstructions
     ├── doc/                                       # addtitional resources
     ├── eval/                                      # sample images + reconstructions
 src
     ├── diffusers/                                 # local copy of https://github.com/huggingface/diffusers (v0.27.0)
     ├── compression_utils.py                       # CLI tools for PerCo compression/ decompression
     ├── config.py                                  # PerCo global configuration (training + inference)
     ├── helpers.py                                 # helper functionality
     ├── hyper_encoder_v2.py                        # hyper-encoder + quantization (based on HiFiC)
     ├── hyper_encoder.py                           # hyper-encoder + quantization (based on ELIC)
     ├── lpips_stable.py                            # stable LPIPS implementation based on MS-ILLM/ NeuralCompression
     ├── openimages_v6.py                           # minimalistic dataloader for OpenImagesV6
     ├── pipeline_sd_perco.py                       # custom HuggingFace Pipeline which bundles image generation (=decompression)
     ├── tfds_interface.py                          # simple PyTorch wrapper to make use of tfds
     ├── train_sd_perco.py                          # PerCo training functionality
     ├── unet_2d_perco.py                           # extended UNet2DConditionModel which accepts local features from the hyper-encoder

Acknowledgment

This project is based on/ takes inspiration from:

  • Diffusers, a library for state-of-the-art pretrained diffusion models for generative AI provided by HuggingFace (Stable Diffusion).
  • Transformers, a library for state-of-the-art machine learning models provided by HuggingFace (BLIP 2).
  • Vector Quantization - PyTorch, a vector quantization library (improved VQ-VAE).
  • TensorFlow Datasets (TFDS), a collection of ready-to-use datasets (we make use of the TensorFlow-less NumPy-only data loading to access open_images_v4).
  • CompressAI, a PyTorch library and evaluation platform for end-to-end compression research.
  • TensorFlow Compression (TFC), a TF library dedicated to data compression (we adopt the convenient script logic for compression/ decompression).
  • NeuralCompression, a Python repository dedicated to research of neural networks that compress data (we make use of the stable LPIPS implementation + HiFiC-based building blocks).
  • torchac: Fast Arithmetic Coding for PyTorch (we use torchac to compress the hyper-latent).

We thank the authors for providing us with the official evaluation points as well as helpful insights.

License

Apache License 2.0