Skip to content

Commit

Permalink
Merge pull request #3 from Justin-Tan/ans_compression
Browse files Browse the repository at this point in the history
Adds rANS compression support.
  • Loading branch information
Justin-Tan authored Sep 7, 2020
2 parents 56cb383 + 39d29a1 commit 5f9ce2a
Show file tree
Hide file tree
Showing 17 changed files with 2,633 additions and 382 deletions.
133 changes: 61 additions & 72 deletions README.md

Large diffs are not rendered by default.

62 changes: 62 additions & 0 deletions assets/EXAMPLES.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
Original | Reconstruction
:-------------------------:|:-------------------------:
![guess](assets/hific/CLIC2020_5_RECON_0.160bpp.png) | ![guess](assets/originals/CLIC2020_5.png)

<details>

<summary>Image 1</summary>

```python
Original: B (11.6 bpp) | HIFIC: A (0.160 bpp). Ratio: 72.5.
```

</details>

A | B
:-------------------------:|:-------------------------:
![guess](assets/originals/CLIC2020_20.png) | ![guess](assets/hific/CLIC2020_20_RECON_0.330bpp.png)

<details>

<summary>Image 2</summary>

```python
Original: A (14.6 bpp) | HIFIC: B (0.330 bpp). Ratio: 44.2
```

</details>

A | B
:-------------------------:|:-------------------------:
![guess](assets/originals/CLIC2020_18.png) | ![guess](assets/hific/CLIC2020_18_RECON_0.209bpp.png)

<details>

<summary>Image 3</summary>

```python
Original: A (12.3 bpp) | HIFIC: B (0.209 bpp). Ratio: 58.9
```

</details>

A | B
:-------------------------:|:-------------------------:
![guess](assets/hific/CLIC2020_19_RECON_0.565bpp.png) | ![guess](assets/originals/CLIC2020_19.png)

<details>

<summary>Image 4</summary>

```python
Original: B (19.9 bpp) | HIFIC: A (0.565 bpp). Ratio: 35.2
```

</details>

| Tables | Are | Cool |
|:------------- |:-------------:| -----:|
| col 3 is | right-aligned | $1600 |
| col 2 is | centered | $12 |
| col 1 is | left-aligned | $42 |
| zebra stripes | are neat | $1 |
125 changes: 125 additions & 0 deletions assets/USAGE_GUIDE.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
# Usage Guide

## Details

