-
Notifications
You must be signed in to change notification settings - Fork 70
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
code and model for the paper EVC: Towards Real-Time Neural Image Comp…
…ression with Mask Decay in ICLR 2023.
- Loading branch information
1 parent
dcf4e40
commit bb57b1d
Showing
39 changed files
with
3,493 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
[flake8] | ||
max-line-length = 120 | ||
exclude= | ||
**/build/ | ||
**/checkpoints/download.py |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,134 @@ | ||
# EVC: Towards Real-Time Neural Image Compression with Mask Decay | ||
|
||
This is the official Pytorch implementation for [EVC: Towards Real-Time Neural Image Compression with Mask Decay](https://openreview.net/forum?id=XUxad2Gj40n), in ICLR 2023. | ||
|
||
## Prerequisites | ||
|
||
Environment: | ||
```bash | ||
conda create -n $YOUR_PY38_ENV_NAME python=3.8 | ||
conda activate $YOUR_PY38_ENV_NAME | ||
conda install pytorch=1.11.0 torchvision=0.12.0 torchaudio=0.11.0 cudatoolkit=11.3 -c pytorch | ||
pip install -r requirements.txt | ||
``` | ||
|
||
If you want to test the model with writing bitstream, please build this project. | ||
|
||
### On Windows | ||
```bash | ||
cd src | ||
mkdir build | ||
cd build | ||
conda activate $YOUR_PY38_ENV_NAME | ||
cmake ../cpp -G "Visual Studio 16 2019" -A x64 | ||
cmake --build . --config Release | ||
``` | ||
|
||
### On Linux | ||
```bash | ||
sudo apt-get install cmake g++ | ||
cd src | ||
mkdir build | ||
cd build | ||
conda activate $YOUR_PY38_ENV_NAME | ||
cmake ../cpp -DCMAKE_BUILD_TYPE=Release | ||
make -j | ||
``` | ||
|
||
## Test Dataset | ||
|
||
Please download the Kodak dataset from [http://r0k.us/graphics/kodak/](http://r0k.us/graphics/kodak/). | ||
|
||
The folder is organized as | ||
``` | ||
.../kodak/ | ||
- kodim01.png | ||
- kodim02.png | ||
- ... | ||
``` | ||
|
||
Please modify the `root_path` in `test_cfg/local_kodak.json` to match the data location. | ||
|
||
# Pretrained Models | ||
|
||
* Download our [pretrained models](https://1drv.ms/u/s!AozfVVwtWWYoiUhZLZDx7vJjHK1C?e=qETpA1) and put them into ./checkpoints folder. | ||
* Or run the script in ./checkpoints directly to download the model. | ||
|
||
| i_frame_model | i_frame_model_path | | ||
| :---: | :----: | | ||
| EVC_LL | EVC_LL.pth.tar | | ||
| EVC_ML | EVC_ML_MD.pth.tar | | ||
| EVC_SL | EVC_SL_MD.pth.tar | | ||
| EVC_LM | EVC_LM_MD.pth.tar | | ||
| EVC_LS | EVC_LS_MD.pth.tar | | ||
| EVC_MM | EVC_MM_MD.pth.tar | | ||
| EVC_SS | EVC_SS_MD.pth.tar | | ||
| Scale_EVC_SL | Scale_EVC_SL_MDRRL.pth.tar | | ||
| Scale_EVC_SS | Scale_EVC_SS_MDRRL.pth.tar | | ||
|
||
# Test the Models | ||
|
||
We use `test_image_codec.py` to test our models. Please check this file and make sure you have downloaded models in `checkpoints`. Then, just run | ||
``` | ||
python run_test_image.py | ||
``` | ||
|
||
If you want to test with actual bitstream writing, set `--write_stream 1` in `command_line`. | ||
|
||
## CPU performance scaling | ||
|
||
Note that the arithmetic coding runs on the CPU, please make sure your CPU runs at high performance while writing the actual bitstream. Otherwise, the arithmetic coding may take a long time. | ||
|
||
Check the CPU frequency by | ||
``` | ||
grep -E '^model name|^cpu MHz' /proc/cpuinfo | ||
``` | ||
|
||
Run the following command to maximum CPU frequency | ||
``` | ||
echo performance | sudo tee /sys/devices/system/cpu/cpu*/cpufreq/scaling_governor | ||
``` | ||
|
||
Run the following command to recover the default frequency | ||
``` | ||
echo ondemand | sudo tee /sys/devices/system/cpu/cpu*/cpufreq/scaling_governor | ||
``` | ||
|
||
# Architecture | ||
|
||
The framework: | ||
|
||
data:image/s3,"s3://crabby-images/23080/230805fe5801b44b835de4b76aae741e550ae88c" alt="framework" | ||
|
||
The encoder and decoder: | ||
|
||
data:image/s3,"s3://crabby-images/ac459/ac459af6ba19c7ec4af1e83e41c35199460869c8" alt="enc_dec" | ||
|
||
The hyperprior: | ||
|
||
data:image/s3,"s3://crabby-images/e550b/e550b606b357d19ee2c9ad57a575b9701a78ca85" alt="hyperprior" | ||
|
||
The dual prior: | ||
|
||
data:image/s3,"s3://crabby-images/c4401/c4401d1fd5173660c55d670047b1ab4be596ff16" alt="dualprior" | ||
|
||
|
||
# R-D Curves | ||
|
||
data:image/s3,"s3://crabby-images/7a965/7a965f5996ab45a61d5c4b459418bf5bca53daaf" alt="PSNR RD Curve" | ||
|
||
# Citation | ||
|
||
If you find the work useful for your research, please cite: | ||
``` | ||
@inproceedings{wang2023EVC, | ||
title={EVC: Towards Real-Time Neural Image Compression with Mask Decay}, | ||
author={Wang, Guo-Hua and Li, Jiahao and Li, Bin and Lu, Yan}, | ||
booktitle={International Conference on Learning Representations}, | ||
year={2023} | ||
} | ||
``` | ||
|
||
# Acknowledgement | ||
|
||
* CompressAI: [https://github.com/InterDigitalInc/CompressAI](https://github.com/InterDigitalInc/CompressAI) |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
import urllib.request | ||
|
||
|
||
def download_one(url, target): | ||
urllib.request.urlretrieve(url, target) | ||
|
||
|
||
def main(): | ||
urls = { | ||
'https://onedrive.live.com/download?cid=2866592D5C55DF8C&resid=2866592D5C55DF8C%211227&authkey=AD8e586WrFlT6IE': 'EVC_LL.pth.tar', | ||
'https://onedrive.live.com/download?cid=2866592D5C55DF8C&resid=2866592D5C55DF8C%211225&authkey=AOOYBdkfEmZ9rTo': 'EVC_LM_MD.pth.tar', | ||
'https://onedrive.live.com/download?cid=2866592D5C55DF8C&resid=2866592D5C55DF8C%211226&authkey=ADp_pN4gvxbHMrw': 'EVC_LS_MD.pth.tar', | ||
'https://onedrive.live.com/download?cid=2866592D5C55DF8C&resid=2866592D5C55DF8C%211228&authkey=AHCLXyxrm3UdXxU': 'EVC_ML_MD.pth.tar', | ||
'https://onedrive.live.com/download?cid=2866592D5C55DF8C&resid=2866592D5C55DF8C%211229&authkey=AGT8gpE50lHHixI': 'EVC_MM_MD.pth.tar', | ||
'https://onedrive.live.com/download?cid=2866592D5C55DF8C&resid=2866592D5C55DF8C%211230&authkey=ABwOafGhqBQcT9I': 'EVC_SL_MD.pth.tar', | ||
'https://onedrive.live.com/download?cid=2866592D5C55DF8C&resid=2866592D5C55DF8C%211231&authkey=ANrIn85RgtBH2wM': 'EVC_SS_MD.pth.tar', | ||
'https://onedrive.live.com/download?cid=2866592D5C55DF8C&resid=2866592D5C55DF8C%211233&authkey=AC8tZbxQdbJDXCU': 'Scale_EVC_SL_MDRRL.pth.tar', | ||
'https://onedrive.live.com/download?cid=2866592D5C55DF8C&resid=2866592D5C55DF8C%211232&authkey=AAy8Q8QMM0dUxKg': 'Scale_EVC_SS_MDRRL.pth.tar', | ||
} | ||
for url in urls: | ||
target = urls[url] | ||
print("downloading", target) | ||
download_one(url, target) | ||
print("downloaded", target) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# Licensed under the MIT License. | ||
|
||
import torch | ||
import torch.nn as nn | ||
from src.models import build_model | ||
from ptflops import get_model_complexity_info | ||
|
||
|
||
class IntraCodec(nn.Module): | ||
def __init__(self, model): | ||
super().__init__() | ||
self.model = model | ||
|
||
def forward(self, x): | ||
with torch.no_grad(): | ||
result = self.model.forward(x, q_scale=1.0) | ||
return result | ||
|
||
|
||
def print_model(): | ||
net = build_model("EVC_SS") | ||
Codec = IntraCodec | ||
model = Codec(net) | ||
img_size = (3, 1920, 1088) | ||
|
||
macs, params = get_model_complexity_info(model, img_size, as_strings=True, | ||
print_per_layer_stat=True, verbose=True) | ||
print('{:<30} {:<8}'.format('Computational complexity: ', macs)) | ||
print('{:<30} {:<8}'.format('Number of parameters: ', params)) | ||
|
||
print(f" macs {macs} params {params}") | ||
|
||
|
||
if __name__ == "__main__": | ||
print_model() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
numpy>=1.20.0 | ||
scipy | ||
torch>=1.9.0 | ||
pytorch-msssim==0.2.0 | ||
protobuf==3.20 | ||
tqdm | ||
bd-metric | ||
ptflops |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
# RD numbers on Kodak | ||
|
||
EncL_DecL = { | ||
'bpp': [ | ||
0.328361083, | ||
0.500810833, | ||
0.731926708, | ||
0.981324417, | ||
], | ||
'PSNR': [ | ||
32.47535863, | ||
34.47508771, | ||
36.41845, | ||
37.86458587, | ||
] | ||
} | ||
|
||
EncM_DecM = { | ||
'bpp': [ | ||
0.332920792, | ||
0.505808708, | ||
0.738309958, | ||
0.987386125, | ||
], | ||
'PSNR': [ | ||
32.46911329, | ||
34.42935488, | ||
36.34560738, | ||
37.75338558, | ||
] | ||
} | ||
|
||
|
||
EncS_DecS = { | ||
'bpp': [ | ||
0.339319125, | ||
0.51432875, | ||
0.750090625, | ||
1.001192042, | ||
], | ||
'PSNR': [ | ||
32.31969629, | ||
34.24087479, | ||
36.08858388, | ||
37.40119083, | ||
] | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# Licensed under the MIT License. | ||
|
||
import os | ||
|
||
|
||
def test_one_model(i_frame_model, checkpoint): | ||
root_folder = "output" | ||
output_json_path = f"{root_folder}/output_json/arch:{i_frame_model}_ckpt:{checkpoint}.json" | ||
image_model = f"checkpoints/{checkpoint}" | ||
|
||
test_cfg = 'local_kodak.json' | ||
|
||
command_line = (" python test_image.py " | ||
f" --i_frame_model {i_frame_model}" | ||
f" --i_frame_model_path {image_model}" | ||
f" --test_config ./test_cfg/{test_cfg}" | ||
" --cuda 1 -w 1 --rate_num 4" | ||
" --write_stream 0 --ec_thread 1" | ||
" --verbose 1" | ||
# " --save_decoded_frame True" | ||
f" --output_path {output_json_path}") | ||
|
||
print(command_line) | ||
os.system(command_line) | ||
|
||
|
||
def main(): | ||
# i_frame_model = "EVC_LL" | ||
# checkpoint = 'EVC_LL.pth.tar' | ||
|
||
# i_frame_model = "EVC_ML" | ||
# checkpoint = 'EVC_ML_MD.pth.tar' | ||
|
||
# i_frame_model = "EVC_SL" | ||
# checkpoint = 'EVC_SL_MD.pth.tar' | ||
|
||
# i_frame_model = "EVC_LM" | ||
# checkpoint = 'EVC_LM_MD.pth.tar' | ||
|
||
# i_frame_model = "EVC_LS" | ||
# checkpoint = 'EVC_LS_MD.pth.tar' | ||
|
||
# i_frame_model = "EVC_MM" | ||
# checkpoint = 'EVC_MM_MD.pth.tar' | ||
|
||
i_frame_model = "EVC_SS" | ||
checkpoint = 'EVC_SS_MD.pth.tar' | ||
|
||
# i_frame_model = "Scale_EVC_SL" | ||
# checkpoint = 'Scale_EVC_SL_MDRRL.pth.tar' | ||
|
||
# i_frame_model = "Scale_EVC_SS" | ||
# checkpoint = 'Scale_EVC_SS_MDRRL.pth.tar' | ||
test_one_model(i_frame_model, checkpoint) | ||
|
||
|
||
# latency on kodak | ||
""" | ||
CUDA_VISIBLE_DEVICES=0 python test_all_image.py 2>&1 | tee "log.txt" | ||
cat log.txt | grep latency: | tail -n 94 | awk '{a+=$2}END{print a/NR}' | ||
""" | ||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# Licensed under the MIT License. | ||
|
||
add_subdirectory(pybind11) | ||
add_subdirectory(ryg_rans) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# Licensed under the MIT License. | ||
|
||
configure_file(CMakeLists.txt.in pybind11-download/CMakeLists.txt) | ||
execute_process(COMMAND ${CMAKE_COMMAND} -G "${CMAKE_GENERATOR}" . | ||
RESULT_VARIABLE result | ||
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/pybind11-download ) | ||
if(result) | ||
message(FATAL_ERROR "CMake step for pybind11 failed: ${result}") | ||
endif() | ||
execute_process(COMMAND ${CMAKE_COMMAND} --build . | ||
RESULT_VARIABLE result | ||
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/pybind11-download ) | ||
if(result) | ||
message(FATAL_ERROR "Build step for pybind11 failed: ${result}") | ||
endif() | ||
|
||
add_subdirectory(${CMAKE_CURRENT_BINARY_DIR}/pybind11-src/ | ||
${CMAKE_CURRENT_BINARY_DIR}/pybind11-build/ | ||
EXCLUDE_FROM_ALL) | ||
|
||
set(PYBIND11_INCLUDE | ||
${CMAKE_CURRENT_BINARY_DIR}/pybind11-src/include/ | ||
CACHE INTERNAL "") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
cmake_minimum_required(VERSION 3.6.3) | ||
|
||
project(pybind11-download NONE) | ||
|
||
include(ExternalProject) | ||
if(IS_DIRECTORY "${PROJECT_BINARY_DIR}/3rdparty/pybind11/pybind11-src/include") | ||
ExternalProject_Add(pybind11 | ||
GIT_REPOSITORY https://github.com/pybind/pybind11.git | ||
GIT_TAG v2.6.1 | ||
GIT_SHALLOW 1 | ||
SOURCE_DIR "${CMAKE_CURRENT_BINARY_DIR}/pybind11-src" | ||
BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}/pybind11-build" | ||
DOWNLOAD_COMMAND "" | ||
UPDATE_COMMAND "" | ||
CONFIGURE_COMMAND "" | ||
BUILD_COMMAND "" | ||
INSTALL_COMMAND "" | ||
TEST_COMMAND "" | ||
) | ||
else() | ||
ExternalProject_Add(pybind11 | ||
GIT_REPOSITORY https://github.com/pybind/pybind11.git | ||
GIT_TAG v2.6.1 | ||
GIT_SHALLOW 1 | ||
SOURCE_DIR "${CMAKE_CURRENT_BINARY_DIR}/pybind11-src" | ||
BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}/pybind11-build" | ||
UPDATE_COMMAND "" | ||
CONFIGURE_COMMAND "" | ||
BUILD_COMMAND "" | ||
INSTALL_COMMAND "" | ||
TEST_COMMAND "" | ||
) | ||
endif() |
Oops, something went wrong.