diff --git a/.flake8 b/.flake8 new file mode 100644 index 0000000..f1521a8 --- /dev/null +++ b/.flake8 @@ -0,0 +1,5 @@ +[flake8] +max-line-length = 120 +exclude= + **/build/ + **/checkpoints/download.py diff --git a/ICLR2023/README.md b/ICLR2023/README.md new file mode 100644 index 0000000..af7fa1d --- /dev/null +++ b/ICLR2023/README.md @@ -0,0 +1,134 @@ +# EVC: Towards Real-Time Neural Image Compression with Mask Decay + +This is the official Pytorch implementation for [EVC: Towards Real-Time Neural Image Compression with Mask Decay](https://openreview.net/forum?id=XUxad2Gj40n), in ICLR 2023. + +## Prerequisites + +Environment: +```bash +conda create -n $YOUR_PY38_ENV_NAME python=3.8 +conda activate $YOUR_PY38_ENV_NAME +conda install pytorch=1.11.0 torchvision=0.12.0 torchaudio=0.11.0 cudatoolkit=11.3 -c pytorch +pip install -r requirements.txt +``` + +If you want to test the model with writing bitstream, please build this project. + +### On Windows +```bash +cd src +mkdir build +cd build +conda activate $YOUR_PY38_ENV_NAME +cmake ../cpp -G "Visual Studio 16 2019" -A x64 +cmake --build . --config Release +``` + +### On Linux +```bash +sudo apt-get install cmake g++ +cd src +mkdir build +cd build +conda activate $YOUR_PY38_ENV_NAME +cmake ../cpp -DCMAKE_BUILD_TYPE=Release +make -j +``` + +## Test Dataset + +Please download the Kodak dataset from [http://r0k.us/graphics/kodak/](http://r0k.us/graphics/kodak/). + +The folder is organized as +``` +.../kodak/ + - kodim01.png + - kodim02.png + - ... +``` + +Please modify the `root_path` in `test_cfg/local_kodak.json` to match the data location. + +# Pretrained Models + +* Download our [pretrained models](https://1drv.ms/u/s!AozfVVwtWWYoiUhZLZDx7vJjHK1C?e=qETpA1) and put them into ./checkpoints folder. +* Or run the script in ./checkpoints directly to download the model. + +| i_frame_model | i_frame_model_path | +| :---: | :----: | +| EVC_LL | EVC_LL.pth.tar | +| EVC_ML | EVC_ML_MD.pth.tar | +| EVC_SL | EVC_SL_MD.pth.tar | +| EVC_LM | EVC_LM_MD.pth.tar | +| EVC_LS | EVC_LS_MD.pth.tar | +| EVC_MM | EVC_MM_MD.pth.tar | +| EVC_SS | EVC_SS_MD.pth.tar | +| Scale_EVC_SL | Scale_EVC_SL_MDRRL.pth.tar | +| Scale_EVC_SS | Scale_EVC_SS_MDRRL.pth.tar | + +# Test the Models + +We use `test_image_codec.py` to test our models. Please check this file and make sure you have downloaded models in `checkpoints`. Then, just run +``` +python run_test_image.py +``` + +If you want to test with actual bitstream writing, set `--write_stream 1` in `command_line`. + +## CPU performance scaling + +Note that the arithmetic coding runs on the CPU, please make sure your CPU runs at high performance while writing the actual bitstream. Otherwise, the arithmetic coding may take a long time. + +Check the CPU frequency by +``` +grep -E '^model name|^cpu MHz' /proc/cpuinfo +``` + +Run the following command to maximum CPU frequency +``` +echo performance | sudo tee /sys/devices/system/cpu/cpu*/cpufreq/scaling_governor +``` + +Run the following command to recover the default frequency +``` +echo ondemand | sudo tee /sys/devices/system/cpu/cpu*/cpufreq/scaling_governor +``` + +# Architecture + +The framework: + +![framework](assets/framework.png) + +The encoder and decoder: + +![enc_dec](assets/enc_dec.png) + +The hyperprior: + +![hyperprior](assets/hyperprior.png) + +The dual prior: + +![dualprior](assets/dualprior.png) + + +# R-D Curves + +![PSNR RD Curve](assets/RD.png) + +# Citation + +If you find the work useful for your research, please cite: +``` +@inproceedings{wang2023EVC, + title={EVC: Towards Real-Time Neural Image Compression with Mask Decay}, + author={Wang, Guo-Hua and Li, Jiahao and Li, Bin and Lu, Yan}, + booktitle={International Conference on Learning Representations}, + year={2023} +} +``` + +# Acknowledgement + +* CompressAI: [https://github.com/InterDigitalInc/CompressAI](https://github.com/InterDigitalInc/CompressAI) \ No newline at end of file diff --git a/ICLR2023/assets/RD.png b/ICLR2023/assets/RD.png new file mode 100644 index 0000000..e8dfbb3 Binary files /dev/null and b/ICLR2023/assets/RD.png differ diff --git a/ICLR2023/assets/dualprior.png b/ICLR2023/assets/dualprior.png new file mode 100644 index 0000000..1a5ab15 Binary files /dev/null and b/ICLR2023/assets/dualprior.png differ diff --git a/ICLR2023/assets/enc_dec.png b/ICLR2023/assets/enc_dec.png new file mode 100644 index 0000000..fa8d666 Binary files /dev/null and b/ICLR2023/assets/enc_dec.png differ diff --git a/ICLR2023/assets/framework.png b/ICLR2023/assets/framework.png new file mode 100644 index 0000000..701c23c Binary files /dev/null and b/ICLR2023/assets/framework.png differ diff --git a/ICLR2023/assets/hyperprior.png b/ICLR2023/assets/hyperprior.png new file mode 100644 index 0000000..1a631c2 Binary files /dev/null and b/ICLR2023/assets/hyperprior.png differ diff --git a/ICLR2023/checkpoints/download.py b/ICLR2023/checkpoints/download.py new file mode 100644 index 0000000..4a247fc --- /dev/null +++ b/ICLR2023/checkpoints/download.py @@ -0,0 +1,28 @@ +import urllib.request + + +def download_one(url, target): + urllib.request.urlretrieve(url, target) + + +def main(): + urls = { + 'https://onedrive.live.com/download?cid=2866592D5C55DF8C&resid=2866592D5C55DF8C%211227&authkey=AD8e586WrFlT6IE': 'EVC_LL.pth.tar', + 'https://onedrive.live.com/download?cid=2866592D5C55DF8C&resid=2866592D5C55DF8C%211225&authkey=AOOYBdkfEmZ9rTo': 'EVC_LM_MD.pth.tar', + 'https://onedrive.live.com/download?cid=2866592D5C55DF8C&resid=2866592D5C55DF8C%211226&authkey=ADp_pN4gvxbHMrw': 'EVC_LS_MD.pth.tar', + 'https://onedrive.live.com/download?cid=2866592D5C55DF8C&resid=2866592D5C55DF8C%211228&authkey=AHCLXyxrm3UdXxU': 'EVC_ML_MD.pth.tar', + 'https://onedrive.live.com/download?cid=2866592D5C55DF8C&resid=2866592D5C55DF8C%211229&authkey=AGT8gpE50lHHixI': 'EVC_MM_MD.pth.tar', + 'https://onedrive.live.com/download?cid=2866592D5C55DF8C&resid=2866592D5C55DF8C%211230&authkey=ABwOafGhqBQcT9I': 'EVC_SL_MD.pth.tar', + 'https://onedrive.live.com/download?cid=2866592D5C55DF8C&resid=2866592D5C55DF8C%211231&authkey=ANrIn85RgtBH2wM': 'EVC_SS_MD.pth.tar', + 'https://onedrive.live.com/download?cid=2866592D5C55DF8C&resid=2866592D5C55DF8C%211233&authkey=AC8tZbxQdbJDXCU': 'Scale_EVC_SL_MDRRL.pth.tar', + 'https://onedrive.live.com/download?cid=2866592D5C55DF8C&resid=2866592D5C55DF8C%211232&authkey=AAy8Q8QMM0dUxKg': 'Scale_EVC_SS_MDRRL.pth.tar', + } + for url in urls: + target = urls[url] + print("downloading", target) + download_one(url, target) + print("downloaded", target) + + +if __name__ == "__main__": + main() diff --git a/ICLR2023/model_complexity.py b/ICLR2023/model_complexity.py new file mode 100644 index 0000000..287d5f8 --- /dev/null +++ b/ICLR2023/model_complexity.py @@ -0,0 +1,36 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import torch +import torch.nn as nn +from src.models import build_model +from ptflops import get_model_complexity_info + + +class IntraCodec(nn.Module): + def __init__(self, model): + super().__init__() + self.model = model + + def forward(self, x): + with torch.no_grad(): + result = self.model.forward(x, q_scale=1.0) + return result + + +def print_model(): + net = build_model("EVC_SS") + Codec = IntraCodec + model = Codec(net) + img_size = (3, 1920, 1088) + + macs, params = get_model_complexity_info(model, img_size, as_strings=True, + print_per_layer_stat=True, verbose=True) + print('{:<30} {:<8}'.format('Computational complexity: ', macs)) + print('{:<30} {:<8}'.format('Number of parameters: ', params)) + + print(f" macs {macs} params {params}") + + +if __name__ == "__main__": + print_model() diff --git a/ICLR2023/requirements.txt b/ICLR2023/requirements.txt new file mode 100644 index 0000000..2adf711 --- /dev/null +++ b/ICLR2023/requirements.txt @@ -0,0 +1,8 @@ +numpy>=1.20.0 +scipy +torch>=1.9.0 +pytorch-msssim==0.2.0 +protobuf==3.20 +tqdm +bd-metric +ptflops diff --git a/ICLR2023/results/RD_numbers.py b/ICLR2023/results/RD_numbers.py new file mode 100644 index 0000000..d0491c4 --- /dev/null +++ b/ICLR2023/results/RD_numbers.py @@ -0,0 +1,47 @@ +# RD numbers on Kodak + +EncL_DecL = { + 'bpp': [ + 0.328361083, + 0.500810833, + 0.731926708, + 0.981324417, + ], + 'PSNR': [ + 32.47535863, + 34.47508771, + 36.41845, + 37.86458587, + ] +} + +EncM_DecM = { + 'bpp': [ + 0.332920792, + 0.505808708, + 0.738309958, + 0.987386125, + ], + 'PSNR': [ + 32.46911329, + 34.42935488, + 36.34560738, + 37.75338558, + ] +} + + +EncS_DecS = { + 'bpp': [ + 0.339319125, + 0.51432875, + 0.750090625, + 1.001192042, + ], + 'PSNR': [ + 32.31969629, + 34.24087479, + 36.08858388, + 37.40119083, + ] +} diff --git a/ICLR2023/run_test_image.py b/ICLR2023/run_test_image.py new file mode 100644 index 0000000..3c5270f --- /dev/null +++ b/ICLR2023/run_test_image.py @@ -0,0 +1,65 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import os + + +def test_one_model(i_frame_model, checkpoint): + root_folder = "output" + output_json_path = f"{root_folder}/output_json/arch:{i_frame_model}_ckpt:{checkpoint}.json" + image_model = f"checkpoints/{checkpoint}" + + test_cfg = 'local_kodak.json' + + command_line = (" python test_image.py " + f" --i_frame_model {i_frame_model}" + f" --i_frame_model_path {image_model}" + f" --test_config ./test_cfg/{test_cfg}" + " --cuda 1 -w 1 --rate_num 4" + " --write_stream 0 --ec_thread 1" + " --verbose 1" + # " --save_decoded_frame True" + f" --output_path {output_json_path}") + + print(command_line) + os.system(command_line) + + +def main(): + # i_frame_model = "EVC_LL" + # checkpoint = 'EVC_LL.pth.tar' + + # i_frame_model = "EVC_ML" + # checkpoint = 'EVC_ML_MD.pth.tar' + + # i_frame_model = "EVC_SL" + # checkpoint = 'EVC_SL_MD.pth.tar' + + # i_frame_model = "EVC_LM" + # checkpoint = 'EVC_LM_MD.pth.tar' + + # i_frame_model = "EVC_LS" + # checkpoint = 'EVC_LS_MD.pth.tar' + + # i_frame_model = "EVC_MM" + # checkpoint = 'EVC_MM_MD.pth.tar' + + i_frame_model = "EVC_SS" + checkpoint = 'EVC_SS_MD.pth.tar' + + # i_frame_model = "Scale_EVC_SL" + # checkpoint = 'Scale_EVC_SL_MDRRL.pth.tar' + + # i_frame_model = "Scale_EVC_SS" + # checkpoint = 'Scale_EVC_SS_MDRRL.pth.tar' + test_one_model(i_frame_model, checkpoint) + + +# latency on kodak +""" +CUDA_VISIBLE_DEVICES=0 python test_all_image.py 2>&1 | tee "log.txt" +cat log.txt | grep latency: | tail -n 94 | awk '{a+=$2}END{print a/NR}' +""" + +if __name__ == "__main__": + main() diff --git a/ICLR2023/src/cpp/3rdparty/CMakeLists.txt b/ICLR2023/src/cpp/3rdparty/CMakeLists.txt new file mode 100644 index 0000000..8f63698 --- /dev/null +++ b/ICLR2023/src/cpp/3rdparty/CMakeLists.txt @@ -0,0 +1,5 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +add_subdirectory(pybind11) +add_subdirectory(ryg_rans) \ No newline at end of file diff --git a/ICLR2023/src/cpp/3rdparty/pybind11/CMakeLists.txt b/ICLR2023/src/cpp/3rdparty/pybind11/CMakeLists.txt new file mode 100644 index 0000000..3c88809 --- /dev/null +++ b/ICLR2023/src/cpp/3rdparty/pybind11/CMakeLists.txt @@ -0,0 +1,24 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +configure_file(CMakeLists.txt.in pybind11-download/CMakeLists.txt) +execute_process(COMMAND ${CMAKE_COMMAND} -G "${CMAKE_GENERATOR}" . + RESULT_VARIABLE result + WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/pybind11-download ) +if(result) + message(FATAL_ERROR "CMake step for pybind11 failed: ${result}") +endif() +execute_process(COMMAND ${CMAKE_COMMAND} --build . + RESULT_VARIABLE result + WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/pybind11-download ) +if(result) + message(FATAL_ERROR "Build step for pybind11 failed: ${result}") +endif() + +add_subdirectory(${CMAKE_CURRENT_BINARY_DIR}/pybind11-src/ + ${CMAKE_CURRENT_BINARY_DIR}/pybind11-build/ + EXCLUDE_FROM_ALL) + +set(PYBIND11_INCLUDE + ${CMAKE_CURRENT_BINARY_DIR}/pybind11-src/include/ + CACHE INTERNAL "") diff --git a/ICLR2023/src/cpp/3rdparty/pybind11/CMakeLists.txt.in b/ICLR2023/src/cpp/3rdparty/pybind11/CMakeLists.txt.in new file mode 100644 index 0000000..f0b4565 --- /dev/null +++ b/ICLR2023/src/cpp/3rdparty/pybind11/CMakeLists.txt.in @@ -0,0 +1,33 @@ +cmake_minimum_required(VERSION 3.6.3) + +project(pybind11-download NONE) + +include(ExternalProject) +if(IS_DIRECTORY "${PROJECT_BINARY_DIR}/3rdparty/pybind11/pybind11-src/include") + ExternalProject_Add(pybind11 + GIT_REPOSITORY https://github.com/pybind/pybind11.git + GIT_TAG v2.6.1 + GIT_SHALLOW 1 + SOURCE_DIR "${CMAKE_CURRENT_BINARY_DIR}/pybind11-src" + BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}/pybind11-build" + DOWNLOAD_COMMAND "" + UPDATE_COMMAND "" + CONFIGURE_COMMAND "" + BUILD_COMMAND "" + INSTALL_COMMAND "" + TEST_COMMAND "" + ) +else() + ExternalProject_Add(pybind11 + GIT_REPOSITORY https://github.com/pybind/pybind11.git + GIT_TAG v2.6.1 + GIT_SHALLOW 1 + SOURCE_DIR "${CMAKE_CURRENT_BINARY_DIR}/pybind11-src" + BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}/pybind11-build" + UPDATE_COMMAND "" + CONFIGURE_COMMAND "" + BUILD_COMMAND "" + INSTALL_COMMAND "" + TEST_COMMAND "" + ) +endif() diff --git a/ICLR2023/src/cpp/3rdparty/ryg_rans/CMakeLists.txt b/ICLR2023/src/cpp/3rdparty/ryg_rans/CMakeLists.txt new file mode 100644 index 0000000..d7a23bf --- /dev/null +++ b/ICLR2023/src/cpp/3rdparty/ryg_rans/CMakeLists.txt @@ -0,0 +1,24 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +configure_file(CMakeLists.txt.in ryg_rans-download/CMakeLists.txt) +execute_process(COMMAND ${CMAKE_COMMAND} -G "${CMAKE_GENERATOR}" . + RESULT_VARIABLE result + WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/ryg_rans-download ) +if(result) + message(FATAL_ERROR "CMake step for ryg_rans failed: ${result}") +endif() +execute_process(COMMAND ${CMAKE_COMMAND} --build . + RESULT_VARIABLE result + WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/ryg_rans-download ) +if(result) + message(FATAL_ERROR "Build step for ryg_rans failed: ${result}") +endif() + +# add_subdirectory(${CMAKE_CURRENT_BINARY_DIR}/ryg_rans-src/ +# ${CMAKE_CURRENT_BINARY_DIR}/ryg_rans-build +# EXCLUDE_FROM_ALL) + +set(RYG_RANS_INCLUDE + ${CMAKE_CURRENT_BINARY_DIR}/ryg_rans-src/ + CACHE INTERNAL "") diff --git a/ICLR2023/src/cpp/3rdparty/ryg_rans/CMakeLists.txt.in b/ICLR2023/src/cpp/3rdparty/ryg_rans/CMakeLists.txt.in new file mode 100644 index 0000000..3c62451 --- /dev/null +++ b/ICLR2023/src/cpp/3rdparty/ryg_rans/CMakeLists.txt.in @@ -0,0 +1,33 @@ +cmake_minimum_required(VERSION 3.6.3) + +project(ryg_rans-download NONE) + +include(ExternalProject) +if(EXISTS "${PROJECT_BINARY_DIR}/3rdparty/ryg_rans/ryg_rans-src/rans64.h") + ExternalProject_Add(ryg_rans + GIT_REPOSITORY https://github.com/rygorous/ryg_rans.git + GIT_TAG c9d162d996fd600315af9ae8eb89d832576cb32d + GIT_SHALLOW 1 + SOURCE_DIR "${CMAKE_CURRENT_BINARY_DIR}/ryg_rans-src" + BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}/ryg_rans-build" + DOWNLOAD_COMMAND "" + UPDATE_COMMAND "" + CONFIGURE_COMMAND "" + BUILD_COMMAND "" + INSTALL_COMMAND "" + TEST_COMMAND "" + ) +else() + ExternalProject_Add(ryg_rans + GIT_REPOSITORY https://github.com/rygorous/ryg_rans.git + GIT_TAG c9d162d996fd600315af9ae8eb89d832576cb32d + GIT_SHALLOW 1 + SOURCE_DIR "${CMAKE_CURRENT_BINARY_DIR}/ryg_rans-src" + BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}/ryg_rans-build" + UPDATE_COMMAND "" + CONFIGURE_COMMAND "" + BUILD_COMMAND "" + INSTALL_COMMAND "" + TEST_COMMAND "" + ) +endif() diff --git a/ICLR2023/src/cpp/CMakeLists.txt b/ICLR2023/src/cpp/CMakeLists.txt new file mode 100644 index 0000000..06001c6 --- /dev/null +++ b/ICLR2023/src/cpp/CMakeLists.txt @@ -0,0 +1,24 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +cmake_minimum_required (VERSION 3.6.3) +project (MLCodec) + +set(CMAKE_CONFIGURATION_TYPES "Debug;Release;RelWithDebInfo" CACHE STRING "" FORCE) + +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED ON) +set(CMAKE_CXX_EXTENSIONS OFF) + +# treat warning as error +if (MSVC) + add_compile_options(/W4 /WX) +else() + add_compile_options(-Wall -Wextra -pedantic -Werror) +endif() + +# The sequence is tricky, put 3rd party first +add_subdirectory(3rdparty) +add_subdirectory (ops) +add_subdirectory (rans) +add_subdirectory (py_rans) diff --git a/ICLR2023/src/cpp/ops/CMakeLists.txt b/ICLR2023/src/cpp/ops/CMakeLists.txt new file mode 100644 index 0000000..ed31abb --- /dev/null +++ b/ICLR2023/src/cpp/ops/CMakeLists.txt @@ -0,0 +1,28 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +cmake_minimum_required(VERSION 3.7) +set(PROJECT_NAME MLCodec_CXX) +project(${PROJECT_NAME}) + +set(cxx_source + ops.cpp + ) + +set(include_dirs + ${CMAKE_CURRENT_SOURCE_DIR} + ${PYBIND11_INCLUDE} + ) + +pybind11_add_module(${PROJECT_NAME} ${cxx_source}) + +target_include_directories (${PROJECT_NAME} PUBLIC ${include_dirs}) + +# The post build argument is executed after make! +add_custom_command( + TARGET ${PROJECT_NAME} POST_BUILD + COMMAND + "${CMAKE_COMMAND}" -E copy + "$" + "${CMAKE_CURRENT_SOURCE_DIR}/../../models/" +) diff --git a/ICLR2023/src/cpp/ops/ops.cpp b/ICLR2023/src/cpp/ops/ops.cpp new file mode 100644 index 0000000..9463ab7 --- /dev/null +++ b/ICLR2023/src/cpp/ops/ops.cpp @@ -0,0 +1,91 @@ +/* Copyright 2020 InterDigital Communications, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#include +#include +#include +#include + +std::vector pmf_to_quantized_cdf(const std::vector &pmf, + int precision) { + /* NOTE(begaintj): ported from `ryg_rans` public implementation. Not optimal + * although it's only run once per model after training. See TF/compression + * implementation for an optimized version. */ + + std::vector cdf(pmf.size() + 1); + cdf[0] = 0; /* freq 0 */ + + std::transform(pmf.begin(), pmf.end(), cdf.begin() + 1, [=](float p) { + return static_cast(std::round(p * (1 << precision)) + 0.5); + }); + + const uint32_t total = std::accumulate(cdf.begin(), cdf.end(), 0); + + std::transform( + cdf.begin(), cdf.end(), cdf.begin(), [precision, total](uint32_t p) { + return static_cast((((1ull << precision) * p) / total)); + }); + + std::partial_sum(cdf.begin(), cdf.end(), cdf.begin()); + cdf.back() = 1 << precision; + + for (int i = 0; i < static_cast(cdf.size() - 1); ++i) { + if (cdf[i] == cdf[i + 1]) { + /* Try to steal frequency from low-frequency symbols */ + uint32_t best_freq = ~0u; + int best_steal = -1; + for (int j = 0; j < static_cast(cdf.size()) - 1; ++j) { + uint32_t freq = cdf[j + 1] - cdf[j]; + if (freq > 1 && freq < best_freq) { + best_freq = freq; + best_steal = j; + } + } + + assert(best_steal != -1); + + if (best_steal < i) { + for (int j = best_steal + 1; j <= i; ++j) { + cdf[j]--; + } + } else { + assert(best_steal > i); + for (int j = i + 1; j <= best_steal; ++j) { + cdf[j]++; + } + } + } + } + + assert(cdf[0] == 0); + assert(cdf.back() == (1u << precision)); + for (int i = 0; i < static_cast(cdf.size()) - 1; ++i) { + assert(cdf[i + 1] > cdf[i]); + } + + return cdf; +} + +PYBIND11_MODULE(MLCodec_CXX, m) { + m.attr("__name__") = "MLCodec_CXX"; + + m.doc() = "C++ utils"; + + m.def("pmf_to_quantized_cdf", &pmf_to_quantized_cdf, + "Return quantized CDF for a given PMF"); +} diff --git a/ICLR2023/src/cpp/py_rans/CMakeLists.txt b/ICLR2023/src/cpp/py_rans/CMakeLists.txt new file mode 100644 index 0000000..b99e3c6 --- /dev/null +++ b/ICLR2023/src/cpp/py_rans/CMakeLists.txt @@ -0,0 +1,30 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +cmake_minimum_required(VERSION 3.7) +set(PROJECT_NAME MLCodec_rans) +project(${PROJECT_NAME}) + +set(py_rans_source + py_rans.h + py_rans.cpp + ) + +set(include_dirs + ${CMAKE_CURRENT_SOURCE_DIR} + ${PYBIND11_INCLUDE} + ) + +pybind11_add_module(${PROJECT_NAME} ${py_rans_source}) + +target_include_directories (${PROJECT_NAME} PUBLIC ${include_dirs}) +target_link_libraries (${PROJECT_NAME} LINK_PUBLIC Rans) + +# The post build argument is executed after make! +add_custom_command( + TARGET ${PROJECT_NAME} POST_BUILD + COMMAND + "${CMAKE_COMMAND}" -E copy + "$" + "${CMAKE_CURRENT_SOURCE_DIR}/../../models/" +) diff --git a/ICLR2023/src/cpp/py_rans/py_rans.cpp b/ICLR2023/src/cpp/py_rans/py_rans.cpp new file mode 100644 index 0000000..33cd080 --- /dev/null +++ b/ICLR2023/src/cpp/py_rans/py_rans.cpp @@ -0,0 +1,269 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +#include "py_rans.h" + +#include +#include + +namespace py = pybind11; + +RansEncoder::RansEncoder(bool multiThread, int streamPart=1) { + bool useMultiThread = multiThread || streamPart > 1; + if (useMultiThread) { + for (int i=0; i()); + } + } else { + for (int i=0; i()); + } + } +} + +void RansEncoder::encode_with_indexes(const py::array_t &symbols, + const py::array_t &indexes, + const py::array_t &cdfs, + const py::array_t &cdfs_sizes, + const py::array_t &offsets) { + py::buffer_info symbols_buf = symbols.request(); + py::buffer_info indexes_buf = indexes.request(); + py::buffer_info cdfs_sizes_buf = cdfs_sizes.request(); + py::buffer_info offsets_buf = offsets.request(); + int16_t *symbols_ptr = static_cast(symbols_buf.ptr); + int16_t *indexes_ptr = static_cast(indexes_buf.ptr); + int32_t *cdfs_sizes_ptr = static_cast(cdfs_sizes_buf.ptr); + int32_t *offsets_ptr = static_cast(offsets_buf.ptr); + + std::vector vec_cdfs_sizes(cdfs_sizes.size()); + memcpy(vec_cdfs_sizes.data(), cdfs_sizes_ptr, + sizeof(int32_t) * cdfs_sizes.size()); + std::vector vec_offsets(offsets.size()); + memcpy(vec_offsets.data(), offsets_ptr, sizeof(int32_t) * offsets.size()); + + std::vector> vec_cdfs; + int cdf_num = static_cast(cdfs_sizes.size()); + int per_vector_size = static_cast(cdfs.size() / cdf_num); + auto cdfs_raw = cdfs.unchecked<2>(); + for (int i = 0; i < cdf_num; i++) { + std::vector t(per_vector_size); + memcpy(t.data(), cdfs_raw.data(i, 0), sizeof(int32_t) * per_vector_size); + vec_cdfs.push_back(t); + } + + int encoderNum = m_encoders.size(); + int perEncoderSymbolSize = symbols.size() / encoderNum; + int lastEncoderSymbolSize = symbols.size() - perEncoderSymbolSize * (encoderNum - 1); + for (int i=0; i < encoderNum - 1; i++) { + std::vector vec_symbols(perEncoderSymbolSize); + memcpy(vec_symbols.data(), symbols_ptr + i*perEncoderSymbolSize, sizeof(int16_t) * perEncoderSymbolSize); + std::vector vec_indexes(perEncoderSymbolSize); + memcpy(vec_indexes.data(), indexes_ptr + i*perEncoderSymbolSize, sizeof(int16_t) * perEncoderSymbolSize); + m_encoders[i]->encode_with_indexes(vec_symbols, vec_indexes, vec_cdfs, + vec_cdfs_sizes, vec_offsets); + } + + std::vector vec_symbols(lastEncoderSymbolSize); + memcpy(vec_symbols.data(), symbols_ptr + (encoderNum - 1)*perEncoderSymbolSize, sizeof(int16_t) * lastEncoderSymbolSize); + std::vector vec_indexes(perEncoderSymbolSize); + memcpy(vec_indexes.data(), indexes_ptr + (encoderNum - 1)*perEncoderSymbolSize, sizeof(int16_t) * lastEncoderSymbolSize); + m_encoders[encoderNum - 1]->encode_with_indexes(vec_symbols, vec_indexes, vec_cdfs, + vec_cdfs_sizes, vec_offsets); +} + +void RansEncoder::flush() { + for (int i=0; i(m_encoders.size()); i++) { + m_encoders[i]->flush(); + } +} + +py::array_t RansEncoder::get_encoded_stream() { + std::vector> results; + int maximumSize = 0; + int totalSize = 0; + int encoderNumber = static_cast(m_encoders.size()); + for (int i=0; i result = m_encoders[i]->get_encoded_stream(); + results.push_back(result); + int nbytes = static_cast(result.size()); + if (i < encoderNumber - 1 && nbytes > maximumSize) { + maximumSize = nbytes; + } + totalSize += nbytes; + } + + int overhead = 1; + int perStreamHeader = maximumSize > 65535 ? 4 : 2; + if (encoderNumber > 1) { + overhead += ((encoderNumber - 1) * perStreamHeader); + } + + py::array_t stream(totalSize + overhead); + py::buffer_info stream_buf = stream.request(); + uint8_t *stream_ptr = static_cast(stream_buf.ptr); + + uint8_t flag = ((encoderNumber - 1) << 4) + (perStreamHeader == 2 ? 1 : 0); + memcpy(stream_ptr, &flag, 1); + for (int i=0; i(results[i].size()); + memcpy(stream_ptr + 1 + 2 * i, &perStreamSize, 2); + } + else { + uint32_t perStreamSize = static_cast(results[i].size()); + memcpy(stream_ptr + 1 + 4 * i, &perStreamSize, 4); + } + } + + int offset = overhead; + for (int i=0; i(results[i].size()); + memcpy(stream_ptr + offset, results[i].data(), nbytes); + offset += nbytes; + } + return stream; +} + +void RansEncoder::reset() { + for (int i=0; i(m_encoders.size()); i++) { + m_encoders[i]->reset(); + } +} + +RansDecoder::RansDecoder(int streamPart) { + for (int i=0; i()); + } +} + +void RansDecoder::set_stream(const py::array_t &encoded) { + py::buffer_info encoded_buf = encoded.request(); + uint8_t flag = *(static_cast(encoded_buf.ptr)); + int numberOfStreams = (flag >> 4) + 1; + assert(numberOfStreams == m_decoders.size()); + + uint8_t perStreamSizeLength = (flag & 0x0f) == 1 ? 2 : 4; + if (numberOfStreams == 1) { + int nbytes = static_cast(encoded.size()) - 1; + std::vector stream(nbytes); + memcpy(stream.data(), static_cast(static_cast(encoded_buf.ptr) + 1), nbytes); + m_decoders[0]->set_stream(stream); + } + else { + std::vector perStreamSize; + int offset = 1; + int totalSize = 0; + for (int i=0; i(static_cast(encoded_buf.ptr) + offset)); + offset += 2; + perStreamSize.push_back(streamSize); + totalSize += streamSize; + } + else { + uint32_t streamSize = *(reinterpret_cast(static_cast(encoded_buf.ptr) + offset)); + offset += 4; + perStreamSize.push_back(streamSize); + totalSize += streamSize; + } + } + perStreamSize.push_back(static_cast(encoded.size()) - offset - totalSize); + for (int i=0; i stream(perStreamSize[i]); + memcpy(stream.data(), static_cast(static_cast(encoded_buf.ptr) + offset), perStreamSize[i]); + m_decoders[i]->set_stream(stream); + offset += perStreamSize[i]; + } + } +} + +std::vector decode_future(RansDecoderLib* pDecoder, PendingDecoding* pendingDecoding) { + return pDecoder->decode_stream(*(pendingDecoding->indexes), *(pendingDecoding->cdfs), *(pendingDecoding->cdfs_sizes), *(pendingDecoding->offsets)); +} + +py::array_t +RansDecoder::decode_stream(const py::array_t &indexes, + const py::array_t &cdfs, + const py::array_t &cdfs_sizes, + const py::array_t &offsets) { + py::buffer_info indexes_buf = indexes.request(); + py::buffer_info cdfs_sizes_buf = cdfs_sizes.request(); + py::buffer_info offsets_buf = offsets.request(); + int16_t *indexes_ptr = static_cast(indexes_buf.ptr); + int32_t *cdfs_sizes_ptr = static_cast(cdfs_sizes_buf.ptr); + int32_t *offsets_ptr = static_cast(offsets_buf.ptr); + + std::vector vec_cdfs_sizes(cdfs_sizes.size()); + memcpy(vec_cdfs_sizes.data(), cdfs_sizes_ptr, + sizeof(int32_t) * cdfs_sizes.size()); + std::vector vec_offsets(offsets.size()); + memcpy(vec_offsets.data(), offsets_ptr, sizeof(int32_t) * offsets.size()); + + std::vector> vec_cdfs; + int cdf_num = static_cast(cdfs_sizes.size()); + int per_vector_size = static_cast(cdfs.size() / cdf_num); + auto cdfs_raw = cdfs.unchecked<2>(); + for (int i = 0; i < cdf_num; i++) { + std::vector t(per_vector_size); + memcpy(t.data(), cdfs_raw.data(i, 0), sizeof(int32_t) * per_vector_size); + vec_cdfs.push_back(t); + } + int decoderNum = m_decoders.size(); + int perDecoderSymbolSize = indexes.size() / decoderNum; + int lastDecoderSymbolSize = indexes.size() - perDecoderSymbolSize * (decoderNum - 1); + + std::vector>> results; + std::vector> vec_indexes(decoderNum); + std::vector pendingDecoding(decoderNum); + + for (int i=0; i> result = std::async(std::launch::async, decode_future, m_decoders[i].get(), &(pendingDecoding[i])); + results.push_back(result); + } + vec_indexes[decoderNum - 1].resize(lastDecoderSymbolSize); + memcpy(vec_indexes[decoderNum - 1].data(), indexes_ptr + (decoderNum - 1) * perDecoderSymbolSize, sizeof(int16_t) * lastDecoderSymbolSize); + + pendingDecoding[decoderNum - 1].indexes = &(vec_indexes[decoderNum - 1]); + pendingDecoding[decoderNum - 1].cdfs = &(vec_cdfs); + pendingDecoding[decoderNum - 1].cdfs_sizes = &(vec_cdfs_sizes); + pendingDecoding[decoderNum - 1].offsets = &(vec_offsets); + std::shared_future> result = std::async(std::launch::async, decode_future, m_decoders[decoderNum - 1].get(), &(pendingDecoding[decoderNum - 1])); + results.push_back(result); + + py::array_t output(indexes.size()); + py::buffer_info buf = output.request(); + int offset = 0; + for (int i=0; i result = results[i].get(); + int resultSize = static_cast(result.size()); + memcpy(static_cast(static_cast(buf.ptr) + offset), result.data(), sizeof(int16_t) * resultSize); + offset += resultSize; + } + + return output; +} + +PYBIND11_MODULE(MLCodec_rans, m) { + m.attr("__name__") = "MLCodec_rans"; + + m.doc() = "range Asymmetric Numeral System python bindings"; + + py::class_(m, "RansEncoder") + .def(py::init()) + .def("encode_with_indexes", &RansEncoder::encode_with_indexes) + .def("flush", &RansEncoder::flush) + .def("get_encoded_stream", &RansEncoder::get_encoded_stream) + .def("reset", &RansEncoder::reset); + + py::class_(m, "RansDecoder") + .def(py::init()) + .def("set_stream", &RansDecoder::set_stream) + .def("decode_stream", &RansDecoder::decode_stream); +} diff --git a/ICLR2023/src/cpp/py_rans/py_rans.h b/ICLR2023/src/cpp/py_rans/py_rans.h new file mode 100644 index 0000000..c3c9859 --- /dev/null +++ b/ICLR2023/src/cpp/py_rans/py_rans.h @@ -0,0 +1,53 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +#pragma once +#include "rans.h" +#include +#include + +namespace py = pybind11; + +// the classes in this file only perform the type conversion +// from python type (numpy) to C++ type (vector) +class RansEncoder { +public: + RansEncoder(bool multiThread, int streamPart); + + RansEncoder(const RansEncoder &) = delete; + RansEncoder(RansEncoder &&) = delete; + RansEncoder &operator=(const RansEncoder &) = delete; + RansEncoder &operator=(RansEncoder &&) = delete; + + void encode_with_indexes(const py::array_t &symbols, + const py::array_t &indexes, + const py::array_t &cdfs, + const py::array_t &cdfs_sizes, + const py::array_t &offsets); + void flush(); + py::array_t get_encoded_stream(); + void reset(); + +private: + std::vector> m_encoders; +}; + +class RansDecoder { +public: + RansDecoder(int streamPart); + + RansDecoder(const RansDecoder &) = delete; + RansDecoder(RansDecoder &&) = delete; + RansDecoder &operator=(const RansDecoder &) = delete; + RansDecoder &operator=(RansDecoder &&) = delete; + + void set_stream(const py::array_t &); + + py::array_t decode_stream(const py::array_t &indexes, + const py::array_t &cdfs, + const py::array_t &cdfs_sizes, + const py::array_t &offsets); + +private: + std::vector> m_decoders; +}; diff --git a/ICLR2023/src/cpp/rans/CMakeLists.txt b/ICLR2023/src/cpp/rans/CMakeLists.txt new file mode 100644 index 0000000..6a734fb --- /dev/null +++ b/ICLR2023/src/cpp/rans/CMakeLists.txt @@ -0,0 +1,22 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +cmake_minimum_required(VERSION 3.7) +set(PROJECT_NAME Rans) +project(${PROJECT_NAME}) + +set(rans_source + rans.h + rans.cpp + ) + +set(include_dirs + ${CMAKE_CURRENT_SOURCE_DIR} + ${RYG_RANS_INCLUDE} + ) + +if (NOT MSVC) + add_compile_options(-fPIC) +endif() +add_library (${PROJECT_NAME} ${rans_source}) +target_include_directories (${PROJECT_NAME} PUBLIC ${include_dirs}) diff --git a/ICLR2023/src/cpp/rans/rans.cpp b/ICLR2023/src/cpp/rans/rans.cpp new file mode 100644 index 0000000..1a0c4a7 --- /dev/null +++ b/ICLR2023/src/cpp/rans/rans.cpp @@ -0,0 +1,332 @@ +/* Copyright 2020 InterDigital Communications, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* Rans64 extensions from: + * https://fgiesen.wordpress.com/2015/12/21/rans-in-practice/ + * Unbounded range coding from: + * https://github.com/tensorflow/compression/blob/master/tensorflow_compression/cc/kernels/unbounded_index_range_coding_kernels.cc + **/ + +#include "rans.h" + +#include +#include +#include + +/* probability range, this could be a parameter... */ +constexpr int precision = 16; + +constexpr uint16_t bypass_precision = 4; /* number of bits in bypass mode */ +constexpr uint16_t max_bypass_val = (1 << bypass_precision) - 1; + +namespace { + +/* Support only 16 bits word max */ +inline void Rans64EncPutBits(Rans64State *r, uint32_t **pptr, uint32_t val, + uint32_t nbits) { + assert(nbits <= 16); + assert(val < (1u << nbits)); + + /* Re-normalize */ + uint64_t x = *r; + uint32_t freq = 1 << (16 - nbits); + uint64_t x_max = ((RANS64_L >> 16) << 32) * freq; + if (x >= x_max) { + *pptr -= 1; + **pptr = (uint32_t)x; + x >>= 32; + Rans64Assert(x < x_max); + } + + /* x = C(s, x) */ + *r = (x << nbits) | val; +} + +inline uint32_t Rans64DecGetBits(Rans64State *r, uint32_t **pptr, + uint32_t n_bits) { + uint64_t x = *r; + uint32_t val = x & ((1u << n_bits) - 1); + + /* Re-normalize */ + x = x >> n_bits; + if (x < RANS64_L) { + x = (x << 32) | **pptr; + *pptr += 1; + Rans64Assert(x >= RANS64_L); + } + + *r = x; + + return val; +} +} // namespace + +void RansEncoderLib::encode_with_indexes( + const std::vector &symbols, const std::vector &indexes, + const std::vector> &cdfs, + const std::vector &cdfs_sizes, + const std::vector &offsets) { + + // backward loop on symbols from the end; + const int16_t *symbols_ptr = symbols.data(); + const int16_t *indexes_ptr = indexes.data(); + const int32_t *cdfs_sizes_ptr = cdfs_sizes.data(); + const int32_t *offsets_ptr = offsets.data(); + const int symbol_size = static_cast(symbols.size()); + _syms.reserve(symbol_size * 3 / 2); + for (int i = 0; i < symbol_size; ++i) { + const int32_t cdf_idx = indexes_ptr[i]; + if (cdf_idx < 0) { + continue; + } + const int32_t *cdf = cdfs[cdf_idx].data(); + const int32_t max_value = cdfs_sizes_ptr[cdf_idx] - 2; + int32_t value = symbols_ptr[i] - offsets_ptr[cdf_idx]; + + uint32_t raw_val = 0; + if (value < 0) { + raw_val = -2 * value - 1; + value = max_value; + } else if (value >= max_value) { + raw_val = 2 * (value - max_value); + value = max_value; + } + + _syms.push_back({static_cast(cdf[value]), + static_cast(cdf[value + 1] - cdf[value]), + false}); + + /* Bypass coding mode (value == max_value -> sentinel flag) */ + if (value == max_value) { + /* Determine the number of bypasses (in bypass_precision size) needed to + * encode the raw value. */ + int32_t n_bypass = 0; + while ((raw_val >> (n_bypass * bypass_precision)) != 0) { + ++n_bypass; + } + + /* Encode number of bypasses */ + int32_t val = n_bypass; + while (val >= max_bypass_val) { + _syms.push_back({max_bypass_val, max_bypass_val + 1, true}); + val -= max_bypass_val; + } + _syms.push_back( + {static_cast(val), static_cast(val + 1), true}); + + /* Encode raw value */ + for (int32_t j = 0; j < n_bypass; ++j) { + const int32_t val1 = + (raw_val >> (j * bypass_precision)) & max_bypass_val; + _syms.push_back({static_cast(val1), + static_cast(val1 + 1), true}); + } + } + } +} + +void RansEncoderLib::flush() { + Rans64State rans; + Rans64EncInit(&rans); + + std::vector output(_syms.size()); // too much space ? + uint32_t *ptr = output.data() + output.size(); + assert(ptr != nullptr); + + while (!_syms.empty()) { + const RansSymbol sym = _syms.back(); + + if (!sym.bypass) { + Rans64EncPut(&rans, &ptr, sym.start, sym.range, precision); + } else { + // unlikely... + Rans64EncPutBits(&rans, &ptr, sym.start, bypass_precision); + } + _syms.pop_back(); + } + + Rans64EncFlush(&rans, &ptr); + + const int nbytes = static_cast( + std::distance(ptr, output.data() + output.size()) * sizeof(uint32_t)); + + _stream.resize(nbytes); + memcpy(_stream.data(), ptr, nbytes); +} + +std::vector RansEncoderLib::get_encoded_stream() { return _stream; } + +void RansEncoderLib::reset() { _syms.clear(); } + +RansEncoderLibMultiThread::RansEncoderLibMultiThread() + : RansEncoderLib(), m_finish(false), m_result_ready(false), + m_thread(std::thread(&RansEncoderLibMultiThread::worker, this)) {} + +RansEncoderLibMultiThread::~RansEncoderLibMultiThread() { + { + std::lock_guard lk(m_mutex_pending); + std::lock_guard lk1(m_mutex_result); + m_finish = true; + } + m_cv_pending.notify_one(); + m_cv_result.notify_one(); + m_thread.join(); +} + +void RansEncoderLibMultiThread::encode_with_indexes( + const std::vector &symbols, const std::vector &indexes, + const std::vector> &cdfs, + const std::vector &cdfs_sizes, + const std::vector &offsets) { + PendingEncoding p; + p.workType = WorkType::Encode; + p.symbols = symbols; + p.indexes = indexes; + p.cdfs = cdfs; + p.cdfs_sizes = cdfs_sizes; + p.offsets = offsets; + { + std::unique_lock lk(m_mutex_pending); + m_pending.push_back(p); + } + m_cv_pending.notify_one(); +} + +void RansEncoderLibMultiThread::flush() { + PendingEncoding p; + p.workType = WorkType::Flush; + { + std::unique_lock lk(m_mutex_pending); + m_pending.push_back(p); + } + m_cv_pending.notify_one(); +} + +std::vector RansEncoderLibMultiThread::get_encoded_stream() { + std::unique_lock lk(m_mutex_result); + m_cv_result.wait(lk, [this] { return m_result_ready || m_finish; }); + return RansEncoderLib::get_encoded_stream(); +} + +void RansEncoderLibMultiThread::reset() { + RansEncoderLib::reset(); + std::lock_guard lk(m_mutex_result); + m_result_ready = false; +} + +void RansEncoderLibMultiThread::worker() { + while (!m_finish) { + std::unique_lock lk(m_mutex_pending); + m_cv_pending.wait(lk, [this] { return m_pending.size() > 0 || m_finish; }); + if (m_finish) { + lk.unlock(); + break; + } + if (m_pending.size() == 0) { + lk.unlock(); + // std::cout << "contine in worker" << std::endl; + continue; + } + while (m_pending.size() > 0) { + auto p = m_pending.front(); + m_pending.pop_front(); + lk.unlock(); + if (p.workType == WorkType::Encode) { + RansEncoderLib::encode_with_indexes(p.symbols, p.indexes, p.cdfs, + p.cdfs_sizes, p.offsets); + } else if (p.workType == WorkType::Flush) { + RansEncoderLib::flush(); + { + std::lock_guard lk_result(m_mutex_result); + m_result_ready = true; + } + m_cv_result.notify_one(); + } + lk.lock(); + } + lk.unlock(); + } +} + +void RansDecoderLib::set_stream(const std::vector &encoded) { + _stream = encoded; + _ptr = (uint32_t *)(_stream.data()); + Rans64DecInit(&_rans, &_ptr); +} + +std::vector +RansDecoderLib::decode_stream(const std::vector &indexes, + const std::vector> &cdfs, + const std::vector &cdfs_sizes, + const std::vector &offsets) { + int index_size = static_cast(indexes.size()); + std::vector output(index_size); + + int16_t *outout_ptr = output.data(); + const int16_t *indexes_ptr = indexes.data(); + const int32_t *cdfs_sizes_ptr = cdfs_sizes.data(); + const int32_t *offsets_ptr = offsets.data(); + for (int i = 0; i < index_size; ++i) { + const int32_t cdf_idx = indexes_ptr[i]; + const int32_t offset = offsets_ptr[cdf_idx]; + if (cdf_idx < 0) { + outout_ptr[i] = static_cast(offset); + continue; + } + const int32_t *cdf = cdfs[cdf_idx].data(); + const int32_t max_value = cdfs_sizes_ptr[cdf_idx] - 2; + const uint32_t cum_freq = Rans64DecGet(&_rans, precision); + + const auto cdf_end = cdf + cdfs_sizes_ptr[cdf_idx]; + const auto it = std::find_if(cdf, cdf_end, [cum_freq](int v) { + return static_cast(v) > cum_freq; + }); + const uint32_t s = static_cast(std::distance(cdf, it) - 1); + + Rans64DecAdvance(&_rans, &_ptr, cdf[s], cdf[s + 1] - cdf[s], precision); + + int32_t value = static_cast(s); + + if (value == max_value) { + /* Bypass decoding mode */ + int32_t val = Rans64DecGetBits(&_rans, &_ptr, bypass_precision); + int32_t n_bypass = val; + + while (val == max_bypass_val) { + val = Rans64DecGetBits(&_rans, &_ptr, bypass_precision); + n_bypass += val; + } + + int32_t raw_val = 0; + for (int j = 0; j < n_bypass; ++j) { + val = Rans64DecGetBits(&_rans, &_ptr, bypass_precision); + raw_val |= val << (j * bypass_precision); + } + value = raw_val >> 1; + if (raw_val & 1) { + value = -value - 1; + } else { + value += max_value; + } + } + + outout_ptr[i] = static_cast(value + offset); + } + return output; +} + +std::vector RansDecoderLib::get_decoded_symbols() { + return m_decodedSymbols; +} diff --git a/ICLR2023/src/cpp/rans/rans.h b/ICLR2023/src/cpp/rans/rans.h new file mode 100644 index 0000000..414396d --- /dev/null +++ b/ICLR2023/src/cpp/rans/rans.h @@ -0,0 +1,144 @@ +/* Copyright 2020 InterDigital Communications, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include + +#ifdef __GNUC__ +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wpedantic" +#pragma GCC diagnostic ignored "-Wsign-compare" +#endif + +#include + +#ifdef __GNUC__ +#pragma GCC diagnostic pop +#endif + +struct RansSymbol { + uint16_t start; + uint16_t range; + bool bypass; // bypass flag to write raw bits to the stream +}; + +enum class WorkType { + Encode, + Flush, +}; + +struct PendingEncoding { + WorkType workType; + std::vector symbols; + std::vector indexes; + std::vector> cdfs; + std::vector cdfs_sizes; + std::vector offsets; +}; + +struct PendingDecoding { + std::vector* output; + std::vector* indexes; + std::vector>* cdfs; + std::vector* cdfs_sizes; + std::vector* offsets; +}; + +/* NOTE: Warning, we buffer everything for now... In case of large files we + * should split the bitstream into chunks... Or for a memory-bounded encoder + **/ +class RansEncoderLib { +public: + RansEncoderLib() = default; + virtual ~RansEncoderLib() = default; + + RansEncoderLib(const RansEncoderLib &) = delete; + RansEncoderLib(RansEncoderLib &&) = delete; + RansEncoderLib &operator=(const RansEncoderLib &) = delete; + RansEncoderLib &operator=(RansEncoderLib &&) = delete; + + virtual void + encode_with_indexes(const std::vector &symbols, + const std::vector &indexes, + const std::vector> &cdfs, + const std::vector &cdfs_sizes, + const std::vector &offsets); + virtual void flush(); + virtual std::vector get_encoded_stream(); + virtual void reset(); + +private: + std::vector _syms; + std::vector _stream; +}; + +class RansEncoderLibMultiThread : public RansEncoderLib { +public: + RansEncoderLibMultiThread(); + virtual ~RansEncoderLibMultiThread(); + + virtual void + encode_with_indexes(const std::vector &symbols, + const std::vector &indexes, + const std::vector> &cdfs, + const std::vector &cdfs_sizes, + const std::vector &offsets) override; + virtual void flush() override; + virtual std::vector get_encoded_stream() override; + virtual void reset() override; + + void worker(); + +private: + bool m_finish; + bool m_result_ready; + std::thread m_thread; + std::mutex m_mutex_result; + std::mutex m_mutex_pending; + std::condition_variable m_cv_pending; + std::condition_variable m_cv_result; + std::list m_pending; +}; + +class RansDecoderLib { +public: + RansDecoderLib() = default; + virtual ~RansDecoderLib() = default; + + RansDecoderLib(const RansDecoderLib &) = delete; + RansDecoderLib(RansDecoderLib &&) = delete; + RansDecoderLib &operator=(const RansDecoderLib &) = delete; + RansDecoderLib &operator=(RansDecoderLib &&) = delete; + + void set_stream(const std::vector &); + + virtual std::vector + decode_stream(const std::vector &indexes, + const std::vector> &cdfs, + const std::vector &cdfs_sizes, + const std::vector &offsets); + + virtual std::vector get_decoded_symbols(); + +private: + Rans64State _rans; + uint32_t *_ptr; + std::vector _stream; + std::vector m_decodedSymbols; +}; diff --git a/ICLR2023/src/models/__init__.py b/ICLR2023/src/models/__init__.py new file mode 100644 index 0000000..d30ebf9 --- /dev/null +++ b/ICLR2023/src/models/__init__.py @@ -0,0 +1,31 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from .image_model import EVC_LL, EVC_LM, EVC_LS +from .image_model import EVC_ML, EVC_SL +from .image_model import EVC_MM, EVC_SS +from .scalable_encoder_model import Scale_EVC_SS, Scale_EVC_SL + + +model_architectures = { + 'EVC_LL': EVC_LL, + 'EVC_LM': EVC_LM, + 'EVC_LS': EVC_LS, + + 'EVC_ML': EVC_ML, + 'EVC_SL': EVC_SL, + + 'EVC_MM': EVC_MM, + 'EVC_SS': EVC_SS, + + 'Scale_EVC_SS': Scale_EVC_SS, + 'Scale_EVC_SL': Scale_EVC_SL, +} + + +def build_model(model_name, **kwargs): + # print(f'=> build model: {model_name}') + if model_name in model_architectures: + return model_architectures[model_name](**kwargs) + else: + raise ValueError(model_name) diff --git a/ICLR2023/src/models/common_model.py b/ICLR2023/src/models/common_model.py new file mode 100644 index 0000000..56afdfa --- /dev/null +++ b/ICLR2023/src/models/common_model.py @@ -0,0 +1,164 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import math +import torch +from torch import nn + +from .layers import LowerBound +from .entropy_models import BitEstimator, GaussianEncoder, EntropyCoder + + +class CompressionModel(nn.Module): + def __init__(self, y_distribution, z_channel, ec_thread=False): + super().__init__() + self.y_distribution = y_distribution + self.z_channel = z_channel + self.entropy_coder = None + self.bit_estimator_z = BitEstimator(z_channel) + self.gaussian_encoder = GaussianEncoder(distribution=y_distribution) + self.ec_thread = ec_thread + self.mse = nn.MSELoss(reduction='none') + + def quant(self, x): + return torch.round(x) + + @staticmethod + def get_curr_q(q_scale, q_basic): + q_basic = LowerBound.apply(q_basic, 0.5) + return q_basic * q_scale + + @staticmethod + def probs_to_bits(probs): + bits = -1.0 * torch.log(probs + 1e-5) / math.log(2.0) + bits = LowerBound.apply(bits, 0) + return bits + + def get_y_gaussian_bits(self, y, sigma): + mu = torch.zeros_like(sigma) + sigma = sigma.clamp(0.11, 1e10) + gaussian = torch.distributions.normal.Normal(mu, sigma) + probs = gaussian.cdf(y + 0.5) - gaussian.cdf(y - 0.5) + return CompressionModel.probs_to_bits(probs) + + def get_z_bits(self, z, bit_estimator): + probs = bit_estimator.get_cdf(z + 0.5) - bit_estimator.get_cdf(z - 0.5) + return CompressionModel.probs_to_bits(probs) + + def update(self, force=False): + self.entropy_coder = EntropyCoder(self.ec_thread) + self.gaussian_encoder.update(force=force, entropy_coder=self.entropy_coder) + self.bit_estimator_z.update(force=force, entropy_coder=self.entropy_coder) + + @staticmethod + def separate_prior(params): + return params.chunk(3, 1) + + @staticmethod + def get_mask(height, width, dtype, device): + micro_mask = torch.tensor(((1, 0), (0, 1)), dtype=dtype, device=device) + mask_0 = micro_mask.repeat(height // 2, width // 2) + mask_0 = torch.unsqueeze(mask_0, 0) + mask_0 = torch.unsqueeze(mask_0, 0) + mask_1 = torch.ones_like(mask_0) - mask_0 + return mask_0, mask_1 + + def process_with_mask(self, y, scales, means, mask): + scales_hat = scales * mask + means_hat = means * mask + + y_res = (y - means_hat) * mask + y_q = self.quant(y_res) + y_hat = y_q + means_hat + + return y_res, y_q, y_hat, scales_hat + + def forward_dual_prior(self, y, means, scales, quant_step, y_spatial_prior, write=False): + ''' + y_0 means split in channel, the first half + y_1 means split in channel, the second half + y_?_0, means multiply with mask_0 + y_?_1, means multiply with mask_1 + ''' + dtype = y.dtype + device = y.device + _, _, H, W = y.size() + mask_0, mask_1 = self.get_mask(H, W, dtype, device) + + quant_step = torch.clamp_min(quant_step, 0.5) + y = y / quant_step + y_0, y_1 = y.chunk(2, 1) + + scales_0, scales_1 = scales.chunk(2, 1) + means_0, means_1 = means.chunk(2, 1) + + y_res_0_0, y_q_0_0, y_hat_0_0, scales_hat_0_0 = \ + self.process_with_mask(y_0, scales_0, means_0, mask_0) + y_res_1_1, y_q_1_1, y_hat_1_1, scales_hat_1_1 = \ + self.process_with_mask(y_1, scales_1, means_1, mask_1) + + params = torch.cat((y_hat_0_0, y_hat_1_1, means, scales, quant_step), dim=1) + scales_0, means_0, scales_1, means_1 = y_spatial_prior(params).chunk(4, 1) + + y_res_0_1, y_q_0_1, y_hat_0_1, scales_hat_0_1 = \ + self.process_with_mask(y_0, scales_0, means_0, mask_1) + y_res_1_0, y_q_1_0, y_hat_1_0, scales_hat_1_0 = \ + self.process_with_mask(y_1, scales_1, means_1, mask_0) + + y_res_0 = y_res_0_0 + y_res_0_1 + y_q_0 = y_q_0_0 + y_q_0_1 + y_hat_0 = y_hat_0_0 + y_hat_0_1 + scales_hat_0 = scales_hat_0_0 + scales_hat_0_1 + + y_res_1 = y_res_1_1 + y_res_1_0 + y_q_1 = y_q_1_1 + y_q_1_0 + y_hat_1 = y_hat_1_1 + y_hat_1_0 + scales_hat_1 = scales_hat_1_1 + scales_hat_1_0 + + y_res = torch.cat((y_res_0, y_res_1), dim=1) + y_q = torch.cat((y_q_0, y_q_1), dim=1) + y_hat = torch.cat((y_hat_0, y_hat_1), dim=1) + scales_hat = torch.cat((scales_hat_0, scales_hat_1), dim=1) + + y_hat = y_hat * quant_step + + if write: + y_q_w_0 = y_q_0_0 + y_q_1_1 + y_q_w_1 = y_q_0_1 + y_q_1_0 + scales_w_0 = scales_hat_0_0 + scales_hat_1_1 + scales_w_1 = scales_hat_0_1 + scales_hat_1_0 + return y_q_w_0, y_q_w_1, scales_w_0, scales_w_1, y_hat + return y_res, y_q, y_hat, scales_hat + + def compress_dual_prior(self, y, means, scales, quant_step, y_spatial_prior): + return self.forward_dual_prior(y, means, scales, quant_step, y_spatial_prior, write=True) + + def decompress_dual_prior(self, means, scales, quant_step, y_spatial_prior): + dtype = means.dtype + device = means.device + _, _, H, W = means.size() + mask_0, mask_1 = self.get_mask(H, W, dtype, device) + quant_step = torch.clamp_min(quant_step, 0.5) + + scales_0, scales_1 = scales.chunk(2, 1) + means_0, means_1 = means.chunk(2, 1) + + scales_r_0 = scales_0 * mask_0 + scales_1 * mask_1 + y_q_r_0 = self.gaussian_encoder.decode_stream(scales_r_0, dtype, device) + y_hat_0_0 = (y_q_r_0 + means_0) * mask_0 + y_hat_1_1 = (y_q_r_0 + means_1) * mask_1 + + params = torch.cat((y_hat_0_0, y_hat_1_1, means, scales, quant_step), dim=1) + scales_0, means_0, scales_1, means_1 = y_spatial_prior(params).chunk(4, 1) + + scales_r_1 = scales_0 * mask_1 + scales_1 * mask_0 + y_q_r_1 = self.gaussian_encoder.decode_stream(scales_r_1, dtype, device) + y_hat_0_1 = (y_q_r_1 + means_0) * mask_1 + y_hat_1_0 = (y_q_r_1 + means_1) * mask_0 + + y_hat_0 = y_hat_0_0 + y_hat_0_1 + y_hat_1 = y_hat_1_1 + y_hat_1_0 + y_hat = torch.cat((y_hat_0, y_hat_1), dim=1) + y_hat = y_hat * quant_step + + return y_hat diff --git a/ICLR2023/src/models/entropy_models.py b/ICLR2023/src/models/entropy_models.py new file mode 100644 index 0000000..6914c2b --- /dev/null +++ b/ICLR2023/src/models/entropy_models.py @@ -0,0 +1,286 @@ +import math + +import torch +import numpy as np +from torch import nn +import torch.nn.functional as F + + +class EntropyCoder(): + def __init__(self, ec_thread=False): + super().__init__() + + from .MLCodec_rans import RansEncoder, RansDecoder + self.encoder = RansEncoder(ec_thread, 2) + self.decoder = RansDecoder(2) + + @staticmethod + def pmf_to_quantized_cdf(pmf, precision=16): + from .MLCodec_CXX import pmf_to_quantized_cdf as _pmf_to_quantized_cdf + cdf = _pmf_to_quantized_cdf(pmf.tolist(), precision) + cdf = torch.IntTensor(cdf) + return cdf + + @staticmethod + def pmf_to_cdf(pmf, tail_mass, pmf_length, max_length): + entropy_coder_precision = 16 + cdf = torch.zeros((len(pmf_length), max_length + 2), dtype=torch.int32) + for i, p in enumerate(pmf): + prob = torch.cat((p[: pmf_length[i]], tail_mass[i]), dim=0) + _cdf = EntropyCoder.pmf_to_quantized_cdf(prob, entropy_coder_precision) + cdf[i, : _cdf.size(0)] = _cdf + return cdf + + def reset(self): + self.encoder.reset() + + def encode_with_indexes(self, symbols, indexes, cdf, cdf_length, offset): + self.encoder.encode_with_indexes(symbols.clamp(-30000, 30000).to(torch.int16).cpu().numpy(), + indexes.to(torch.int16).cpu().numpy(), + cdf, cdf_length, offset) + + def flush(self): + self.encoder.flush() + + def get_encoded_stream(self): + return self.encoder.get_encoded_stream().tobytes() + + def set_stream(self, stream): + self.decoder.set_stream((np.frombuffer(stream, dtype=np.uint8))) + + def decode_stream(self, indexes, cdf, cdf_length, offset): + rv = self.decoder.decode_stream(indexes.to(torch.int16).cpu().numpy(), + cdf, cdf_length, offset) + rv = torch.Tensor(rv) + return rv + + +class Bitparm(nn.Module): + def __init__(self, channel, final=False): + super().__init__() + self.final = final + self.h = nn.Parameter(torch.nn.init.normal_( + torch.empty(channel).view(1, -1, 1, 1), 0, 0.01)) + self.b = nn.Parameter(torch.nn.init.normal_( + torch.empty(channel).view(1, -1, 1, 1), 0, 0.01)) + if not final: + self.a = nn.Parameter(torch.nn.init.normal_( + torch.empty(channel).view(1, -1, 1, 1), 0, 0.01)) + else: + self.a = None + + def forward(self, x): + x = x * F.softplus(self.h) + self.b + if self.final: + return x + + return x + torch.tanh(x) * torch.tanh(self.a) + + +class AEHelper(): + def __init__(self): + super().__init__() + self.entropy_coder = None + self._offset = None + self._quantized_cdf = None + self._cdf_length = None + + def set_entropy_coder(self, coder): + self.entropy_coder = coder + + def set_cdf_info(self, quantized_cdf, cdf_length, offset): + self._quantized_cdf = quantized_cdf.cpu().numpy() + self._cdf_length = cdf_length.reshape(-1).int().cpu().numpy() + self._offset = offset.reshape(-1).int().cpu().numpy() + + def get_cdf_info(self): + return self._quantized_cdf, \ + self._cdf_length, \ + self._offset + + def get_cdf_info_tensor(self): + return torch.tensor(self._quantized_cdf), \ + torch.tensor(self._cdf_length), \ + torch.tensor(self._offset) + + +class BitEstimator(AEHelper, nn.Module): + def __init__(self, channel): + super().__init__() + self.f1 = Bitparm(channel) + self.f2 = Bitparm(channel) + self.f3 = Bitparm(channel) + self.f4 = Bitparm(channel, True) + self.channel = channel + + def forward(self, x): + return self.get_cdf(x) + + def get_logits_cdf(self, x): + x = self.f1(x) + x = self.f2(x) + x = self.f3(x) + x = self.f4(x) + return x + + def get_cdf(self, x): + return torch.sigmoid(self.get_logits_cdf(x)) + + def update(self, force=False, entropy_coder=None): + if entropy_coder is not None: + self.entropy_coder = entropy_coder + + if not force and self._offset is not None: + return + + with torch.no_grad(): + device = next(self.parameters()).device + medians = torch.zeros((self.channel), device=device) + + minima = medians + 50 + for i in range(50, 1, -1): + samples = torch.zeros_like(medians) - i + samples = samples[None, :, None, None] + probs = self.forward(samples) + probs = torch.squeeze(probs) + minima = torch.where(probs < torch.zeros_like(medians) + 0.0001, + torch.zeros_like(medians) + i, minima) + + maxima = medians + 50 + for i in range(50, 1, -1): + samples = torch.zeros_like(medians) + i + samples = samples[None, :, None, None] + probs = self.forward(samples) + probs = torch.squeeze(probs) + maxima = torch.where(probs > torch.zeros_like(medians) + 0.9999, + torch.zeros_like(medians) + i, maxima) + + minima = minima.int() + maxima = maxima.int() + + offset = -minima + + pmf_start = medians - minima + pmf_length = maxima + minima + 1 + + max_length = pmf_length.max() + device = pmf_start.device + samples = torch.arange(max_length, device=device) + + samples = samples[None, :] + pmf_start[:, None, None] + + half = float(0.5) + + lower = self.forward(samples - half).squeeze(0) + upper = self.forward(samples + half).squeeze(0) + pmf = upper - lower + + pmf = pmf[:, 0, :] + tail_mass = lower[:, 0, :1] + (1.0 - upper[:, 0, -1:]) + + quantized_cdf = EntropyCoder.pmf_to_cdf(pmf, tail_mass, pmf_length, max_length) + cdf_length = pmf_length + 2 + self.set_cdf_info(quantized_cdf, cdf_length, offset) + + @staticmethod + def build_indexes(size): + N, C, H, W = size + indexes = torch.arange(C, dtype=torch.int).view(1, -1, 1, 1) + return indexes.repeat(N, 1, H, W) + + def encode(self, x): + indexes = self.build_indexes(x.size()) + return self.entropy_coder.encode_with_indexes(x.reshape(-1), indexes.reshape(-1), + *self.get_cdf_info()) + + def decode_stream(self, size, dtype, device): + output_size = (1, self.channel, size[0], size[1]) + indexes = self.build_indexes(output_size) + val = self.entropy_coder.decode_stream(indexes.reshape(-1), *self.get_cdf_info()) + val = val.reshape(indexes.shape) + return val.to(dtype).to(device) + + +class GaussianEncoder(AEHelper): + def __init__(self, distribution='laplace'): + super().__init__() + assert distribution in ['laplace', 'gaussian'] + self.distribution = distribution + if distribution == 'laplace': + self.cdf_distribution = torch.distributions.laplace.Laplace + self.scale_min = 0.01 + self.scale_max = 64.0 + self.scale_level = 256 + elif distribution == 'gaussian': + self.cdf_distribution = torch.distributions.normal.Normal + self.scale_min = 0.11 + self.scale_max = 64.0 + self.scale_level = 256 + self.scale_table = self.get_scale_table(self.scale_min, self.scale_max, self.scale_level) + + self.log_scale_min = math.log(self.scale_min) + self.log_scale_max = math.log(self.scale_max) + self.log_scale_step = (self.log_scale_max - self.log_scale_min) / (self.scale_level - 1) + + @staticmethod + def get_scale_table(min_val, max_val, levels): + return torch.exp(torch.linspace(math.log(min_val), math.log(max_val), levels)) + + def update(self, force=False, entropy_coder=None): + if entropy_coder is not None: + self.entropy_coder = entropy_coder + + if not force and self._offset is not None: + return + + pmf_center = torch.zeros_like(self.scale_table) + 50 + scales = torch.zeros_like(pmf_center) + self.scale_table + mu = torch.zeros_like(scales) + cdf_distribution = self.cdf_distribution(mu, scales) + for i in range(50, 1, -1): + samples = torch.zeros_like(pmf_center) + i + probs = cdf_distribution.cdf(samples) + probs = torch.squeeze(probs) + pmf_center = torch.where(probs > torch.zeros_like(pmf_center) + 0.9999, + torch.zeros_like(pmf_center) + i, pmf_center) + + pmf_center = pmf_center.int() + pmf_length = 2 * pmf_center + 1 + max_length = torch.max(pmf_length).item() + + device = pmf_center.device + samples = torch.arange(max_length, device=device) - pmf_center[:, None] + samples = samples.float() + + scales = torch.zeros_like(samples) + self.scale_table[:, None] + mu = torch.zeros_like(scales) + cdf_distribution = self.cdf_distribution(mu, scales) + + upper = cdf_distribution.cdf(samples + 0.5) + lower = cdf_distribution.cdf(samples - 0.5) + pmf = upper - lower + + tail_mass = 2 * lower[:, :1] + + quantized_cdf = torch.Tensor(len(pmf_length), max_length + 2) + quantized_cdf = EntropyCoder.pmf_to_cdf(pmf, tail_mass, pmf_length, max_length) + + self.set_cdf_info(quantized_cdf, pmf_length+2, -pmf_center) + + def build_indexes(self, scales): + scales = torch.maximum(scales, torch.zeros_like(scales) + 1e-5) + indexes = (torch.log(scales) - self.log_scale_min) / self.log_scale_step + indexes = indexes.clamp_(0, self.scale_level - 1) + return indexes.int() + + def encode(self, x, scales): + indexes = self.build_indexes(scales) + return self.entropy_coder.encode_with_indexes(x.reshape(-1), indexes.reshape(-1), + *self.get_cdf_info()) + + def decode_stream(self, scales, dtype, device): + indexes = self.build_indexes(scales) + val = self.entropy_coder.decode_stream(indexes.reshape(-1), + *self.get_cdf_info()) + val = val.reshape(scales.shape) + return val.to(device).to(dtype) diff --git a/ICLR2023/src/models/hyperprior.py b/ICLR2023/src/models/hyperprior.py new file mode 100644 index 0000000..5e1218d --- /dev/null +++ b/ICLR2023/src/models/hyperprior.py @@ -0,0 +1,107 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from torch import nn + + +class DepthConv(nn.Module): + def __init__(self, in_ch, out_ch, depth_kernel=3, stride=1, slope=0.01): + super().__init__() + dw_ch = in_ch * 1 + self.conv1 = nn.Sequential( + nn.Conv2d(in_ch, dw_ch, 1, stride=stride), + nn.LeakyReLU(negative_slope=slope), + ) + + self.depth_conv = nn.Conv2d(dw_ch, dw_ch, depth_kernel, padding=depth_kernel // 2, + groups=dw_ch) + self.conv2 = nn.Conv2d(dw_ch, out_ch, 1) + + self.adaptor = None + if stride != 1: + assert stride == 2 + self.adaptor = nn.Conv2d(in_ch, out_ch, 2, stride=2) + elif in_ch != out_ch: + self.adaptor = nn.Conv2d(in_ch, out_ch, 1) + + def forward(self, x): + identity = x + if self.adaptor is not None: + identity = self.adaptor(identity) + out = self.conv1(x) + out = self.depth_conv(out) + out = self.conv2(out) + return out + identity + + +class ConvFFN(nn.Module): + def __init__(self, in_ch, slope=0.1): + super().__init__() + internal_ch = max(min(in_ch * 4, 1024), in_ch * 2) + self.conv = nn.Sequential( + nn.Conv2d(in_ch, internal_ch, 1), + nn.LeakyReLU(negative_slope=slope), + nn.Conv2d(internal_ch, in_ch, 1), + nn.LeakyReLU(negative_slope=slope), + ) + + def forward(self, x): + identity = x + return identity + self.conv(x) + + +class DepthConvBlock(nn.Module): + def __init__(self, in_ch, out_ch, depth_kernel=3, stride=1, + slope_depth_conv=0.01, slope_ffn=0.1): + super().__init__() + self.block = nn.Sequential( + DepthConv(in_ch, out_ch, depth_kernel, stride, slope=slope_depth_conv), + ConvFFN(out_ch, slope=slope_ffn), + ) + + def forward(self, x): + return self.block(x) + + +class DepthConvBlockUpsample(nn.Module): + def __init__(self, in_ch, out_ch, depth_kernel=3, slope_depth_conv=0.01, slope_ffn=0.1): + super().__init__() + self.block = nn.Sequential( + DepthConv(in_ch, out_ch, depth_kernel, slope=slope_depth_conv), + ConvFFN(out_ch, slope=slope_ffn), + nn.Conv2d(out_ch, out_ch * 4, 1), + nn.PixelShuffle(2), + ) + + def forward(self, x): + return self.block(x) + + +def get_hyperprior(channel=192): + N = channel + hyper_enc = nn.Sequential( + DepthConvBlock(N, N, stride=1), + nn.Conv2d(N, N, 3, stride=2, padding=1), + nn.LeakyReLU(), + nn.Conv2d(N, N, 3, stride=2, padding=1), + ) + hyper_dec = nn.Sequential( + DepthConvBlockUpsample(N, N), + DepthConvBlockUpsample(N, N), + DepthConvBlock(N, N), + ) + y_prior_fusion = nn.Sequential( + DepthConvBlock(N, N * 2), + DepthConvBlock(N * 2, N * 3), + ) + return hyper_enc, hyper_dec, y_prior_fusion + + +def get_dualprior(channel=192): + N = channel + y_spatial_prior = nn.Sequential( + DepthConvBlock(N * 4, N * 3), + DepthConvBlock(N * 3, N * 2), + DepthConvBlock(N * 2, N * 2), + ) + return y_spatial_prior diff --git a/ICLR2023/src/models/image_model.py b/ICLR2023/src/models/image_model.py new file mode 100644 index 0000000..cccd733 --- /dev/null +++ b/ICLR2023/src/models/image_model.py @@ -0,0 +1,258 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import os +import time +import torch +from torch import nn + +from .common_model import CompressionModel +from .layers import get_enc_dec_models +from .hyperprior import get_hyperprior, get_dualprior +from ..utils.stream_helper import encode_i, decode_i, get_downsampled_shape, filesize, \ + get_rounded_q, get_state_dict + + +class EVC(CompressionModel): + def __init__(self, N=192, anchor_num=4, ec_thread=False): + super().__init__(y_distribution='gaussian', z_channel=N, ec_thread=ec_thread) + channels = [192, 192, 192, 192] + self.enc, self.dec = get_enc_dec_models(3, 3, channels) + self.hyper_enc, self.hyper_dec, self.y_prior_fusion = get_hyperprior(N) + self.y_spatial_prior = get_dualprior(N) + + self.q_basic = nn.Parameter(torch.ones((1, N, 1, 1))) + self.q_scale = nn.Parameter(torch.ones((anchor_num, 1, 1, 1))) + # the exact q_step is q_basic * q_scale + self.N = int(N) + self.anchor_num = int(anchor_num) + + def single_encode(self, x, q_scale=None): + curr_q = self.get_curr_q(q_scale, self.q_basic) + y = self.enc(x) + y = y / curr_q + return x, y, curr_q + + def hyperprior(self, y): + z = self.hyper_enc(y) + z_hat = self.quant(z) + + params = self.hyper_dec(z_hat) + params = self.y_prior_fusion(params) + q_step, scales, means = self.separate_prior(params) + y_res, y_q, y_hat, scales_hat = self.forward_dual_prior( + y, means, scales, q_step, self.y_spatial_prior) + + y_for_bit = y_q + z_for_bit = z_hat + + bits_y = self.get_y_gaussian_bits(y_for_bit, scales_hat) + bits_z = self.get_z_bits(z_for_bit, self.bit_estimator_z) + return y_hat, bits_y, bits_z + + def forward(self, x, q_scale=None): + x, y, curr_q = self.single_encode(x, q_scale) + y_hat, bits_y, bits_z = self.hyperprior(y) + y_hat = y_hat * curr_q + x_hat = self.dec(y_hat) + return self.compute_loss(x, x_hat, bits_y, bits_z) + + @staticmethod + def get_q_scales_from_ckpt(ckpt_path): + ckpt = get_state_dict(ckpt_path) + q_scales = ckpt["q_scale"] + return q_scales.reshape(-1) + + def compute_loss(self, x, x_hat, bits_y, bits_z): + B, _, H, W = x.size() + pixel_num = H * W + bpp_y = torch.sum(bits_y, dim=(1, 2, 3)) / pixel_num + bpp_z = torch.sum(bits_z, dim=(1, 2, 3)) / pixel_num + + bits = torch.sum(bpp_y + bpp_z) * pixel_num + bpp = bpp_y + bpp_z + + return { + "x_hat": x_hat, + "bit": bits, + "bpp": bpp, + "bpp_y": bpp_y, + "bpp_z": bpp_z, + } + + def encode_decode(self, x, q_scale, output_path=None, pic_width=None, pic_height=None): + # pic_width and pic_height may be different from x's size. X here is after padding + # x_hat has the same size with x + if output_path is None: + torch.cuda.synchronize() + start_time = time.time() + x, y, curr_q = self.single_encode(x, q_scale) + y_hat, bits_y, bits_z = self.hyperprior(y) + y_hat = y_hat * curr_q + x_hat = self.dec(y_hat) + torch.cuda.synchronize() + latency = time.time() - start_time + encoded = self.compute_loss(x, x_hat, bits_y, bits_z) + result = { + 'bit': encoded['bit'].item(), + 'x_hat': encoded['x_hat'], + 'latency': latency, + } + return result + + assert pic_height is not None + assert pic_width is not None + q_scale, q_index = get_rounded_q(q_scale) + torch.cuda.synchronize() + start_time = time.time() + compressed = self.compress(x, q_scale) + torch.cuda.synchronize() + enc_time = time.time() - start_time + + bit_stream = compressed['bit_stream'] + encode_i(pic_height, pic_width, q_index, bit_stream, output_path) + bit = filesize(output_path) * 8 + + height, width, q_index, bit_stream = decode_i(output_path) + decompressed = self.decompress(bit_stream, height, width, q_index / 100) + + x_hat = decompressed['x_hat'] + dec_time = decompressed['dec_time'] + + result = { + 'bit': bit, + 'x_hat': x_hat, + 'enc_time': enc_time, + 'dec_time': dec_time, + 'latency': enc_time + dec_time, + } + return result + + def compress(self, x, q_scale): + curr_q = self.get_curr_q(q_scale, self.q_basic) + y = self.enc(x) + y = y / curr_q + z = self.hyper_enc(y) + z_hat = torch.round(z) + + params = self.hyper_dec(z_hat) + params = self.y_prior_fusion(params) + q_step, scales, means = self.separate_prior(params) + y_q_w_0, y_q_w_1, scales_w_0, scales_w_1, y_hat = self.compress_dual_prior( + y, means, scales, q_step, self.y_spatial_prior) + y_hat = y_hat * curr_q + + self.entropy_coder.reset() + self.bit_estimator_z.encode(z_hat) + self.gaussian_encoder.encode(y_q_w_0, scales_w_0) + self.gaussian_encoder.encode(y_q_w_1, scales_w_1) + self.entropy_coder.flush() + + x_hat = self.dec(y_hat).clamp_(0, 1) + + bit_stream = self.entropy_coder.get_encoded_stream() + + result = { + "bit_stream": bit_stream, + "x_hat": x_hat, + } + return result + + def decompress(self, bit_stream, height, width, q_scale): + torch.cuda.synchronize() + start_time = time.time() + curr_q = self.get_curr_q(q_scale, self.q_basic) + + self.entropy_coder.set_stream(bit_stream) + dtype = next(self.parameters()).dtype + device = next(self.parameters()).device + z_size = get_downsampled_shape(height, width, 64) + z_hat = self.bit_estimator_z.decode_stream(z_size, dtype, device) + + params = self.hyper_dec(z_hat) + params = self.y_prior_fusion(params) + q_step, scales, means = self.separate_prior(params) + y_hat = self.decompress_dual_prior(means, scales, q_step, self.y_spatial_prior) + + y_hat = y_hat * curr_q + x_hat = self.dec(y_hat).clamp_(0, 1) + torch.cuda.synchronize() + dec_time = time.time() - start_time + return {"x_hat": x_hat, 'dec_time': dec_time} + + def load_state_dict(self, state_dict, verbose=True, **kwargs): + sd = self.state_dict() + for skey in sd: + if skey in state_dict and state_dict[skey].shape == sd[skey].shape: + sd[skey] = state_dict[skey] + # print(f"load {skey}") + elif 'q_scale' in skey and skey in state_dict: + # TODO: load q_scale according to cuda_id + print(f"q_scale: this {sd[skey].shape}, load {state_dict[skey].shape}") + cuda_id = int(os.environ.get('CUDA_VISIBLE_DEVICES', 0)) + sd[skey][0] = state_dict[skey][cuda_id % 4] + if verbose: + print(f"cuda {cuda_id} load q_scale: {sd[skey]}") + elif verbose and skey not in state_dict: + print(f"NOT load {skey}, not find it in state_dict") + elif verbose: + print(f"NOT load {skey}, this {sd[skey].shape}, load {state_dict[skey].shape}") + super().load_state_dict(sd, **kwargs) + + +class EVC_LL(EVC): + def __init__(self, N=192, anchor_num=4, ec_thread=False): + super().__init__(N, anchor_num, ec_thread) + channels = [192, 192, 192, 192] + self.enc, self.dec = get_enc_dec_models(3, 3, channels) + + +class EVC_LM(EVC): + def __init__(self, N=192, anchor_num=4, ec_thread=False): + super().__init__(N, anchor_num, ec_thread) + channels = [128, 128, 192, 192] + _, self.dec = get_enc_dec_models(3, 3, channels) + + +class EVC_LS(EVC): + def __init__(self, N=192, anchor_num=4, ec_thread=False): + super().__init__(N, anchor_num, ec_thread) + channels = [64, 64, 128, 192] + _, self.dec = get_enc_dec_models(3, 3, channels) + + +class EVC_SL(EVC): + def __init__(self, N=192, anchor_num=4, ec_thread=False): + super().__init__(N, anchor_num, ec_thread) + channels = [64, 64, 128, 192] + self.enc, _ = get_enc_dec_models(3, 3, channels) + + +class EVC_ML(EVC): + def __init__(self, N=192, anchor_num=4, ec_thread=False): + super().__init__(N, anchor_num, ec_thread) + channels = [128, 128, 192, 192] + self.enc, _ = get_enc_dec_models(3, 3, channels) + + +class EVC_SS(EVC): + def __init__(self, N=192, anchor_num=4, ec_thread=False): + super().__init__(N, anchor_num, ec_thread) + channels = [64, 64, 128, 192] + self.enc, self.dec = get_enc_dec_models(3, 3, channels) + + +class EVC_MM(EVC): + def __init__(self, N=192, anchor_num=4, ec_thread=False): + super().__init__(N, anchor_num, ec_thread) + channels = [128, 128, 192, 192] + self.enc, self.dec = get_enc_dec_models(3, 3, channels) + + +class EVC_MS(EVC): + def __init__(self, N=192, anchor_num=4, ec_thread=False): + super().__init__(N, anchor_num, ec_thread) + channels = [128, 128, 192, 192] + self.enc, _ = get_enc_dec_models(3, 3, channels) + channels = [64, 64, 128, 192] + _, self.dec = get_enc_dec_models(3, 3, channels) diff --git a/ICLR2023/src/models/layers.py b/ICLR2023/src/models/layers.py new file mode 100644 index 0000000..bebf3a4 --- /dev/null +++ b/ICLR2023/src/models/layers.py @@ -0,0 +1,253 @@ +# Copyright 2020 InterDigital Communications, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from torch import nn +from torch.autograd import Function + + +# pylint: disable=W0221 +class LowerBound(Function): + @staticmethod + def forward(ctx, inputs, bound): + b = torch.ones_like(inputs) * bound + ctx.save_for_backward(inputs, b) + return torch.max(inputs, b) + + @staticmethod + def backward(ctx, grad_output): + inputs, b = ctx.saved_tensors + pass_through_1 = inputs >= b + pass_through_2 = grad_output < 0 + + pass_through = pass_through_1 | pass_through_2 + return pass_through.type(grad_output.dtype) * grad_output, None +# pylint: enable=W0221 + + +def conv3x3(in_ch, out_ch, stride=1): + """3x3 convolution with padding.""" + return nn.Conv2d(in_ch, out_ch, kernel_size=3, stride=stride, padding=1) + + +def subpel_conv3x3(in_ch, out_ch, r=1): + """3x3 sub-pixel convolution for up-sampling.""" + return nn.Sequential( + nn.Conv2d(in_ch, out_ch * r ** 2, kernel_size=3, padding=1), nn.PixelShuffle(r) + ) + + +def subpel_conv1x1(in_ch, out_ch, r=1): + """1x1 sub-pixel convolution for up-sampling.""" + return nn.Sequential( + nn.Conv2d(in_ch, out_ch * r ** 2, kernel_size=1, padding=0), nn.PixelShuffle(r) + ) + + +def conv1x1(in_ch, out_ch, stride=1, bias=True): + """1x1 convolution.""" + return nn.Conv2d(in_ch, out_ch, bias=bias, kernel_size=1, stride=stride) + + +class ResidualBlockWithStride(nn.Module): + """Residual block with a stride on the first convolution. + Args: + in_ch (int): number of input channels + out_ch (int): number of output channels + stride (int): stride value (default: 2) + """ + def __init__(self, in_ch1, in_ch2, in_ch3=None, stride=2): + super().__init__() + if in_ch3 is None: + in_ch3 = in_ch2 + self.conv1 = conv3x3(in_ch1, in_ch2, stride=stride) + self.leaky_relu = nn.LeakyReLU() + self.conv2 = conv3x3(in_ch2, in_ch3) + self.leaky_relu2 = nn.LeakyReLU(negative_slope=0.1) + if stride != 1: + self.downsample = conv1x1(in_ch1, in_ch3, stride=stride) + else: + self.downsample = None + + def forward(self, x): + identity = x + out = self.conv1(x) + out = self.leaky_relu(out) + out = self.conv2(out) + out = self.leaky_relu2(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out = out + identity + return out + + +class ResidualBlockUpsample(nn.Module): + """Residual block with sub-pixel upsampling on the last convolution. + + Args: + in_ch (int): number of input channels + out_ch (int): number of output channels + upsample (int): upsampling factor (default: 2) + """ + + def __init__(self, in_ch1, in_ch2, in_ch3=None, upsample=2): + super().__init__() + if in_ch3 is None: + in_ch3 = in_ch2 + self.subpel_conv = subpel_conv1x1(in_ch1, in_ch2, upsample) + self.leaky_relu = nn.LeakyReLU() + self.conv = conv3x3(in_ch2, in_ch3) + self.leaky_relu2 = nn.LeakyReLU(negative_slope=0.1) + self.upsample = subpel_conv1x1(in_ch1, in_ch3, upsample) + + def forward(self, x): + identity = x + out = self.subpel_conv(x) + out = self.leaky_relu(out) + out = self.conv(out) + out = self.leaky_relu2(out) + identity = self.upsample(x) + out = out + identity + return out + + +class ResidualBlock(nn.Module): + """Simple residual block with two 3x3 convolutions. + + Args: + in_ch (int): number of input channels + out_ch (int): number of output channels + """ + + def __init__(self, in_ch1, in_ch2, in_ch3=None, leaky_relu_slope=0.01): + super().__init__() + if in_ch3 is None: + in_ch3 = in_ch2 + self.conv1 = conv3x3(in_ch1, in_ch2) + self.leaky_relu = nn.LeakyReLU(negative_slope=leaky_relu_slope) + self.conv2 = conv3x3(in_ch2, in_ch3) + self.adaptor = None + if in_ch1 != in_ch3: + self.adaptor = conv1x1(in_ch1, in_ch3) + + def forward(self, x): + identity = x + if self.adaptor is not None: + identity = self.adaptor(identity) + + out = self.conv1(x) + out = self.leaky_relu(out) + out = self.conv2(out) + out = self.leaky_relu(out) + + out = out + identity + return out + + +class DepthConv(nn.Module): + def __init__(self, in_ch1, in_ch2, in_ch3=None, depth_kernel=3, stride=1): + super().__init__() + if in_ch3 is None: + in_ch3 = in_ch2 + in_ch2 = in_ch1 + # dw_ch = in_ch * 1 + self.conv1 = nn.Sequential( + nn.Conv2d(in_ch1, in_ch2, 1, stride=stride), + nn.LeakyReLU(), + ) + self.depth_conv = nn.Sequential( + nn.Conv2d(in_ch2, in_ch2, depth_kernel, padding=depth_kernel // 2, groups=in_ch2), + nn.LeakyReLU(), + ) + self.conv2 = nn.Sequential( + nn.Conv2d(in_ch2, in_ch3, 1), + nn.LeakyReLU(), + ) + self.adaptor = None + if stride != 1: + assert stride == 2 + self.adaptor = nn.Conv2d(in_ch1, in_ch3, 2, stride=2) + elif in_ch1 != in_ch3: + self.adaptor = nn.Conv2d(in_ch1, in_ch3, 1) + + def forward(self, x): + identity = x + if self.adaptor is not None: + identity = self.adaptor(identity) + + out = self.conv1(x) + out = self.depth_conv(out) + out = self.conv2(out) + + return out + identity + + +class ConvFFN(nn.Module): + def __init__(self, in_ch1, in_ch2=None): + super().__init__() + if in_ch2 is None: + in_ch2 = in_ch1 * 4 + self.conv = nn.Sequential( + nn.Conv2d(in_ch1, in_ch2, 1), + nn.ReLU(), + nn.Conv2d(in_ch2, in_ch1, 1), + nn.ReLU(), + ) + + def forward(self, x): + identity = x + return identity + self.conv(x) + + +class DepthConvBlock(nn.Module): + def __init__(self, in_ch, out_ch, depth_kernel=3, stride=1): + super().__init__() + self.block = nn.Sequential( + DepthConv(in_ch, out_ch, depth_kernel=depth_kernel, stride=stride), + ConvFFN(out_ch), + ) + + def forward(self, x): + return self.block(x) + + +def get_enc_dec_models(input_channel, output_channel, channels=[64, 64, 128, 192]): + channel_2x = channels[0] + channel_4x = channels[1] + channel_8x = channels[2] + channel_16x = channels[3] + + enc = nn.Sequential( + ResidualBlockWithStride(input_channel, channel_2x, stride=2), + DepthConvBlock(channel_2x, channel_2x), + ResidualBlockWithStride(channel_2x, channel_4x, stride=2), + DepthConvBlock(channel_4x, channel_4x), + ResidualBlockWithStride(channel_4x, channel_8x, stride=2), + DepthConvBlock(channel_8x, channel_8x), + conv3x3(channel_8x, channel_16x, stride=2), + ) + + dec = nn.Sequential( + DepthConvBlock(channel_16x, channel_16x), + ResidualBlockUpsample(channel_16x, channel_8x, upsample=2), + DepthConvBlock(channel_8x, channel_8x), + ResidualBlockUpsample(channel_8x, channel_4x, upsample=2), + DepthConvBlock(channel_4x, channel_4x), + ResidualBlockUpsample(channel_4x, channel_2x, upsample=2), + DepthConvBlock(channel_2x, channel_2x), + subpel_conv1x1(channel_2x, output_channel, 2), + ) + return enc, dec diff --git a/ICLR2023/src/models/scalable_encoder_model.py b/ICLR2023/src/models/scalable_encoder_model.py new file mode 100644 index 0000000..314252e --- /dev/null +++ b/ICLR2023/src/models/scalable_encoder_model.py @@ -0,0 +1,284 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import os +import time +import torch +from torch import nn + +from .image_model import EVC +from .layers import get_enc_dec_models +from ..utils.stream_helper import encode_i, decode_i, filesize, \ + get_rounded_q + + +def scalable_add(inputs): + # inputs: S x B x C x H x W + scalable_num = inputs.size(0) + inputs_detach = inputs.detach() + out = [] + for i in range(scalable_num): + o = (inputs_detach[:i].sum(0) + inputs[i]) / (i + 1) + out.append(o) + out = torch.stack(out) + return out + + +class ScalableEnc(EVC): + def __init__(self, N=192, anchor_num=4, ec_thread=False, enc_num=4, forward_enc_id=None): + super().__init__(N, anchor_num, ec_thread) + self.enc = None + channels = [64, 64, 128, 192] + encs = [] + self.enc_num = enc_num + for i in range(enc_num): + encs.append(get_enc_dec_models(3, 3, channels)[0]) + self.encs = nn.ModuleList(encs) + self.scalable_add = scalable_add + channels = [192, 192, 192, 192] + _, self.dec = get_enc_dec_models(3, 3, channels) + self.rate = None + self.lmbdas = [0.0022, 0.0050, 0.012, 0.027] + self.forward_enc_id = forward_enc_id + # print(f"multi enc: forward_enc_id={self.forward_enc_id}") + + def load_state_dict(self, state_dict, verbose=True, **kwargs): + sd = self.state_dict() + for skey in sd: + if skey in state_dict and state_dict[skey].shape == sd[skey].shape: + sd[skey] = state_dict[skey] + # print(f"load {skey}") + elif 'q_scale' in skey and skey in state_dict: + # TODO: load q_scale according to cuda_id + print(f"q_scale: this {sd[skey].shape}, load {state_dict[skey].shape}") + cuda_id = int(os.environ.get('CUDA_VISIBLE_DEVICES', 0)) + sd[skey][0] = state_dict[skey][cuda_id % 4] + if verbose: + print(f"cuda {cuda_id} load q_scale: {sd[skey]}") + elif 'enc' in skey: + tmp = ['enc'] + skey.split('.')[2:] + tkey = '.'.join(tmp) + if tkey in state_dict and state_dict[tkey].shape == sd[skey].shape: + sd[skey] = state_dict[tkey] + elif verbose: + print(f"NOT load {skey}") + elif verbose and skey not in state_dict: + print(f"NOT load {skey}, not find it in state_dict") + elif verbose: + print(f"NOT load {skey}, this {sd[skey].shape}, load {state_dict[skey].shape}") + super().load_state_dict(sd, **kwargs) + + def multi_encode(self, x, q_scale=None): + curr_q = self.get_curr_q(q_scale, self.q_basic) + + x_list = [] + y_list = [] + for enc_id in range(self.enc_num): + y = self.encs[enc_id](x) + x_list.append(x) + y_list.append(y) + if self.forward_enc_id is not None and self.forward_enc_id == enc_id: + break + + ys = torch.stack(y_list) + y_out = self.scalable_add(ys) + + if self.forward_enc_id is not None: + y = y_out[self.forward_enc_id] + x = x_list[self.forward_enc_id] + else: + S, B, C, H, W = y_out.shape + y = y_out.reshape(S * B, C, H, W) + curr_q = curr_q.repeat(len(y_list), 1, 1, 1) + x = torch.cat(x_list, dim=0) + + y = y / curr_q + return x, y, curr_q + + def hyperprior_decode(self, x, y, curr_q): + z = self.hyper_enc(y) + z_hat = self.quant(z) + + params = self.hyper_dec(z_hat) + params = self.y_prior_fusion(params) + q_step, scales, means = self.separate_prior(params) + y_res, y_q, y_hat, scales_hat = self.forward_dual_prior( + y, means, scales, q_step, self.y_spatial_prior) + + y_hat = y_hat * curr_q + x_hat = self.dec(y_hat) + + y_for_bit = y_q + z_for_bit = z_hat + + bits_y = self.get_y_gaussian_bits(y_for_bit, scales_hat) + bits_z = self.get_z_bits(z_for_bit, self.bit_estimator_z) + mse = self.mse(x, x_hat) + + B, _, H, W = x.size() + pixel_num = H * W + bpp_y = torch.sum(bits_y, dim=(1, 2, 3)) / pixel_num + bpp_z = torch.sum(bits_z, dim=(1, 2, 3)) / pixel_num + mse = torch.sum(mse, dim=(1, 2, 3)) / pixel_num + + bits = (bpp_y + bpp_z) * pixel_num + bpp = bpp_y + bpp_z + + return { + "x_hat": x_hat, + "mse": mse, + "bit": bits, + "bpp": bpp, + "bpp_y": bpp_y, + "bpp_z": bpp_z, + } + + def forward(self, x, q_scale=None): + x, y, curr_q = self.multi_encode(x, q_scale) + return self.hyperprior_decode(x, y, curr_q) + + def set_rate(self, rate): + self.rate = rate + + def encode_decode(self, x, q_scale, output_path=None, pic_width=None, pic_height=None): + if output_path is None: + result = { + 'bit': None, + 'x_hat': None, + } + chose_id = 0 + lmbda = self.lmbdas[self.rate] + encoded = self.forward(x, q_scale) + mse, bpp = encoded['mse'], encoded['bpp'] + cost = (lmbda * 255 * 255 * mse + bpp).flatten() + if len(cost) == 1: + chose_id = 0 + else: + chose_id = cost.argmin() + result['bit'] = encoded['bit'][chose_id].item() + result['x_hat'] = encoded['x_hat'][chose_id].unsqueeze(0) + return result + assert pic_height is not None + assert pic_width is not None + q_scale, q_index = get_rounded_q(q_scale) + torch.cuda.synchronize() + start_time = time.time() + compressed = self.compress(x, q_scale) + torch.cuda.synchronize() + enc_time = time.time() - start_time + + bit_stream = compressed['bit_stream'] + encode_i(pic_height, pic_width, q_index, bit_stream, output_path) + bit = filesize(output_path) * 8 + + height, width, q_index, bit_stream = decode_i(output_path) + decompressed = self.decompress(bit_stream, height, width, q_index / 100) + + x_hat = decompressed['x_hat'] + dec_time = decompressed['dec_time'] + + result = { + 'bit': bit, + 'x_hat': x_hat, + 'enc_time': enc_time, + 'dec_time': dec_time, + 'latency': enc_time + dec_time, + } + return result + + def compress(self, x, q_scale): + curr_q = self.get_curr_q(q_scale, self.q_basic) + + x_list = [] + y_list = [] + for enc_id in range(self.enc_num): + y = self.encs[enc_id](x) + x_list.append(x) + y_list.append(y) + if self.forward_enc_id is not None and self.forward_enc_id == enc_id: + break + + ys = torch.stack(y_list) + y_out = self.scalable_add(ys) + + if self.forward_enc_id is not None: + y = y_out[self.forward_enc_id] + x = x_list[self.forward_enc_id] + else: + S, B, C, H, W = y_out.shape + y = y_out.reshape(S * B, C, H, W) + curr_q = curr_q.repeat(len(y_list), 1, 1, 1) + x = torch.cat(x_list, dim=0) + + y = y / curr_q + z = self.hyper_enc(y) + z_hat = torch.round(z) + + params = self.hyper_dec(z_hat) + params = self.y_prior_fusion(params) + q_step, scales, means = self.separate_prior(params) + + y_res, y_q, y_hat, scales_hat = self.forward_dual_prior( + y, means, scales, q_step, self.y_spatial_prior) + + y_hat = y_hat * curr_q + x_hat = self.dec(y_hat) + + y_for_bit = y_q + z_for_bit = z_hat + + bits_y = self.get_y_gaussian_bits(y_for_bit, scales_hat) + bits_z = self.get_z_bits(z_for_bit, self.bit_estimator_z) + mse = self.mse(x, x_hat) + + B, _, H, W = x.size() + pixel_num = H * W + bpp_y = torch.sum(bits_y, dim=(1, 2, 3)) / pixel_num + bpp_z = torch.sum(bits_z, dim=(1, 2, 3)) / pixel_num + mse = torch.sum(mse, dim=(1, 2, 3)) / pixel_num + + bpp = bpp_y + bpp_z + + chose_id = 0 + lmbda = self.lmbdas[self.rate] + cost = (lmbda * 255 * 255 * mse + bpp).flatten() + if len(cost) == 1: + chose_id = 0 + else: + chose_id = cost.argmin() + + y = y[chose_id].unsqueeze(0) + means = means[chose_id].unsqueeze(0) + scales = scales[chose_id].unsqueeze(0) + q_step = q_step[chose_id].unsqueeze(0) + z_hat = z_hat[chose_id].unsqueeze(0) + y_q_w_0, y_q_w_1, scales_w_0, scales_w_1, y_hat = self.compress_dual_prior( + y, means, scales, q_step, self.y_spatial_prior) + + self.entropy_coder.reset() + self.bit_estimator_z.encode(z_hat) + self.gaussian_encoder.encode(y_q_w_0, scales_w_0) + self.gaussian_encoder.encode(y_q_w_1, scales_w_1) + self.entropy_coder.flush() + + bit_stream = self.entropy_coder.get_encoded_stream() + + result = { + "bit_stream": bit_stream, + "x_hat": x_hat[chose_id].unsqueeze(0), + } + return result + + +class Scale_EVC_SS(ScalableEnc): + def __init__(self, N=192, anchor_num=4, ec_thread=False, enc_num=4, forward_enc_id=None): + super().__init__(N, anchor_num, ec_thread, enc_num, forward_enc_id) + channels = [64, 64, 128, 192] + _, self.dec = get_enc_dec_models(3, 3, channels) + + +class Scale_EVC_SL(ScalableEnc): + def __init__(self, N=192, anchor_num=4, ec_thread=False, enc_num=4, forward_enc_id=None): + super().__init__(N, anchor_num, ec_thread, enc_num, forward_enc_id) + channels = [192, 192, 192, 192] + _, self.dec = get_enc_dec_models(3, 3, channels) diff --git a/ICLR2023/src/utils/common.py b/ICLR2023/src/utils/common.py new file mode 100644 index 0000000..6ebc56d --- /dev/null +++ b/ICLR2023/src/utils/common.py @@ -0,0 +1,62 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import json +import os +from unittest.mock import patch + +import numpy as np + + +def str2bool(v): + return str(v).lower() in ("yes", "y", "true", "t", "1") + + +def interpolate_log(min_val, max_val, num, decending=True): + assert max_val > min_val + assert min_val > 0 + if decending: + values = np.linspace(np.log(max_val), np.log(min_val), num) + else: + values = np.linspace(np.log(min_val), np.log(max_val), num) + values = np.exp(values) + return values + + +def scale_list_to_str(scales): + s = '' + for scale in scales: + s += f'{scale:.2f} ' + + return s + + +def generate_str(x): + # print(x) + if x.numel() == 1: + return f'{x.item():.5f} ' + s = '' + for a in x: + s += f'{a.item():.5f} ' + return s + + +def create_folder(path, print_if_create=False): + if not os.path.exists(path): + os.makedirs(path) + if print_if_create: + print(f"created folder: {path}") + + +@patch('json.encoder.c_make_encoder', None) +def dump_json(obj, fid, float_digits=-1, **kwargs): + of = json.encoder._make_iterencode # pylint: disable=W0212 + + def inner(*args, **kwargs): + args = list(args) + # fifth argument is float formater which we will replace + args[4] = lambda o: format(o, '.%df' % float_digits) + return of(*args, **kwargs) + + with patch('json.encoder._make_iterencode', wraps=inner): + json.dump(obj, fid, **kwargs) diff --git a/ICLR2023/src/utils/png_reader.py b/ICLR2023/src/utils/png_reader.py new file mode 100644 index 0000000..43b348d --- /dev/null +++ b/ICLR2023/src/utils/png_reader.py @@ -0,0 +1,27 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import os + +import numpy as np +from PIL import Image + + +class PNGReader(): + def __init__(self, filepath): + self.filepath = filepath + self.eof = False + + def read_one_frame(self, src_format="rgb"): + if self.eof: + return None + + png_path = self.filepath + if not os.path.exists(png_path): + self.eof = True + return None + + rgb = Image.open(png_path).convert('RGB') + rgb = np.asarray(rgb).astype('float32').transpose(2, 0, 1) + rgb = rgb / 255. + return rgb diff --git a/ICLR2023/src/utils/stream_helper.py b/ICLR2023/src/utils/stream_helper.py new file mode 100644 index 0000000..4e35550 --- /dev/null +++ b/ICLR2023/src/utils/stream_helper.py @@ -0,0 +1,144 @@ +# Copyright 2020 InterDigital Communications, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import struct +from pathlib import Path + +import numpy as np +import torch + + +def get_padding_size(height, width, p=64): + new_h = (height + p - 1) // p * p + new_w = (width + p - 1) // p * p + # padding_left = (new_w - width) // 2 + padding_left = 0 + padding_right = new_w - width - padding_left + # padding_top = (new_h - height) // 2 + padding_top = 0 + padding_bottom = new_h - height - padding_top + return padding_left, padding_right, padding_top, padding_bottom + + +def get_downsampled_shape(height, width, p): + new_h = (height + p - 1) // p * p + new_w = (width + p - 1) // p * p + return int(new_h / p + 0.5), int(new_w / p + 0.5) + + +def get_rounded_q(q_scale): + q_scale = np.clip(q_scale, 0.01, 655.) + q_index = int(np.round(q_scale * 100)) + q_scale = q_index / 100 + return q_scale, q_index + + +def consume_prefix_in_state_dict_if_present(state_dict, prefix): + keys = sorted(state_dict.keys()) + for key in keys: + if key.startswith(prefix): + newkey = key[len(prefix):] + state_dict[newkey] = state_dict.pop(key) + + # also strip the prefix in metadata if any. + if "_metadata" in state_dict: + metadata = state_dict["_metadata"] + for key in list(metadata.keys()): + if len(key) == 0: + continue + newkey = key[len(prefix):] + metadata[newkey] = metadata.pop(key) + + +def get_state_dict(ckpt_path): + def get_one_state_dict(path): + ckpt = torch.load(path, map_location=torch.device('cpu')) + if "state_dict" in ckpt: + ckpt = ckpt['state_dict'] + if "net" in ckpt: + ckpt = ckpt["net"] + consume_prefix_in_state_dict_if_present(ckpt, prefix="module.") + return ckpt + + if isinstance(ckpt_path, list): + state_dict = [get_one_state_dict(path) for path in ckpt_path] + else: + state_dict = get_one_state_dict(ckpt_path) + return state_dict + + +def filesize(filepath: str) -> int: + if not Path(filepath).is_file(): + raise ValueError(f'Invalid file "{filepath}".') + return Path(filepath).stat().st_size + + +def write_uints(fd, values, fmt=">{:d}I"): + fd.write(struct.pack(fmt.format(len(values)), *values)) + + +def write_uchars(fd, values, fmt=">{:d}B"): + fd.write(struct.pack(fmt.format(len(values)), *values)) + + +def read_uints(fd, n, fmt=">{:d}I"): + sz = struct.calcsize("I") + return struct.unpack(fmt.format(n), fd.read(n * sz)) + + +def read_uchars(fd, n, fmt=">{:d}B"): + sz = struct.calcsize("B") + return struct.unpack(fmt.format(n), fd.read(n * sz)) + + +def write_bytes(fd, values, fmt=">{:d}s"): + if len(values) == 0: + return + fd.write(struct.pack(fmt.format(len(values)), values)) + + +def read_bytes(fd, n, fmt=">{:d}s"): + sz = struct.calcsize("s") + return struct.unpack(fmt.format(n), fd.read(n * sz))[0] + + +def write_ushorts(fd, values, fmt=">{:d}H"): + fd.write(struct.pack(fmt.format(len(values)), *values)) + + +def read_ushorts(fd, n, fmt=">{:d}H"): + sz = struct.calcsize("H") + return struct.unpack(fmt.format(n), fd.read(n * sz)) + + +def encode_i(height, width, q_index, bit_stream, output): + with Path(output).open("wb") as f: + stream_length = len(bit_stream) + + write_uints(f, (height, width)) + write_ushorts(f, (q_index,)) + write_uints(f, (stream_length,)) + write_bytes(f, bit_stream) + + +def decode_i(inputpath): + with Path(inputpath).open("rb") as f: + header = read_uints(f, 2) + height = header[0] + width = header[1] + q_index = read_ushorts(f, 1)[0] + stream_length = read_uints(f, 1)[0] + bit_stream = read_bytes(f, stream_length) + + return height, width, q_index, bit_stream diff --git a/ICLR2023/test_cfg/local_kodak.json b/ICLR2023/test_cfg/local_kodak.json new file mode 100644 index 0000000..930244d --- /dev/null +++ b/ICLR2023/test_cfg/local_kodak.json @@ -0,0 +1,11 @@ +{ + "root_path": "/opt/Dataset/", + "test_classes": { + "kodak": { + "test": 1, + "base_path": "kodak", + "src_type": "png", + "img_folder": 1 + } + } +} diff --git a/ICLR2023/test_image.py b/ICLR2023/test_image.py new file mode 100644 index 0000000..12506ba --- /dev/null +++ b/ICLR2023/test_image.py @@ -0,0 +1,423 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import argparse +import os +import concurrent.futures +import json +import multiprocessing +import time + +import numpy as np +import torch +import torch.nn.functional as F +from PIL import Image +from src.models import build_model +from src.utils.common import str2bool, interpolate_log, create_folder, dump_json +from src.utils.stream_helper import get_padding_size, get_state_dict +from src.utils.png_reader import PNGReader +from tqdm import tqdm +from pytorch_msssim import ms_ssim + + +def parse_args(): + parser = argparse.ArgumentParser(description="Example testing script") + parser.add_argument("--ec_thread", type=str2bool, nargs='?', const=True, default=False) + parser.add_argument('--i_frame_model', type=str, default='EncSDecS') + parser.add_argument('--i_frame_model_path', type=str, default='') + parser.add_argument('--i_frame_q_scales', type=float, nargs="+") + parser.add_argument('--rate_num', type=int, default=4) + parser.add_argument('--test_config', type=str, required=True) + parser.add_argument("--worker", "-w", type=int, default=1, help="worker number") + parser.add_argument("--cuda", type=str2bool, nargs='?', const=True, default=False) + parser.add_argument("--cuda_device", default=None, + help="the cuda device used, e.g., 0; 0,1; 1,2,3; etc.") + parser.add_argument('--write_stream', type=str2bool, nargs='?', + const=True, default=False) + parser.add_argument('--stream_path', type=str, default="out_bin") + parser.add_argument('--save_decoded_frame', type=str2bool, default=False) + parser.add_argument('--decoded_frame_path', type=str, default='decoded_frames') + parser.add_argument('--output_path', type=str, required=True) + parser.add_argument('--verbose', type=int, default=0) + + args = parser.parse_args() + return args + + +def read_image_to_torch(path): + input_image = Image.open(path).convert('RGB') + input_image = np.asarray(input_image).astype('float64').transpose(2, 0, 1) + input_image = torch.from_numpy(input_image).type(torch.FloatTensor) + input_image = input_image.unsqueeze(0)/255 + return input_image + + +def np_image_to_tensor(img): + image = torch.from_numpy(img).type(torch.FloatTensor) + image = image.unsqueeze(0) + return image + + +def save_torch_image(img, save_path): + img = img.squeeze(0).permute(1, 2, 0).detach().cpu().numpy() + img = np.clip(np.rint(img * 255), 0, 255).astype(np.uint8) + Image.fromarray(img).save(save_path) + + +def PSNR(input1, input2): + mse = torch.mean((input1 - input2) ** 2) + psnr = 20 * torch.log10(1 / torch.sqrt(mse)) + return psnr.item() + + +def avg_imagejson(video_json): + dirname = os.path.dirname(video_json) + basename = os.path.basename(video_json) + names = basename.split('.') + names[-2] = names[-2] + '_avg' + out_json_file = os.path.join(dirname, '.'.join(names)) + out_json = dict() + out_json['name'] = str(video_json) + + def RD_oneclass(videojson): + bpp_d = dict() + psnr_d = dict() + rate_d = dict() + for subseq in videojson: + sub = videojson[subseq] + for image in sub: + img = sub[image] + if 'rate_idx' in img: + rate_idx = img['rate_idx'] + elif 'qp' in img: + rate_idx = img['qp'] + bpp = img['bpp'] + psnr = img['psnr'] + bpp_d[rate_idx] = bpp_d.get(rate_idx, 0) + bpp + psnr_d[rate_idx] = psnr_d.get(rate_idx, 0) + psnr + rate_d[rate_idx] = rate_d.get(rate_idx, 0) + 1 + rates = list(rate_d.keys()) + rates.sort() + bpp = [] + psnr = [] + for rate in rates: + bpp.append(bpp_d[rate] / rate_d[rate]) + psnr.append(psnr_d[rate] / rate_d[rate]) + return bpp, psnr + + with open(video_json) as f: + videojson = json.load(f) + + for c in videojson: + bpp, psnr = RD_oneclass(videojson[c]) + out_json[c] = dict() + out_json[c]['bpp'] = bpp + out_json[c]['psnr'] = psnr + + with open(out_json_file, 'w') as fp: + dump_json(out_json, fp, float_digits=6, indent=2) + + +def run_test(i_frame_net, args, device): + write_stream = 'write_stream' in args and args['write_stream'] + save_decoded_frame = 'save_decoded_frame' in args and args['save_decoded_frame'] + verbose = args['verbose'] if 'verbose' in args else 0 + + if args['src_type'] == 'png' and args['img_folder'] == 1: + src_reader = PNGReader(args['img_path']) + else: + raise NotImplementedError() + + frame_pixel_num = 0 + + start_time = time.time() + with torch.no_grad(): + frame_start_time = time.time() + rgb = src_reader.read_one_frame(src_format="rgb") + x = np_image_to_tensor(rgb) + x = x.to(device) + pic_height = x.shape[2] + pic_width = x.shape[3] + + if frame_pixel_num == 0: + frame_pixel_num = x.shape[2] * x.shape[3] + else: + assert frame_pixel_num == x.shape[2] * x.shape[3] + + # pad if necessary + padding_l, padding_r, padding_t, padding_b = get_padding_size(pic_height, pic_width) + x_padded = torch.nn.functional.pad( + x, + (padding_l, padding_r, padding_t, padding_b), + mode="constant", + value=0, + ) + + bin_path = os.path.join(args['bin_folder'], "0.bin") \ + if write_stream else None + + result = i_frame_net.encode_decode(x_padded, args['i_frame_q_scale'], bin_path, + pic_height=pic_height, pic_width=pic_width) + + bits = result["bit"] + bpp = bits / pic_width / pic_height + encoding_time = result.get('enc_time', 0) + decoding_time = result.get('dec_time', 0) + recon_frame = result["x_hat"] + recon_frame = recon_frame.clamp_(0, 1) + x_hat = F.pad(recon_frame, (-padding_l, -padding_r, -padding_t, -padding_b)) + psnr = PSNR(x_hat, x) + msssim = ms_ssim(x_hat, x, data_range=1).item() + frame_end_time = time.time() + + if verbose >= 1: + print(f"{os.path.basename(args['img_path'])}, rate: {args['rate_idx']}, " + f"{frame_end_time - frame_start_time:.3f} seconds,", + f"bpp: {bpp:.3f}, PSNR: {psnr:.4f}, MS-SSIM: {msssim:.4f}, " + f"encoding_time: {encoding_time:.4f}, decoding_time: {decoding_time:.4f} ") + + if save_decoded_frame: + folder_name = f"{args['rate_idx']}_{bpp:.4f}_{psnr:.4f}_{encoding_time:.4f}_{decoding_time:.4f}" + save_path = os.path.join(args['decoded_frame_folder'], f"{os.path.basename(args['img_path'])}.png") + save_torch_image(x_hat, save_path) + os.rename(args['decoded_frame_folder'], args['decoded_frame_folder'] + f'/../{folder_name}') + + test_time = time.time() - start_time + + log_result = {} + log_result['frame_pixel_num'] = frame_pixel_num + log_result['bpp'] = bpp + log_result['psnr'] = psnr + log_result['msssim'] = msssim + log_result['test_time'] = test_time + return log_result + + +def encode_one(args, device): + i_state_dict = get_state_dict(args['i_frame_model_path']) + i_frame_net = build_model(args['i_frame_model'], ec_thread=args['ec_thread']) + i_frame_net.load_state_dict(i_state_dict, verbose=False) + if hasattr(i_frame_net, 'set_rate'): + i_frame_net.set_rate(args['rate_idx']) + i_frame_net = i_frame_net.to(device) + i_frame_net.eval() + + if args['write_stream']: + i_frame_net.update(force=True) + + sub_dir_name = os.path.basename(args['img_path']) + + bin_folder = os.path.join(args['stream_path'], sub_dir_name, str(args['rate_idx'])) + if args['write_stream']: + create_folder(bin_folder, True) + + if args['save_decoded_frame']: + decoded_frame_folder = os.path.join(args['decoded_frame_path'], sub_dir_name, + str(args['rate_idx'])) + create_folder(decoded_frame_folder) + else: + decoded_frame_folder = None + + if 'img_path' not in args: + args['img_path'] = os.path.join(args['dataset_path'], sub_dir_name) + args['bin_folder'] = bin_folder + args['decoded_frame_folder'] = decoded_frame_folder + + result = run_test(i_frame_net, args, device=device) + + result['ds_name'] = args['ds_name'] + result['img_path'] = args['img_path'] + result['rate_idx'] = args['rate_idx'] + + return result + + +def worker(use_cuda, args): + torch.backends.cudnn.benchmark = False + try: + torch.use_deterministic_algorithms(True) + except Exception: + torch.set_deterministic(True) + torch.manual_seed(0) + torch.set_num_threads(1) + np.random.seed(seed=0) + gpu_num = 0 + if use_cuda: + gpu_num = torch.cuda.device_count() + + process_name = multiprocessing.current_process().name + if process_name == 'MainProcess': + process_idx = 0 + else: + process_idx = int(process_name[process_name.rfind('-') + 1:]) + gpu_id = -1 + if gpu_num > 0: + gpu_id = process_idx % gpu_num + if gpu_id >= 0: + device = f"cuda:{gpu_id}" + else: + device = "cpu" + + result = encode_one(args, device) + return result + + +def prepare_args(args, config, ds_name, img_name, rate_idx, i_frame_q_scales): + cur_args = {} + cur_args['rate_idx'] = rate_idx + cur_args['ec_thread'] = args.ec_thread + cur_args['i_frame_model'] = args.i_frame_model + cur_args['i_frame_model_path'] = args.i_frame_model_path + if len(i_frame_q_scales) > 0: + cur_args['i_frame_q_scale'] = i_frame_q_scales[rate_idx].to(torch.float32) + else: + cur_args['i_frame_q_scale'] = [] + cur_args['img_path'] = img_name + cur_args['src_type'] = config[ds_name]['src_type'] + cur_args['img_folder'] = config[ds_name].get('img_folder', 0) + if cur_args['img_folder'] == 1: + cur_args['img_path'] = os.path.join(args.root_path, config[ds_name]['base_path'], img_name) + else: + cur_args['src_height'] = config[ds_name]['sequences'][img_name]['height'] + cur_args['src_width'] = config[ds_name]['sequences'][img_name]['width'] + cur_args['dataset_path'] = os.path.join(args.root_path, config[ds_name]['base_path']) + cur_args['write_stream'] = args.write_stream + cur_args['stream_path'] = args.stream_path + cur_args['save_decoded_frame'] = args.save_decoded_frame + cur_args['decoded_frame_path'] = f'{args.decoded_frame_path}' + cur_args['ds_name'] = ds_name + cur_args['verbose'] = args.verbose + return cur_args + + +def main(): + begin_time = time.time() + + torch.backends.cudnn.enabled = True + args = parse_args() + + if args.cuda_device is not None and args.cuda_device != '': + os.environ['CUDA_VISIBLE_DEVICES'] = args.cuda_device + os.environ['CUBLAS_WORKSPACE_CONFIG'] = ":4096:8" + + worker_num = args.worker + assert worker_num >= 1 + + with open(args.test_config) as f: + config = json.load(f) + + if worker_num > 1: + multiprocessing.set_start_method("spawn") + threadpool_executor = concurrent.futures.ProcessPoolExecutor(max_workers=worker_num) + objs = [] + + count_frames = 0 + count_imgs = 0 + + rate_num = args.rate_num + q_scales_list = [] + + ckpt = get_state_dict(args.i_frame_model_path) + if "q_scale" in ckpt: + q_scales = ckpt["q_scale"] + elif "student.q_scale" in ckpt: + q_scales = ckpt["student.q_scale"] + elif "teacher.q_scale" in ckpt: + q_scales = ckpt["teacher.q_scale"] + else: + raise ValueError("q_scale") + q_scales_list.append(q_scales.reshape(-1)) + if q_scales_list: + i_frame_q_scales = torch.cat(q_scales_list) + else: + i_frame_q_scales = [] + + print("q_scales in intra ckpt: ", end='') + for q in i_frame_q_scales: + print(f"{q:.3f}, ", end='') + print() + if args.i_frame_q_scales is not None: + assert len(args.i_frame_q_scales) == rate_num + i_frame_q_scales = torch.tensor(args.i_frame_q_scales) + print(f"testing {rate_num} rate points with pre-defined intra y q_scales: ", end='') + elif len(i_frame_q_scales) == rate_num: + print(f"testing {rate_num} rate points with intra y q_scales in ckpt: ", end='') + elif len(i_frame_q_scales) > 0: + max_q_scale = i_frame_q_scales[0] + min_q_scale = i_frame_q_scales[-1] + i_frame_q_scales = interpolate_log(min_q_scale, max_q_scale, rate_num) + i_frame_q_scales = torch.tensor(i_frame_q_scales) + print(f"testing {rate_num} rates, using intra y q_scales: ", end='') + + for q in i_frame_q_scales: + print(f"{q:.3f}, ", end='') + print() + + results = [] + args.root_path = config['root_path'] + config = config['test_classes'] + for ds_name in config: + if config[ds_name]['test'] == 0: + continue + if config[ds_name].get('img_folder', 0) == 1: + imgs = os.listdir(os.path.join(args.root_path, config[ds_name]['base_path'])) + imgs.sort() + imgs = [f for f in imgs if f.endswith(config[ds_name]['src_type'])] + else: + imgs = config[ds_name]['sequences'] + for img_name in imgs: + count_imgs += 1 + for rate_idx in range(rate_num): + cur_args = prepare_args(args, config, ds_name, img_name, rate_idx, i_frame_q_scales) + count_frames += 1 + + if worker_num > 1: + obj = threadpool_executor.submit( + worker, + args.cuda, + cur_args) + objs.append(obj) + else: + result = worker(args.cuda, cur_args) + results.append(result) + + if worker_num > 1: + for obj in tqdm(objs): + result = obj.result() + results.append(result) + + log_result = {} + for ds_name in config: + if config[ds_name]['test'] == 0: + continue + log_result[ds_name] = {} + if config[ds_name].get('img_folder', 0) == 1: + imgs = os.listdir(os.path.join(args.root_path, config[ds_name]['base_path'])) + imgs.sort() + imgs = [f for f in imgs if f.endswith(config[ds_name]['src_type'])] + else: + imgs = config[ds_name]['sequences'] + for img in imgs: + log_result[ds_name][img] = {} + for rate in range(rate_num): + for res in results: + if res['rate_idx'] == rate and ds_name == res['ds_name'] \ + and img == os.path.basename(res['img_path']): + log_result[ds_name][img][f"{rate:03d}"] = res + + out_json_dir = os.path.dirname(args.output_path) + if len(out_json_dir) > 0: + create_folder(out_json_dir, True) + with open(args.output_path, 'w') as fp: + dump_json(log_result, fp, float_digits=6, indent=2) + + avg_imagejson(args.output_path) + + total_minutes = (time.time() - begin_time) / 60 + print('Test finished') + print(f'Tested {count_frames} frames from {count_imgs} images') + print(f'Total elapsed time: {total_minutes:.1f} min') + + +if __name__ == "__main__": + main() diff --git a/README.md b/README.md index 10b9549..8e2efd4 100644 --- a/README.md +++ b/README.md @@ -5,6 +5,7 @@ Official Pytorch implementation for Neural Video Compression including: * [Hybrid Spatial-Temporal Entropy Modelling for Neural Video Compression](https://arxiv.org/abs/2207.05894), ACM MM 2022, in [this folder](./ACMMM2022/). - The first end-to-end neural video codec to exceed H.266 (VTM) using the highest compression ratio configuration, in terms of both PSNR and MS-SSIM. - The first end-to-end neural video codec to achieve rate adjustment in single model. +* [EVC: Towards Real-Time Neural Image Compression with Mask Decay](https://openreview.net/forum?id=XUxad2Gj40n), ICLR 2023, in [this folder](./ICLR2023). # On the comparison @@ -38,6 +39,13 @@ If you find this work useful for your research, please cite: booktitle={Proceedings of the 30th ACM International Conference on Multimedia}, year={2022} } + +@inproceedings{wang2023EVC, + title={EVC: Towards Real-Time Neural Image Compression with Mask Decay}, + author={Wang, Guo-Hua and Li, Jiahao and Li, Bin and Lu, Yan}, + booktitle={International Conference on Learning Representations}, + year={2023} +} ``` # Trademarks