This repository defines a model for learnable image compression capable of compressing images of arbitrary size and resolution based on the paper ["High-Fidelity Generative Image Compression" (HIFIC) by Mentzer et. al.](https://hific.github.io/). There are three main components to this model, as described in the original paper:

1. An autoencoding architecture defining a nonlinear transform to latent space. This is used in place of the linear transforms used by traditional image codecs.
2. A hierarchical (two-level in this case) entropy model over the quantized latent representation enabling lossless compression through standard entropy coding.
3. A generator-discriminator component that encourages the decoder/generator component to yield realistic reconstructions.

The model is then trained end-to-end by optimization of a modified rate-distortion Lagrangian. Loosely, the model can be thought of as 'amortizing' the storage requirements for an generic input image through training a learnable compression/decompression scheme. The method is further described in the original paper [[0](https://arxiv.org/abs/2006.09965)]. The model is capable of yielding perceptually similar reconstructions to the input that tend to be more visually pleasing than standard image codecs which operate at comparable or higher bitrates.

This repository also includes a partial port of the [Tensorflow Compression library](https://github.com/tensorflow/compression) for general tools for neural image compression.

## Training

* Download a large (> 100,000) dataset of diverse color images. We found that using 1-2 training divisions of the [OpenImages](https://storage.googleapis.com/openimages/web/index.html) dataset was able to produce satisfactory results on arbitrary images. [Fabian Mentzer's L3C Repo](https://github.com/fab-jul/L3C-PyTorch/) provides utility functions for downloading and preprocessing OpenImages (the trained models did not use this exact split). Add the dataset path under the `DatasetPaths` class in `default_config.py`. Check default config/command line arguments:

```bash
vim default_config.py
python3 train.py -h
```

* For best results, as described in the paper, train an initial base model using the rate-distortion loss only, together with the hyperprior model, e.g. to target low bitrates:

```bash
# Train initial autoencoding model
python3 train.py --model_type compression --regime low --n_steps 1e6
```

* Then use the checkpoint of the trained base model to 'warmstart' the GAN architecture. Training the generator and discriminator from scratch was found to result in unstable training, but YMMV.

```bash
# Train using full generator-discriminator loss
python3 train.py --model_type compression_gan --regime low --n_steps 1e6 --warmstart --ckpt path/to/base/checkpoint
```

* Training after the warmstart for 2e5 steps using a batch size of 16 was sufficient to get reasonable results at sub-0.2 `bpp` per validation image, on average, using the default config in the `low` regime. You can change regimes to `med` or `high` to tradeoff perceptual quality for increased bitrate.

* Perceptual distortion metrics and `bpp` tend to decrease with a pareto-like distribution over training, so model quality can probably be significantly improved by training for an extremely large number of steps.

* If you get out-of-memory errors, try, in decreasing order of priority:
* Decreasing the batch size (default 16).
* Decreasing the number of channels of the latent representation (`latent_channels`, default 220). You may be able to reduce this quite aggressively as the network is highly over-parameterized - many values of the latent representation are near-deterministic.
* Reducing the number of residual blocks in the generator (`n_residual_blocks`, default 7, the original paper used 9).
* Training on smaller crops (`crop_size`, default `256 x 256`).

* Logs for each experiment, including image reconstructions, are automatically created and periodically saved under `experiments/` with the appropriate name/timestamp. Metrics can be visualized via `tensorboard`:

```bash
tensorboard --logdir experiments/my_experiment/tensorboard --port 2401
```

Some sample logs for a couple of models can be found below:

* [Low bitrate regime (warmstart)](https://tensorboard.dev/experiment/xJV4hjbxRFy3TzrdYl7MXA/).
* [Low bitrate regime (full GAN loss)](https://tensorboard.dev/experiment/ETa0JIeOS0ONNZuNkIdrQw/).
* [High bitrate regime (full GAN loss)](https://tensorboard.dev/experiment/hAf1NYrqSVieKoDOcNpoGw/).

## Compression

* `compress.py` will compress generic images under some specified entropy model. This performs a forward pass through the model to obtain the compressed representation, optionally coding the representation using a vectorized rANS entropy coder. As the model architecture is fully convolutional, compression will work with images of arbitrary size/resolution (subject to memory constraints).

* For message transmission, separate entropy models over the latents and hyperlatents must be instantiated and shared between sender and receiver.
* The sender computes the bottleneck tensor and calls the `compress()` method in `src/model.py` to obtain the compressed representation for transmission.
* The receiver calls the `decompress()` method in `src/model.py` to obtain the quantized bottleneck tensor, which is then passed through the generator to obtain the reconstruction.

* The compression scheme in hierarchial in the sense that two 'levels' of information representing the latent and hyperlatent variables must be compressed and stored in the message, together with the shape of the encoded data.

```bash
# Check arguments
python3 compress.py -h

python3 compress.py -i path/to/image/dir -ckpt path/to/trained/model --reconstruct
```

* Optionally, reconstructions from the compressed format can be generated by passing the `--reconstruct` flag. Decoding without executing the rANS coder takes around 2-3 seconds for ~megapixel images on GPU, but this can definitely be optimized. As the hyperprior entropy model involves a series of matrix multiplications, decoding is significantly faster on GPU.

* Executing the rANS coder is quite slow currently and represents a performance bottleneck. Passing the `--vectorize` flag is much faster, but incurs a constant-bit overhead. The batch size needs to be quite large to make this overhead negligible, suitable for e.g. video frames but not so good for general images. Working on a fix.

## Pretrained Models

* Pretrained models using the OpenImages dataset can be found below. The examples at the end of this readme were produced using the `HIFIC-med` model. Each model was trained for around `2e5` warmup steps and `2e5` steps with the full generative loss. Note the original paper trained for `1e6` steps in each mode, so you can probably get better performance by training from scratch yourself.

* To use a pretrained model, download the selected model (~2 GB) and point the `-ckpt` argument in the command above to the corresponding path. If you want to finetune this model, e.g. on some domain-specific dataset, use the following options for each respective model (you will probably need to adapt the learning rate and rate-penalty schedule yourself):

| Target bitrate (bpp) | Weights | Training Instructions |
| ----------- | -------------------------------- | ---------------------- |
| 0.14 | [`HIFIC-low`](https://drive.google.com/open?id=1hfFTkZbs_VOBmXQ-M4bYEPejrD76lAY9) | <pre lang=bash>`python3 train.py --model_type compression_gan --regime low --warmstart -ckpt path/to/trained/model -nrb 9 -norm`</pre> |
| 0.30 | [`HIFIC-med`](https://drive.google.com/open?id=1QNoX0AGKTBkthMJGPfQI0dT0_tnysYUb) | <pre lang=bash>`python3 train.py --model_type compression_gan --regime med --warmstart -ckpt path/to/trained/model --likelihood_type logistic`</pre> |
| 0.45 | [`HIFIC-high`](https://drive.google.com/open?id=1BFYpvhVIA_Ek2QsHBbKnaBE8wn1GhFyA) | <pre lang=bash>`python3 train.py --model_type compression_gan --regime high --warmstart -ckpt path/to/trained/model -nrb 9 -norm`</pre> |

## Extensibility

* Network architectures can be modified by changing the respective files under `src/network`.
* The entropy model for both latents and hyperlatents can be changed by modifying `src/network/hyperprior`. For reference, there is an implementation of a discrete-logistic latent mixture model instead of the default latent mean-scale Gaussian model.
* The exact compression algorithm used can be replaced with any entropy coder that makes use of indexed probability tables.

## Notes

* The reported `bpp` is the theoretical bitrate required to losslessly store the quantized latent representation of an image. Comparing this (not the size of the reconstruction) against the original size of the image will give you an idea of the reduction in memory footprint.
* The total size of the model (using the original architecture) is around 737 MB. Forward pass time should scale sublinearly provided everything fits in memory. A complete forward pass using a batch of 10 `256 x 256` images takes around 45s on a 2.8 GHz Intel Core i7.
* You may get an OOM error when compressing images which are too large (`>~ 4000 x 4000` on a typical consumer GPU). It's possible to get around this by splitting the input into distinct crops whose forward pass will fit in memory. We're working on a fix to automatically support this.
* Compression of >~ megapixel images takes around 8 GB of RAM.

## Contributing

Feel free to submit any questions/corrections/suggestions/bugs as issues. Pull requests are welcome. Thanks to Grace for helping refactor my code.

### References

The following additional papers were useful to understand implementation details.

0. Fabian Mentzer, George Toderici, Michael Tschannen, Eirikur Agustsson. High-Fidelity Generative Image Compression. [arXiv:2006.09965 (2020)](https://arxiv.org/abs/2006.09965).
1. Johannes Ballé, David Minnen, Saurabh Singh, Sung Jin Hwang, Nick Johnston. Variational image compression with a scale hyperprior. [arXiv:1802.01436 (2018)](https://arxiv.org/abs/1802.01436).
2. David Minnen, Johannes Ballé, George Toderici. Joint Autoregressive and Hierarchical Priors for Learned Image Compression. [arXiv 1809.02736 (2018)](https://arxiv.org/abs/1809.02736).
3. Johannes Ballé, Valero Laparra, Eero P. Simoncelli. End-to-end optimization of nonlinear transform codes for perceptual quality. [arXiv 1607.05006 (2016)](https://arxiv.org/abs/1607.05006).
4. Fabian Mentzer, Eirikur Agustsson, Michael Tschannen, Radu Timofte, Luc Van Gool. Practical Full Resolution Learned Lossless Image Compression. [arXiv 1811.12817 (2018)](https://arxiv.org/abs/1811.12817).

## TODO (priority descending)

* Include `torchac` support for entropy coding.
* Implement universal code for overflow values.
* Investigate bit overhead in vectorized rANS implementation.
* Rewrite rANS implementation for speed.
27 changes: 21 additions & 6 deletions compress.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def make_deterministic(seed=42):
def compress_batch(args):

# Reproducibility
make_deterministic()
# make_deterministic()
perceptual_loss_fn = ps.PerceptualLoss(model='net-lin', net='alex', use_gpu=torch.cuda.is_available())

# Load model
Expand All @@ -41,20 +41,26 @@ def compress_batch(args):
loaded_args, model, _ = utils.load_model(args.ckpt_path, logger, device, model_mode=ModelModes.EVALUATION,
current_args_d=None, prediction=True, strict=False)

# Override current arguments with recorded
dictify = lambda x: dict((n, getattr(x, n)) for n in dir(x) if not (n.startswith('__') or 'logger' in n))
loaded_args_d, args_d = dictify(loaded_args), dictify(args)
loaded_args_d.update(args_d)
args = utils.Struct(**loaded_args_d)
logger.info(loaded_args_d)

# Build probability tables
model.Hyperprior.hyperprior_entropy_model.build_tables()


eval_loader = datasets.get_dataloaders('evaluation', root=args.image_dir, batch_size=args.batch_size,
logger=logger, shuffle=False, normalize=args.normalize_input_image)

n, N = 0, len(eval_loader.dataset)
input_filenames_total = list()
output_filenames_total = list()
bpp_total, q_bpp_total, LPIPS_total = torch.Tensor(N), torch.Tensor(N), torch.Tensor(N)

utils.makedirs(args.output_dir)

start_time = time.time()

with torch.no_grad():
Expand All @@ -63,7 +69,14 @@ def compress_batch(args):
data = data.to(device, dtype=torch.float)
B = data.size(0)

reconstruction, q_bpp, n_bpp = model(data, writeout=False)
if args.reconstruct is True:
# Reconstruction without compression
reconstruction, q_bpp = model(data, writeout=False)
else:
# Perform entropy coding
compressed_output = model.compress(data)
reconstruction = model.decompress(compressed_output)
q_bpp = compressed_output.total_bpp

if args.normalize_input_image is True:
# [-1., 1.] -> [0., 1.]
Expand All @@ -77,13 +90,14 @@ def compress_batch(args):
if B > 1:
q_bpp_per_im = float(q_bpp.cpu().numpy()[subidx])
else:
q_bpp_per_im = float(q_bpp.item())
q_bpp_per_im = float(q_bpp.item()) if type(q_bpp) == torch.Tensor else float(q_bpp)

fname = os.path.join(args.output_dir, "{}_RECON_{:.3f}bpp.png".format(filenames[subidx], q_bpp_per_im))
torchvision.utils.save_image(reconstruction[subidx], fname, normalize=True)
output_filenames_total.append(fname)

bpp_total[n:n + B] = bpp.data
q_bpp_total[n:n + B] = q_bpp.data
q_bpp_total[n:n + B] = q_bpp.data if type(q_bpp) == torch.Tensor else q_bpp
LPIPS_total[n:n + B] = perceptual_loss.data
n += B

Expand Down Expand Up @@ -116,6 +130,7 @@ def main(**kwargs):
help="Path to directory to store output images")
parser.add_argument('-bs', '--batch_size', type=int, default=1,
help="Loader batch size. Set to 1 if images in directory are different sizes.")
parser.add_argument("-rc", "--reconstruct", help="Reconstruct input image without compression.", action="store_true")
args = parser.parse_args()

input_images = glob.glob(os.path.join(args.image_dir, '*.jpg'))
Expand All @@ -125,7 +140,7 @@ def main(**kwargs):

print('Input images')
pprint(input_images)
# Launch training

compress_batch(args)

if __name__ == '__main__':
Expand Down
8 changes: 1 addition & 7 deletions default_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class ModelTypes(object):
class ModelModes(object):
TRAINING = 'training'
VALIDATION = 'validation'
EVALUATION = 'evaluation'
EVALUATION = 'evaluation' # actual entropy coding

class Datasets(object):
OPENIMAGES = 'openimages'
Expand All @@ -29,12 +29,6 @@ class DatasetPaths(object):
class directories(object):
experiments = 'experiments'

class checkpoints(object):
low_rate1 = 'experiments/norm_low_rate_openimages_compression_2020_08_19_16_13/checkpoints/norm_low_rate_openimages_compression_2020_08_19_16_13_epoch2_idx168720_2020_08_21_04:00.pt'
low_rate_nrb9 = 'experiments/low_rate9_norm_openimages_compression_2020_08_19_16_59/checkpoints/low_rate9_norm_openimages_compression_2020_08_19_16_59_epoch4_idx237436_2020_08_22_00:21.pt'
# python3 train.py -n low_rate_gan_v1_norm -mt compression_gan -bs 8 -norm --regime low -steps 1e6 --warmstart -ckpt
# experiments/norm_low_rate_openimages_compression_2020_08_19_16_13/checkpoints/norm_low_rate_openimages_compression_2020_08_19_16_13_epoch2_idx168720_2020_08_21_04:00.pt

class args(object):
"""
Shared config
Expand Down
Loading

0 comments on commit 5f9ce2a

Please sign in to comment.