This is a Jax implementation of the differentiable JPEG compression algorithm, based on the PyTorch implementation and some of the modifications found in this repository to improve quality at high compression rates.
- JAX
Can be installed with pip:
pip install diffjpeg_jax
Unlike the PyTorch version, this is ML library agnostic, so it simply is implemented as a function. Inputs should be in the range [0, 255]
and in the format (H, W, C)
.
from diffjpeg_jax import diff_jpeg
img = ... # (H, W, C)
jpeg = diff_jpeg(img, quality=75)
Note: The implementation is not wrapped in JIT, so make sure to do that if you want to. For batch processing just use vmap.