Skip to content

haydn-jones/DiffJPEG-JAX

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

4 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

DiffJPEG: A Jax Implementation

This is a Jax implementation of the differentiable JPEG compression algorithm, based on the PyTorch implementation and some of the modifications found in this repository to improve quality at high compression rates.

Requirements

  • JAX

Installation

Can be installed with pip:

pip install diffjpeg_jax

Usage

Unlike the PyTorch version, this is ML library agnostic, so it simply is implemented as a function. Inputs should be in the range [0, 255] and in the format (H, W, C).

from diffjpeg_jax import diff_jpeg

img = ... # (H, W, C)
jpeg = diff_jpeg(img, quality=75)

Note: The implementation is not wrapped in JIT, so make sure to do that if you want to. For batch processing just use vmap.

About

Differentiable JPEG compression in JAX

Topics

Resources

License

Stars

Watchers

Forks

Packages

No packages published

Languages