diff --git a/.gitignore b/.gitignore index 5ac5364..215effa 100644 --- a/.gitignore +++ b/.gitignore @@ -3,4 +3,6 @@ *.bin *.png *.so -build/ \ No newline at end of file +build/ +*.tar +*.pyd diff --git a/ACMMM2022/README.md b/ACMMM2022/README.md index fcc39fe..3751c1b 100644 --- a/ACMMM2022/README.md +++ b/ACMMM2022/README.md @@ -1 +1,108 @@ -Coming soon. \ No newline at end of file +# Introduction + +Official Pytorch implementation for [Hybrid Spatial-Temporal Entropy Modelling for Neural Video Compression](https://arxiv.org/abs/2207.05894), in ACM MM 2022. + +# Prerequisites +* Python 3.8 and conda, get [Conda](https://www.anaconda.com/) +* CUDA if want to use GPU +* Environment + ``` + conda create -n $YOUR_PY38_ENV_NAME python=3.8 + conda activate $YOUR_PY38_ENV_NAME + + conda install pytorch==1.11.0 torchvision==0.12.0 torchaudio==0.11.0 cudatoolkit=11.3 -c pytorch + pip install -r requirements.txt + ``` + +# Test dataset +We support arbitrary original resolution. The input video resolution will be padded to 64x automatically. The reconstructed video will be cropped back to the original size. The distortion (PSNR/MS-SSIM) is calculated at original resolution. + +The dataset format can be seen in dataset_config_example.json. + +For example, one video of HEVC Class B can be prepared as: +* Make the video path: + ``` + mkdir BasketballDrive_1920x1080_50 + ``` +* Convert YUV to PNG: + ``` + ffmpeg -pix_fmt yuv420p -s 1920x1080 -i BasketballDrive_1920x1080_50.yuv -f image2 BasketballDrive_1920x1080_50/im%05d.png + ``` +At last, the folder structure of dataset is like: + + /media/data/HEVC_B/ + * BQTerrace_1920x1080_60/ + - im00001.png + - im00002.png + - im00003.png + - ... + * BasketballDrive_1920x1080_50/ + - im00001.png + - im00002.png + - im00003.png + - ... + * ... + /media/data/HEVC_D + /media/data/HEVC_C/ + ... + +# Build the project +Please build the C++ code if want to test with actual bitstream writing. There is minor difference about the bits for calculating the bits using entropy (the method used in the paper to report numbers) and actual bitstreaming writing. There is overhead when writing the bitstream into the file and the difference percentage depends on the bitstream size. Usually, the overhead for 1080p content is less than 0.5%. +## On Windows +```bash +cd src +mkdir build +cd build +conda activate $YOUR_PY38_ENV_NAME +cmake ../cpp -G "Visual Studio 16 2019" -A x64 +cmake --build . --config Release +``` + +## On Linux +```bash +sudo apt-get install cmake g++ +cd src +mkdir build +cd build +conda activate $YOUR_PY38_ENV_NAME +cmake ../cpp -DCMAKE_BUILD_TYPE=Release +make -j +``` + +# Pretrained models + +* Download [Our pretrained models](https://1drv.ms/u/s!AozfVVwtWWYoiUAGk6xr-oELbodn?e=kry2Nk) and put them into ./checkpoints folder. +* Or run the script in ./checkpoints directly to download the model. +# Test the models + +Example to test pretrained model with four rate points: +```bash +python test_video.py --i_frame_model_path ./checkpoints/acmmm2022_image_psnr.pth.tar --model_path ./checkpoints/acmmm2022_video_psnr.pth.tar --rate_num 4 --test_config ./dataset_config_example.json --cuda 1 -w 1 --write_stream 0 --output_path output.json --force_intra_period 32 --force_frame_num 96 +``` +It is recommended that the ```--worker``` number is equal to your GPU number. + +You can also specify different q_scales values to test other bitrate points. It is suggested to change all the three q_scales together and generate the interpolated q_scales between the minimum one and maximum one. +For examples, using intra_q_scales = scale_list_to_str(interpolate_log(minimum_value, maximum_value, number_of_rate_points)) +Please use --rate_num to specify the rate number and --i_frame_q_scales, --p_frame_mv_y_q_scales, --p_frame_y_q_scales to specify the q_scales. +Please note that, using q_scales out of the range [minimum_value, maximum_value] has not been tested and may generate poor encoding results. + + +# R-D Curves +![PSNR RD Curve](assets/rd_curve_psnr.png) + +# Acknowledgement +The implementation is based on [CompressAI](https://github.com/InterDigitalInc/CompressAI) and [PyTorchVideoCompression](https://github.com/ZhihaoHu/PyTorchVideoCompression). +# Citation +If you find this work useful for your research, please cite: + +``` +@inproceedings{li2022hybrid, + title={Hybrid Spatial-Temporal Entropy Modelling for Neural Video Compression}, + author={Li, Jiahao and Li, Bin and Lu, Yan}, + booktitle={Proceedings of the 30th ACM International Conference on Multimedia}, + year={2022} +} +``` + +# Trademarks +This project may contain trademarks or logos for projects, products, or services. Authorized use of Microsoft trademarks or logos is subject to and must follow [Microsoft’s Trademark & Brand Guidelines](https://www.microsoft.com/en-us/legal/intellectualproperty/trademarks/usage/general). Use of Microsoft trademarks or logos in modified versions of this project must not cause confusion or imply Microsoft sponsorship. Any use of third-party trademarks or logos are subject to those third-party’s policies. diff --git a/ACMMM2022/assets/rd_curve_psnr.png b/ACMMM2022/assets/rd_curve_psnr.png new file mode 100644 index 0000000..20fc1f0 Binary files /dev/null and b/ACMMM2022/assets/rd_curve_psnr.png differ diff --git a/ACMMM2022/checkpoints/download.py b/ACMMM2022/checkpoints/download.py new file mode 100644 index 0000000..bc08acc --- /dev/null +++ b/ACMMM2022/checkpoints/download.py @@ -0,0 +1,23 @@ +import urllib.request + + +def download_one(url, target): + urllib.request.urlretrieve(url, target) + + +def main(): + urls = { + 'https://onedrive.live.com/download?cid=2866592D5C55DF8C&resid=2866592D5C55DF8C%211220&authkey=AMRg1W3PVt_F3yc': 'acmmm2022_image_psnr.pth.tar', + 'https://onedrive.live.com/download?cid=2866592D5C55DF8C&resid=2866592D5C55DF8C%211219&authkey=ACJnPOPf1ntw_w0': 'acmmm2022_image_ssim.pth.tar', + 'https://onedrive.live.com/download?cid=2866592D5C55DF8C&resid=2866592D5C55DF8C%211217&authkey=AKpdgXQtvs-OxRs': 'acmmm2022_video_psnr.pth.tar', + 'https://onedrive.live.com/download?cid=2866592D5C55DF8C&resid=2866592D5C55DF8C%211218&authkey=ANxapLv3PcCJ4Vw': 'acmmm2022_video_ssim.pth.tar', + } + for url in urls: + target = urls[url] + print("downloading", target) + download_one(url, target) + print("downloaded", target) + + +if __name__ == "__main__": + main() diff --git a/ACMMM2022/dataset_config_example.json b/ACMMM2022/dataset_config_example.json new file mode 100644 index 0000000..0b2c098 --- /dev/null +++ b/ACMMM2022/dataset_config_example.json @@ -0,0 +1,113 @@ +{ + "root_path": "/media/data/", + "test_classes": { + "HEVC_B": { + "test": 1, + "base_path": "HEVC_B", + "src_type": "png", + "sequences": { + "BQTerrace_1920x1080_60": {"width": 1920, "height": 1080, "frames": 96, "gop": 32}, + "BasketballDrive_1920x1080_50": {"width": 1920, "height": 1080, "frames": 96, "gop": 32}, + "Cactus_1920x1080_50": {"width": 1920, "height": 1080, "frames": 96, "gop": 32}, + "Kimono1_1920x1080_24": {"width": 1920, "height": 1080, "frames": 96, "gop": 32}, + "ParkScene_1920x1080_24": {"width": 1920, "height": 1080, "frames": 96, "gop": 32} + } + }, + "HEVC_C": { + "test": 1, + "base_path": "HEVC_C", + "src_type": "png", + "sequences": { + "BQMall_832x480_60": {"width": 832, "height": 480, "frames": 96, "gop": 32}, + "BasketballDrill_832x480_50": {"width": 832, "height": 480, "frames": 96, "gop": 32}, + "PartyScene_832x480_50": {"width": 832, "height": 480, "frames": 96, "gop": 32}, + "RaceHorses_832x480_30": {"width": 832, "height": 480, "frames": 96, "gop": 32} + } + }, + "HEVC_D": { + "test": 1, + "base_path": "HEVC_D", + "src_type": "png", + "sequences": { + "BasketballPass_416x240_50": {"width": 416, "height": 240, "frames": 96, "gop": 32}, + "BlowingBubbles_416x240_50": {"width": 416, "height": 240, "frames": 96, "gop": 32}, + "BQSquare_416x240_60": {"width": 416, "height": 240, "frames": 96, "gop": 32}, + "RaceHorses_416x240_30": {"width": 416, "height": 240, "frames": 96, "gop": 32} + } + }, + "HEVC_E": { + "test": 1, + "base_path": "HEVC_E", + "src_type": "png", + "sequences": { + "FourPeople_1280x720_60": {"width": 1280, "height": 720, "frames": 96, "gop": 32}, + "Johnny_1280x720_60": {"width": 1280, "height": 720, "frames": 96, "gop": 32}, + "KristenAndSara_1280x720_60": {"width": 1280, "height": 720, "frames": 96, "gop": 32} + } + }, + "HEVC_RGB": { + "test": 1, + "base_path": "HEVC_RGB", + "src_type": "png", + "sequences": { + "DucksAndLegs_1920x1080_30_RGB": {"width": 1920, "height": 1080, "frames": 96, "gop": 32}, + "EBULupoCandlelight_1920x1080_50_RGB": {"width": 1920, "height": 1080, "frames": 96, "gop": 32}, + "EBURainFruits_1920x1080_50_RGB": {"width": 1920, "height": 1080, "frames": 96, "gop": 32}, + "Kimono1_1920x1080_24_RGB": {"width": 1920, "height": 1080, "frames": 96, "gop": 32}, + "OldTownCross_1920x1080_50_RGB": {"width": 1920, "height": 1080, "frames": 96, "gop": 32}, + "ParkScene_1920x1080_24_RGB": {"width": 1920, "height": 1080, "frames": 96, "gop": 32} + } + }, + "UVG": { + "test": 1, + "base_path": "UVG", + "src_type": "png", + "sequences": { + "Beauty_1920x1080_120fps_420_8bit_YUV": {"width": 1920, "height": 1080, "frames": 96, "gop": 32}, + "Bosphorus_1920x1080_120fps_420_8bit_YUV": {"width": 1920, "height": 1080, "frames": 96, "gop": 32}, + "HoneyBee_1920x1080_120fps_420_8bit_YUV": {"width": 1920, "height": 1080, "frames": 96, "gop": 32}, + "Jockey_1920x1080_120fps_420_8bit_YUV": {"width": 1920, "height": 1080, "frames": 96, "gop": 32}, + "ReadySteadyGo_1920x1080_120fps_420_8bit_YUV": {"width": 1920, "height": 1080, "frames": 96, "gop": 32}, + "ShakeNDry_1920x1080_120fps_420_8bit_YUV": {"width": 1920, "height": 1080, "frames": 96, "gop": 32}, + "YachtRide_1920x1080_120fps_420_8bit_YUV": {"width": 1920, "height": 1080, "frames": 96, "gop": 32} + } + }, + "MCL-JCV": { + "test": 1, + "base_path": "MCL-JCV", + "src_type": "png", + "sequences": { + "videoSRC01_1920x1080_30": {"width": 1920, "height": 1080, "frames": 96, "gop": 32}, + "videoSRC02_1920x1080_30": {"width": 1920, "height": 1080, "frames": 96, "gop": 32}, + "videoSRC03_1920x1080_30": {"width": 1920, "height": 1080, "frames": 96, "gop": 32}, + "videoSRC04_1920x1080_30": {"width": 1920, "height": 1080, "frames": 96, "gop": 32}, + "videoSRC05_1920x1080_25": {"width": 1920, "height": 1080, "frames": 96, "gop": 32}, + "videoSRC06_1920x1080_25": {"width": 1920, "height": 1080, "frames": 96, "gop": 32}, + "videoSRC07_1920x1080_25": {"width": 1920, "height": 1080, "frames": 96, "gop": 32}, + "videoSRC08_1920x1080_25": {"width": 1920, "height": 1080, "frames": 96, "gop": 32}, + "videoSRC09_1920x1080_25": {"width": 1920, "height": 1080, "frames": 96, "gop": 32}, + "videoSRC10_1920x1080_30": {"width": 1920, "height": 1080, "frames": 96, "gop": 32}, + "videoSRC11_1920x1080_30": {"width": 1920, "height": 1080, "frames": 96, "gop": 32}, + "videoSRC12_1920x1080_30": {"width": 1920, "height": 1080, "frames": 96, "gop": 32}, + "videoSRC13_1920x1080_30": {"width": 1920, "height": 1080, "frames": 96, "gop": 32}, + "videoSRC14_1920x1080_30": {"width": 1920, "height": 1080, "frames": 96, "gop": 32}, + "videoSRC15_1920x1080_30": {"width": 1920, "height": 1080, "frames": 96, "gop": 32}, + "videoSRC16_1920x1080_30": {"width": 1920, "height": 1080, "frames": 96, "gop": 32}, + "videoSRC17_1920x1080_24": {"width": 1920, "height": 1080, "frames": 96, "gop": 32}, + "videoSRC18_1920x1080_25": {"width": 1920, "height": 1080, "frames": 96, "gop": 32}, + "videoSRC19_1920x1080_30": {"width": 1920, "height": 1080, "frames": 96, "gop": 32}, + "videoSRC20_1920x1080_25": {"width": 1920, "height": 1080, "frames": 96, "gop": 32}, + "videoSRC21_1920x1080_24": {"width": 1920, "height": 1080, "frames": 96, "gop": 32}, + "videoSRC22_1920x1080_24": {"width": 1920, "height": 1080, "frames": 96, "gop": 32}, + "videoSRC23_1920x1080_24": {"width": 1920, "height": 1080, "frames": 96, "gop": 32}, + "videoSRC24_1920x1080_24": {"width": 1920, "height": 1080, "frames": 96, "gop": 32}, + "videoSRC25_1920x1080_24": {"width": 1920, "height": 1080, "frames": 96, "gop": 32}, + "videoSRC26_1920x1080_30": {"width": 1920, "height": 1080, "frames": 96, "gop": 32}, + "videoSRC27_1920x1080_30": {"width": 1920, "height": 1080, "frames": 96, "gop": 32}, + "videoSRC28_1920x1080_30": {"width": 1920, "height": 1080, "frames": 96, "gop": 32}, + "videoSRC29_1920x1080_24": {"width": 1920, "height": 1080, "frames": 96, "gop": 32}, + "videoSRC30_1920x1080_30": {"width": 1920, "height": 1080, "frames": 96, "gop": 32} + } + } + } +} diff --git a/ACMMM2022/requirements.txt b/ACMMM2022/requirements.txt new file mode 100644 index 0000000..04dada3 --- /dev/null +++ b/ACMMM2022/requirements.txt @@ -0,0 +1,9 @@ +numpy>=1.20.0 +scipy +matplotlib==3.3.4 +torch>=1.10.0 +pytorch-msssim==0.2.0 +tensorboard +tqdm +bd-metric +ptflops diff --git a/ACMMM2022/src/cpp/3rdparty/CMakeLists.txt b/ACMMM2022/src/cpp/3rdparty/CMakeLists.txt new file mode 100644 index 0000000..8f63698 --- /dev/null +++ b/ACMMM2022/src/cpp/3rdparty/CMakeLists.txt @@ -0,0 +1,5 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +add_subdirectory(pybind11) +add_subdirectory(ryg_rans) \ No newline at end of file diff --git a/ACMMM2022/src/cpp/3rdparty/pybind11/CMakeLists.txt b/ACMMM2022/src/cpp/3rdparty/pybind11/CMakeLists.txt new file mode 100644 index 0000000..3c88809 --- /dev/null +++ b/ACMMM2022/src/cpp/3rdparty/pybind11/CMakeLists.txt @@ -0,0 +1,24 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +configure_file(CMakeLists.txt.in pybind11-download/CMakeLists.txt) +execute_process(COMMAND ${CMAKE_COMMAND} -G "${CMAKE_GENERATOR}" . + RESULT_VARIABLE result + WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/pybind11-download ) +if(result) + message(FATAL_ERROR "CMake step for pybind11 failed: ${result}") +endif() +execute_process(COMMAND ${CMAKE_COMMAND} --build . + RESULT_VARIABLE result + WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/pybind11-download ) +if(result) + message(FATAL_ERROR "Build step for pybind11 failed: ${result}") +endif() + +add_subdirectory(${CMAKE_CURRENT_BINARY_DIR}/pybind11-src/ + ${CMAKE_CURRENT_BINARY_DIR}/pybind11-build/ + EXCLUDE_FROM_ALL) + +set(PYBIND11_INCLUDE + ${CMAKE_CURRENT_BINARY_DIR}/pybind11-src/include/ + CACHE INTERNAL "") diff --git a/ACMMM2022/src/cpp/3rdparty/pybind11/CMakeLists.txt.in b/ACMMM2022/src/cpp/3rdparty/pybind11/CMakeLists.txt.in new file mode 100644 index 0000000..f0b4565 --- /dev/null +++ b/ACMMM2022/src/cpp/3rdparty/pybind11/CMakeLists.txt.in @@ -0,0 +1,33 @@ +cmake_minimum_required(VERSION 3.6.3) + +project(pybind11-download NONE) + +include(ExternalProject) +if(IS_DIRECTORY "${PROJECT_BINARY_DIR}/3rdparty/pybind11/pybind11-src/include") + ExternalProject_Add(pybind11 + GIT_REPOSITORY https://github.com/pybind/pybind11.git + GIT_TAG v2.6.1 + GIT_SHALLOW 1 + SOURCE_DIR "${CMAKE_CURRENT_BINARY_DIR}/pybind11-src" + BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}/pybind11-build" + DOWNLOAD_COMMAND "" + UPDATE_COMMAND "" + CONFIGURE_COMMAND "" + BUILD_COMMAND "" + INSTALL_COMMAND "" + TEST_COMMAND "" + ) +else() + ExternalProject_Add(pybind11 + GIT_REPOSITORY https://github.com/pybind/pybind11.git + GIT_TAG v2.6.1 + GIT_SHALLOW 1 + SOURCE_DIR "${CMAKE_CURRENT_BINARY_DIR}/pybind11-src" + BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}/pybind11-build" + UPDATE_COMMAND "" + CONFIGURE_COMMAND "" + BUILD_COMMAND "" + INSTALL_COMMAND "" + TEST_COMMAND "" + ) +endif() diff --git a/ACMMM2022/src/cpp/3rdparty/ryg_rans/CMakeLists.txt b/ACMMM2022/src/cpp/3rdparty/ryg_rans/CMakeLists.txt new file mode 100644 index 0000000..d7a23bf --- /dev/null +++ b/ACMMM2022/src/cpp/3rdparty/ryg_rans/CMakeLists.txt @@ -0,0 +1,24 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +configure_file(CMakeLists.txt.in ryg_rans-download/CMakeLists.txt) +execute_process(COMMAND ${CMAKE_COMMAND} -G "${CMAKE_GENERATOR}" . + RESULT_VARIABLE result + WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/ryg_rans-download ) +if(result) + message(FATAL_ERROR "CMake step for ryg_rans failed: ${result}") +endif() +execute_process(COMMAND ${CMAKE_COMMAND} --build . + RESULT_VARIABLE result + WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/ryg_rans-download ) +if(result) + message(FATAL_ERROR "Build step for ryg_rans failed: ${result}") +endif() + +# add_subdirectory(${CMAKE_CURRENT_BINARY_DIR}/ryg_rans-src/ +# ${CMAKE_CURRENT_BINARY_DIR}/ryg_rans-build +# EXCLUDE_FROM_ALL) + +set(RYG_RANS_INCLUDE + ${CMAKE_CURRENT_BINARY_DIR}/ryg_rans-src/ + CACHE INTERNAL "") diff --git a/ACMMM2022/src/cpp/3rdparty/ryg_rans/CMakeLists.txt.in b/ACMMM2022/src/cpp/3rdparty/ryg_rans/CMakeLists.txt.in new file mode 100644 index 0000000..3c62451 --- /dev/null +++ b/ACMMM2022/src/cpp/3rdparty/ryg_rans/CMakeLists.txt.in @@ -0,0 +1,33 @@ +cmake_minimum_required(VERSION 3.6.3) + +project(ryg_rans-download NONE) + +include(ExternalProject) +if(EXISTS "${PROJECT_BINARY_DIR}/3rdparty/ryg_rans/ryg_rans-src/rans64.h") + ExternalProject_Add(ryg_rans + GIT_REPOSITORY https://github.com/rygorous/ryg_rans.git + GIT_TAG c9d162d996fd600315af9ae8eb89d832576cb32d + GIT_SHALLOW 1 + SOURCE_DIR "${CMAKE_CURRENT_BINARY_DIR}/ryg_rans-src" + BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}/ryg_rans-build" + DOWNLOAD_COMMAND "" + UPDATE_COMMAND "" + CONFIGURE_COMMAND "" + BUILD_COMMAND "" + INSTALL_COMMAND "" + TEST_COMMAND "" + ) +else() + ExternalProject_Add(ryg_rans + GIT_REPOSITORY https://github.com/rygorous/ryg_rans.git + GIT_TAG c9d162d996fd600315af9ae8eb89d832576cb32d + GIT_SHALLOW 1 + SOURCE_DIR "${CMAKE_CURRENT_BINARY_DIR}/ryg_rans-src" + BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}/ryg_rans-build" + UPDATE_COMMAND "" + CONFIGURE_COMMAND "" + BUILD_COMMAND "" + INSTALL_COMMAND "" + TEST_COMMAND "" + ) +endif() diff --git a/ACMMM2022/src/cpp/CMakeLists.txt b/ACMMM2022/src/cpp/CMakeLists.txt new file mode 100644 index 0000000..069e920 --- /dev/null +++ b/ACMMM2022/src/cpp/CMakeLists.txt @@ -0,0 +1,23 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +cmake_minimum_required (VERSION 3.6.3) +project (ErrorRecovery) + +set(CMAKE_CONFIGURATION_TYPES "Debug;Release;RelWithDebInfo" CACHE STRING "" FORCE) + +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED ON) +set(CMAKE_CXX_EXTENSIONS OFF) + +# treat warning as error +if (MSVC) + add_compile_options(/W4 /WX) +else() + add_compile_options(-Wall -Wextra -pedantic -Werror) +endif() + +# The sequence is tricky, put 3rd party first +add_subdirectory(3rdparty) +add_subdirectory (ops) +add_subdirectory (rans) diff --git a/ACMMM2022/src/cpp/ops/CMakeLists.txt b/ACMMM2022/src/cpp/ops/CMakeLists.txt new file mode 100644 index 0000000..03b72c4 --- /dev/null +++ b/ACMMM2022/src/cpp/ops/CMakeLists.txt @@ -0,0 +1,28 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +cmake_minimum_required(VERSION 3.7) +set(PROJECT_NAME MLCodec_CXX) +project(${PROJECT_NAME}) + +set(cxx_source + ops.cpp + ) + +set(include_dirs + ${CMAKE_CURRENT_SOURCE_DIR} + ${PYBIND11_INCLUDE} + ) + +pybind11_add_module(${PROJECT_NAME} ${cxx_source}) + +target_include_directories (${PROJECT_NAME} PUBLIC ${include_dirs}) + +# The post build argument is executed after make! +add_custom_command( + TARGET ${PROJECT_NAME} POST_BUILD + COMMAND + "${CMAKE_COMMAND}" -E copy + "$" + "${CMAKE_CURRENT_SOURCE_DIR}/../../entropy_models/" +) diff --git a/ACMMM2022/src/cpp/ops/ops.cpp b/ACMMM2022/src/cpp/ops/ops.cpp new file mode 100644 index 0000000..9463ab7 --- /dev/null +++ b/ACMMM2022/src/cpp/ops/ops.cpp @@ -0,0 +1,91 @@ +/* Copyright 2020 InterDigital Communications, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#include +#include +#include +#include + +std::vector pmf_to_quantized_cdf(const std::vector &pmf, + int precision) { + /* NOTE(begaintj): ported from `ryg_rans` public implementation. Not optimal + * although it's only run once per model after training. See TF/compression + * implementation for an optimized version. */ + + std::vector cdf(pmf.size() + 1); + cdf[0] = 0; /* freq 0 */ + + std::transform(pmf.begin(), pmf.end(), cdf.begin() + 1, [=](float p) { + return static_cast(std::round(p * (1 << precision)) + 0.5); + }); + + const uint32_t total = std::accumulate(cdf.begin(), cdf.end(), 0); + + std::transform( + cdf.begin(), cdf.end(), cdf.begin(), [precision, total](uint32_t p) { + return static_cast((((1ull << precision) * p) / total)); + }); + + std::partial_sum(cdf.begin(), cdf.end(), cdf.begin()); + cdf.back() = 1 << precision; + + for (int i = 0; i < static_cast(cdf.size() - 1); ++i) { + if (cdf[i] == cdf[i + 1]) { + /* Try to steal frequency from low-frequency symbols */ + uint32_t best_freq = ~0u; + int best_steal = -1; + for (int j = 0; j < static_cast(cdf.size()) - 1; ++j) { + uint32_t freq = cdf[j + 1] - cdf[j]; + if (freq > 1 && freq < best_freq) { + best_freq = freq; + best_steal = j; + } + } + + assert(best_steal != -1); + + if (best_steal < i) { + for (int j = best_steal + 1; j <= i; ++j) { + cdf[j]--; + } + } else { + assert(best_steal > i); + for (int j = i + 1; j <= best_steal; ++j) { + cdf[j]++; + } + } + } + } + + assert(cdf[0] == 0); + assert(cdf.back() == (1u << precision)); + for (int i = 0; i < static_cast(cdf.size()) - 1; ++i) { + assert(cdf[i + 1] > cdf[i]); + } + + return cdf; +} + +PYBIND11_MODULE(MLCodec_CXX, m) { + m.attr("__name__") = "MLCodec_CXX"; + + m.doc() = "C++ utils"; + + m.def("pmf_to_quantized_cdf", &pmf_to_quantized_cdf, + "Return quantized CDF for a given PMF"); +} diff --git a/ACMMM2022/src/cpp/rans/CMakeLists.txt b/ACMMM2022/src/cpp/rans/CMakeLists.txt new file mode 100644 index 0000000..1f13884 --- /dev/null +++ b/ACMMM2022/src/cpp/rans/CMakeLists.txt @@ -0,0 +1,30 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +cmake_minimum_required(VERSION 3.7) +set(PROJECT_NAME MLCodec_rans) +project(${PROJECT_NAME}) + +set(rans_source + rans_interface.hpp + rans_interface.cpp + ) + +set(include_dirs + ${CMAKE_CURRENT_SOURCE_DIR} + ${PYBIND11_INCLUDE} + ${RYG_RANS_INCLUDE} + ) + +pybind11_add_module(${PROJECT_NAME} ${rans_source}) + +target_include_directories (${PROJECT_NAME} PUBLIC ${include_dirs}) + +# The post build argument is executed after make! +add_custom_command( + TARGET ${PROJECT_NAME} POST_BUILD + COMMAND + "${CMAKE_COMMAND}" -E copy + "$" + "${CMAKE_CURRENT_SOURCE_DIR}/../../entropy_models/" +) diff --git a/ACMMM2022/src/cpp/rans/rans_interface.cpp b/ACMMM2022/src/cpp/rans/rans_interface.cpp new file mode 100644 index 0000000..a1baf1f --- /dev/null +++ b/ACMMM2022/src/cpp/rans/rans_interface.cpp @@ -0,0 +1,261 @@ +/* Copyright 2020 InterDigital Communications, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* Rans64 extensions from: + * https://fgiesen.wordpress.com/2015/12/21/rans-in-practice/ + * Unbounded range coding from: + * https://github.com/tensorflow/compression/blob/master/tensorflow_compression/cc/kernels/unbounded_index_range_coding_kernels.cc + **/ + +#include "rans_interface.hpp" + +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +namespace py = pybind11; + +/* probability range, this could be a parameter... */ +constexpr int precision = 16; + +constexpr uint16_t bypass_precision = 4; /* number of bits in bypass mode */ +constexpr uint16_t max_bypass_val = (1 << bypass_precision) - 1; + +namespace { + +/* Support only 16 bits word max */ +inline void Rans64EncPutBits(Rans64State *r, uint32_t **pptr, uint32_t val, + uint32_t nbits) { + assert(nbits <= 16); + assert(val < (1u << nbits)); + + /* Re-normalize */ + uint64_t x = *r; + uint32_t freq = 1 << (16 - nbits); + uint64_t x_max = ((RANS64_L >> 16) << 32) * freq; + if (x >= x_max) { + *pptr -= 1; + **pptr = (uint32_t)x; + x >>= 32; + Rans64Assert(x < x_max); + } + + /* x = C(s, x) */ + *r = (x << nbits) | val; +} + +inline uint32_t Rans64DecGetBits(Rans64State *r, uint32_t **pptr, + uint32_t n_bits) { + uint64_t x = *r; + uint32_t val = x & ((1u << n_bits) - 1); + + /* Re-normalize */ + x = x >> n_bits; + if (x < RANS64_L) { + x = (x << 32) | **pptr; + *pptr += 1; + Rans64Assert(x >= RANS64_L); + } + + *r = x; + + return val; +} +} // namespace + +void BufferedRansEncoder::encode_with_indexes( + const py::array_t &symbols, const py::array_t &indexes, + const py::array_t &cdfs, const py::array_t &cdfs_sizes, + const py::array_t &offsets) { + // backward loop on symbols from the end; + py::buffer_info symbols_buf = symbols.request(); + py::buffer_info indexes_buf = indexes.request(); + py::buffer_info cdfs_sizes_buf = cdfs_sizes.request(); + py::buffer_info offsets_buf = offsets.request(); + int32_t *symbols_ptr = static_cast(symbols_buf.ptr); + int32_t *indexes_ptr = static_cast(indexes_buf.ptr); + int32_t *cdfs_sizes_ptr = static_cast(cdfs_sizes_buf.ptr); + int32_t *offsets_ptr = static_cast(offsets_buf.ptr); + auto cdfs_raw = cdfs.unchecked<2>(); + for (pybind11::ssize_t i = 0; i < symbols.size(); ++i) { + const int32_t cdf_idx = indexes_ptr[i]; + const int32_t *cdf = cdfs_raw.data(cdf_idx, 0); + const int32_t max_value = cdfs_sizes_ptr[cdf_idx] - 2; + int32_t value = symbols_ptr[i] - offsets_ptr[cdf_idx]; + + uint32_t raw_val = 0; + if (value < 0) { + raw_val = -2 * value - 1; + value = max_value; + } else if (value >= max_value) { + raw_val = 2 * (value - max_value); + value = max_value; + } + + _syms.push_back({static_cast(cdf[value]), + static_cast(cdf[value + 1] - cdf[value]), + false}); + + /* Bypass coding mode (value == max_value -> sentinel flag) */ + if (value == max_value) { + /* Determine the number of bypasses (in bypass_precision size) needed to + * encode the raw value. */ + int32_t n_bypass = 0; + while ((raw_val >> (n_bypass * bypass_precision)) != 0) { + ++n_bypass; + } + + /* Encode number of bypasses */ + int32_t val = n_bypass; + while (val >= max_bypass_val) { + _syms.push_back({max_bypass_val, max_bypass_val + 1, true}); + val -= max_bypass_val; + } + _syms.push_back( + {static_cast(val), static_cast(val + 1), true}); + + /* Encode raw value */ + for (int32_t j = 0; j < n_bypass; ++j) { + const int32_t val1 = + (raw_val >> (j * bypass_precision)) & max_bypass_val; + _syms.push_back({static_cast(val1), + static_cast(val1 + 1), true}); + } + } + } +} + +py::bytes BufferedRansEncoder::flush() { + Rans64State rans; + Rans64EncInit(&rans); + + std::vector output(_syms.size(), 0xCC); // too much space ? + uint32_t *ptr = output.data() + output.size(); + assert(ptr != nullptr); + + while (!_syms.empty()) { + const RansSymbol sym = _syms.back(); + + if (!sym.bypass) { + Rans64EncPut(&rans, &ptr, sym.start, sym.range, precision); + } else { + // unlikely... + Rans64EncPutBits(&rans, &ptr, sym.start, bypass_precision); + } + _syms.pop_back(); + } + + Rans64EncFlush(&rans, &ptr); + + const int nbytes = static_cast( + std::distance(ptr, output.data() + output.size()) * sizeof(uint32_t)); + return std::string(reinterpret_cast(ptr), nbytes); +} + +void BufferedRansEncoder::reset() { _syms.clear(); } + +void RansDecoder::set_stream(const std::string &encoded) { + _stream = encoded; + uint32_t *ptr = (uint32_t *)_stream.data(); + assert(ptr != nullptr); + _ptr = ptr; + Rans64DecInit(&_rans, &_ptr); +} + +py::array_t +RansDecoder::decode_stream(const py::array_t &indexes, + const py::array_t &cdfs, + const py::array_t &cdfs_sizes, + const py::array_t &offsets) { + py::array_t output(indexes.size()); + + py::buffer_info output_buf = output.request(); + py::buffer_info indexes_buf = indexes.request(); + py::buffer_info cdfs_sizes_buf = cdfs_sizes.request(); + py::buffer_info offsets_buf = offsets.request(); + int32_t *outout_ptr = static_cast(output_buf.ptr); + int32_t *indexes_ptr = static_cast(indexes_buf.ptr); + int32_t *cdfs_sizes_ptr = static_cast(cdfs_sizes_buf.ptr); + int32_t *offsets_ptr = static_cast(offsets_buf.ptr); + auto cdfs_raw = cdfs.unchecked<2>(); + for (pybind11::ssize_t i = 0; i < indexes.size(); ++i) { + const int32_t cdf_idx = indexes_ptr[i]; + const int32_t *cdf = cdfs_raw.data(cdf_idx, 0); + const int32_t max_value = cdfs_sizes_ptr[cdf_idx] - 2; + const int32_t offset = offsets_ptr[cdf_idx]; + const uint32_t cum_freq = Rans64DecGet(&_rans, precision); + + const auto cdf_end = cdf + cdfs_sizes_ptr[cdf_idx]; + const auto it = std::find_if(cdf, cdf_end, [cum_freq](int v) { + return static_cast(v) > cum_freq; + }); + const uint32_t s = static_cast(std::distance(cdf, it) - 1); + + Rans64DecAdvance(&_rans, &_ptr, cdf[s], cdf[s + 1] - cdf[s], precision); + + int32_t value = static_cast(s); + + if (value == max_value) { + /* Bypass decoding mode */ + int32_t val = Rans64DecGetBits(&_rans, &_ptr, bypass_precision); + int32_t n_bypass = val; + + while (val == max_bypass_val) { + val = Rans64DecGetBits(&_rans, &_ptr, bypass_precision); + n_bypass += val; + } + + int32_t raw_val = 0; + for (int j = 0; j < n_bypass; ++j) { + val = Rans64DecGetBits(&_rans, &_ptr, bypass_precision); + raw_val |= val << (j * bypass_precision); + } + value = raw_val >> 1; + if (raw_val & 1) { + value = -value - 1; + } else { + value += max_value; + } + } + + outout_ptr[i] = value + offset; + } + + return output; +} + +PYBIND11_MODULE(MLCodec_rans, m) { + m.attr("__name__") = "MLCodec_rans"; + + m.doc() = "range Asymmetric Numeral System python bindings"; + + py::class_(m, "BufferedRansEncoder") + .def(py::init<>()) + .def("encode_with_indexes", &BufferedRansEncoder::encode_with_indexes) + .def("flush", &BufferedRansEncoder::flush) + .def("reset", &BufferedRansEncoder::reset); + + py::class_(m, "RansDecoder") + .def(py::init<>()) + .def("set_stream", &RansDecoder::set_stream) + .def("decode_stream", &RansDecoder::decode_stream); +} diff --git a/ACMMM2022/src/cpp/rans/rans_interface.hpp b/ACMMM2022/src/cpp/rans/rans_interface.hpp new file mode 100644 index 0000000..af882c2 --- /dev/null +++ b/ACMMM2022/src/cpp/rans/rans_interface.hpp @@ -0,0 +1,91 @@ +/* Copyright 2020 InterDigital Communications, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include + +#ifdef __GNUC__ +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wpedantic" +#pragma GCC diagnostic ignored "-Wsign-compare" +#elif _MSC_VER +#pragma warning(push, 0) +#endif + +#include + +#ifdef __GNUC__ +#pragma GCC diagnostic pop +#elif _MSC_VER +#pragma warning(pop) +#endif + +namespace py = pybind11; + +struct RansSymbol { + uint16_t start; + uint16_t range; + bool bypass; // bypass flag to write raw bits to the stream +}; + +/* NOTE: Warning, we buffer everything for now... In case of large files we + * should split the bitstream into chunks... Or for a memory-bounded encoder + **/ +class BufferedRansEncoder { +public: + BufferedRansEncoder() = default; + + BufferedRansEncoder(const BufferedRansEncoder &) = delete; + BufferedRansEncoder(BufferedRansEncoder &&) = delete; + BufferedRansEncoder &operator=(const BufferedRansEncoder &) = delete; + BufferedRansEncoder &operator=(BufferedRansEncoder &&) = delete; + + void encode_with_indexes(const py::array_t &symbols, + const py::array_t &indexes, + const py::array_t &cdfs, + const py::array_t &cdfs_sizes, + const py::array_t &offsets); + py::bytes flush(); + void reset(); + +private: + std::vector _syms; +}; + +class RansDecoder { +public: + RansDecoder() = default; + + RansDecoder(const RansDecoder &) = delete; + RansDecoder(RansDecoder &&) = delete; + RansDecoder &operator=(const RansDecoder &) = delete; + RansDecoder &operator=(RansDecoder &&) = delete; + + void set_stream(const std::string &stream); + + py::array_t + decode_stream(const py::array_t &indexes, + const py::array_t &cdfs, + const py::array_t &cdfs_sizes, + const py::array_t &offsets); + +private: + Rans64State _rans; + std::string _stream; + uint32_t *_ptr; +}; diff --git a/ACMMM2022/src/entropy_models/entropy_models.py b/ACMMM2022/src/entropy_models/entropy_models.py new file mode 100644 index 0000000..c0970ee --- /dev/null +++ b/ACMMM2022/src/entropy_models/entropy_models.py @@ -0,0 +1,288 @@ +import math + +import torch +import numpy as np +from torch import nn +import torch.nn.functional as F + + +class EntropyCoder(): + def __init__(self): + super().__init__() + + from .MLCodec_rans import BufferedRansEncoder, RansDecoder + self.encoder = BufferedRansEncoder() + self.decoder = RansDecoder() + + @staticmethod + def pmf_to_quantized_cdf(pmf, precision=16): + from .MLCodec_CXX import pmf_to_quantized_cdf as _pmf_to_quantized_cdf + cdf = _pmf_to_quantized_cdf(pmf.tolist(), precision) + cdf = torch.IntTensor(cdf) + return cdf + + @staticmethod + def pmf_to_cdf(pmf, tail_mass, pmf_length, max_length): + entropy_coder_precision = 16 + cdf = torch.zeros((len(pmf_length), max_length + 2), dtype=torch.int32) + for i, p in enumerate(pmf): + prob = torch.cat((p[: pmf_length[i]], tail_mass[i]), dim=0) + _cdf = EntropyCoder.pmf_to_quantized_cdf(prob, entropy_coder_precision) + cdf[i, : _cdf.size(0)] = _cdf + return cdf + + def set_stream(self, stream): + self.decoder.set_stream(stream) + + def encode_with_indexes(self, symbols_list, indexes_list, cdf, cdf_length, offset): + self.encoder.encode_with_indexes(symbols_list, indexes_list, cdf, cdf_length, offset) + return None + + def flush_encoder(self): + return self.encoder.flush() + + def reset_encoder(self): + self.encoder.reset() + + def decode_stream(self, indexes, cdf, cdf_length, offset): + rv = self.decoder.decode_stream(indexes, cdf, cdf_length, offset) + rv = np.array(rv) + rv = torch.Tensor(rv).reshape(1, -1, 1, 1) + return rv + + +class Bitparm(nn.Module): + def __init__(self, channel, final=False): + super().__init__() + self.final = final + self.h = nn.Parameter(torch.nn.init.normal_( + torch.empty(channel).view(1, -1, 1, 1), 0, 0.01)) + self.b = nn.Parameter(torch.nn.init.normal_( + torch.empty(channel).view(1, -1, 1, 1), 0, 0.01)) + if not final: + self.a = nn.Parameter(torch.nn.init.normal_( + torch.empty(channel).view(1, -1, 1, 1), 0, 0.01)) + else: + self.a = None + + def forward(self, x): + x = x * F.softplus(self.h) + self.b + if self.final: + return x + + return x + torch.tanh(x) * torch.tanh(self.a) + + +class CdfHelper(): + def __init__(self): + super().__init__() + self._offset = None + self._quantized_cdf = None + self._cdf_length = None + + def set_cdf(self, offset, quantized_cdf, cdf_length): + self._offset = offset.reshape(-1).int().cpu().numpy() + self._quantized_cdf = quantized_cdf.cpu().numpy() + self._cdf_length = cdf_length.reshape(-1).int().cpu().numpy() + + def get_cdf_info(self): + return self._quantized_cdf, \ + self._cdf_length, \ + self._offset + + +class BitEstimator(nn.Module): + def __init__(self, channel): + super().__init__() + self.f1 = Bitparm(channel) + self.f2 = Bitparm(channel) + self.f3 = Bitparm(channel) + self.f4 = Bitparm(channel, True) + self.channel = channel + + self.entropy_coder = None + self.cdf_helper = None + + def forward(self, x): + return self.get_cdf(x) + + def get_logits_cdf(self, x): + x = self.f1(x) + x = self.f2(x) + x = self.f3(x) + x = self.f4(x) + return x + + def get_cdf(self, x): + return torch.sigmoid(self.get_logits_cdf(x)) + + def update(self, force=False, entropy_coder=None): + # Check if we need to update the bottleneck parameters, the offsets are + # only computed and stored when the conditonal model is update()'d. + if self.entropy_coder is not None and not force: # pylint: disable=E0203 + return + + self.entropy_coder = entropy_coder + self.cdf_helper = CdfHelper() + with torch.no_grad(): + device = next(self.parameters()).device + medians = torch.zeros((self.channel), device=device) + + minima = medians + 50 + for i in range(50, 1, -1): + samples = torch.zeros_like(medians) - i + samples = samples[None, :, None, None] + probs = self.forward(samples) + probs = torch.squeeze(probs) + minima = torch.where(probs < torch.zeros_like(medians) + 0.0001, + torch.zeros_like(medians) + i, minima) + + maxima = medians + 50 + for i in range(50, 1, -1): + samples = torch.zeros_like(medians) + i + samples = samples[None, :, None, None] + probs = self.forward(samples) + probs = torch.squeeze(probs) + maxima = torch.where(probs > torch.zeros_like(medians) + 0.9999, + torch.zeros_like(medians) + i, maxima) + + minima = minima.int() + maxima = maxima.int() + + offset = -minima + + pmf_start = medians - minima + pmf_length = maxima + minima + 1 + + max_length = pmf_length.max() + device = pmf_start.device + samples = torch.arange(max_length, device=device) + + samples = samples[None, :] + pmf_start[:, None, None] + + half = float(0.5) + + lower = self.forward(samples - half).squeeze(0) + upper = self.forward(samples + half).squeeze(0) + pmf = upper - lower + + pmf = pmf[:, 0, :] + tail_mass = lower[:, 0, :1] + (1.0 - upper[:, 0, -1:]) + + quantized_cdf = EntropyCoder.pmf_to_cdf(pmf, tail_mass, pmf_length, max_length) + cdf_length = pmf_length + 2 + self.cdf_helper.set_cdf(offset, quantized_cdf, cdf_length) + + @staticmethod + def build_indexes(size): + N, C, H, W = size + indexes = torch.arange(C).view(1, -1, 1, 1) + indexes = indexes.int() + return indexes.repeat(N, 1, H, W) + + def encode(self, x): + indexes = self.build_indexes(x.size()) + return self.entropy_coder.encode_with_indexes(x.reshape(-1).int().cpu().numpy(), + indexes[0].reshape(-1).int().cpu().numpy(), + *self.cdf_helper.get_cdf_info()) + + def decode_stream(self, size): + output_size = (1, self.channel, size[0], size[1]) + indexes = self.build_indexes(output_size) + val = self.entropy_coder.decode_stream(indexes.reshape(-1).int().cpu().numpy(), + *self.cdf_helper.get_cdf_info()) + val = val.reshape(indexes.shape) + return val.float() + + +class GaussianEncoder(): + def __init__(self, distribution='laplace'): + assert distribution in ['laplace', 'gaussian'] + self.distribution = distribution + if distribution == 'laplace': + self.cdf_distribution = torch.distributions.laplace.Laplace + self.scale_min = 0.01 + self.scale_max = 64.0 + self.scale_level = 256 + elif distribution == 'gaussian': + self.cdf_distribution = torch.distributions.normal.Normal + self.scale_min = 0.11 + self.scale_max = 64.0 + self.scale_level = 256 + self.scale_table = self.get_scale_table(self.scale_min, self.scale_max, self.scale_level) + + self.log_scale_min = math.log(self.scale_min) + self.log_scale_max = math.log(self.scale_max) + self.log_scale_step = (self.log_scale_max - self.log_scale_min) / (self.scale_level - 1) + self.entropy_coder = None + self.cdf_helper = None + + @staticmethod + def get_scale_table(min_val, max_val, levels): + return torch.exp(torch.linspace(math.log(min_val), math.log(max_val), levels)) + + def update(self, force=False, entropy_coder=None): + if self.entropy_coder is not None and not force: + return + self.entropy_coder = entropy_coder + self.cdf_helper = CdfHelper() + + pmf_center = torch.zeros_like(self.scale_table) + 50 + scales = torch.zeros_like(pmf_center) + self.scale_table + mu = torch.zeros_like(scales) + cdf_distribution = self.cdf_distribution(mu, scales) + for i in range(50, 1, -1): + samples = torch.zeros_like(pmf_center) + i + probs = cdf_distribution.cdf(samples) + probs = torch.squeeze(probs) + pmf_center = torch.where(probs > torch.zeros_like(pmf_center) + 0.9999, + torch.zeros_like(pmf_center) + i, pmf_center) + + pmf_center = pmf_center.int() + pmf_length = 2 * pmf_center + 1 + max_length = torch.max(pmf_length).item() + + device = pmf_center.device + samples = torch.arange(max_length, device=device) - pmf_center[:, None] + samples = samples.float() + + scales = torch.zeros_like(samples) + self.scale_table[:, None] + mu = torch.zeros_like(scales) + cdf_distribution = self.cdf_distribution(mu, scales) + + upper = cdf_distribution.cdf(samples + 0.5) + lower = cdf_distribution.cdf(samples - 0.5) + pmf = upper - lower + + tail_mass = 2 * lower[:, :1] + + quantized_cdf = torch.Tensor(len(pmf_length), max_length + 2) + quantized_cdf = EntropyCoder.pmf_to_cdf(pmf, tail_mass, pmf_length, max_length) + + self.cdf_helper.set_cdf(-pmf_center, quantized_cdf, pmf_length+2) + + def build_indexes(self, scales): + scales = torch.maximum(scales, torch.zeros_like(scales) + 1e-5) + indexes = (torch.log(scales) - self.log_scale_min) / self.log_scale_step + indexes = indexes.clamp_(0, self.scale_level - 1) + return indexes.int() + + def encode(self, x, scales): + indexes = self.build_indexes(scales) + return self.entropy_coder.encode_with_indexes(x.reshape(-1).int().cpu().numpy(), + indexes.reshape(-1).int().cpu().numpy(), + *self.cdf_helper.get_cdf_info()) + + def decode_stream(self, scales): + indexes = self.build_indexes(scales) + val = self.entropy_coder.decode_stream(indexes.reshape(-1).int().cpu().numpy(), + *self.cdf_helper.get_cdf_info()) + val = val.reshape(scales.shape) + return val.float() + + def set_decoder_cdf(self): + self.entropy_coder.set_decoder_cdf(*self.cdf_helper.get_cdf_info()) + + def encode_with_indexes(self, symbols_list, indexes_list): + return self.entropy_coder.encode_with_indexes(symbols_list, indexes_list, + *self.cdf_helper.get_cdf_info()) diff --git a/ACMMM2022/src/layers/layers.py b/ACMMM2022/src/layers/layers.py new file mode 100644 index 0000000..7a8165b --- /dev/null +++ b/ACMMM2022/src/layers/layers.py @@ -0,0 +1,127 @@ +# Copyright 2020 InterDigital Communications, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from torch import nn + + +def conv3x3(in_ch, out_ch, stride=1): + """3x3 convolution with padding.""" + return nn.Conv2d(in_ch, out_ch, kernel_size=3, stride=stride, padding=1) + + +def subpel_conv3x3(in_ch, out_ch, r=1): + """3x3 sub-pixel convolution for up-sampling.""" + return nn.Sequential( + nn.Conv2d(in_ch, out_ch * r ** 2, kernel_size=3, padding=1), nn.PixelShuffle(r) + ) + + +def subpel_conv1x1(in_ch, out_ch, r=1): + """1x1 sub-pixel convolution for up-sampling.""" + return nn.Sequential( + nn.Conv2d(in_ch, out_ch * r ** 2, kernel_size=1, padding=0), nn.PixelShuffle(r) + ) + + +def conv1x1(in_ch, out_ch, stride=1): + """1x1 convolution.""" + return nn.Conv2d(in_ch, out_ch, kernel_size=1, stride=stride) + + +class ResidualBlockWithStride(nn.Module): + """Residual block with a stride on the first convolution. + + Args: + in_ch (int): number of input channels + out_ch (int): number of output channels + stride (int): stride value (default: 2) + """ + + def __init__(self, in_ch, out_ch, stride=2): + super().__init__() + self.conv1 = conv3x3(in_ch, out_ch, stride=stride) + self.leaky_relu = nn.LeakyReLU() + self.conv2 = conv3x3(out_ch, out_ch) + self.leaky_relu2 = nn.LeakyReLU(negative_slope=0.1) + if stride != 1: + self.downsample = conv1x1(in_ch, out_ch, stride=stride) + else: + self.downsample = None + + def forward(self, x): + identity = x + out = self.conv1(x) + out = self.leaky_relu(out) + out = self.conv2(out) + out = self.leaky_relu2(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + return out + + +class ResidualBlockUpsample(nn.Module): + """Residual block with sub-pixel upsampling on the last convolution. + + Args: + in_ch (int): number of input channels + out_ch (int): number of output channels + upsample (int): upsampling factor (default: 2) + """ + + def __init__(self, in_ch, out_ch, upsample=2): + super().__init__() + self.subpel_conv = subpel_conv1x1(in_ch, out_ch, upsample) + self.leaky_relu = nn.LeakyReLU() + self.conv = conv3x3(out_ch, out_ch) + self.leaky_relu2 = nn.LeakyReLU(negative_slope=0.1) + self.upsample = subpel_conv1x1(in_ch, out_ch, upsample) + + def forward(self, x): + identity = x + out = self.subpel_conv(x) + out = self.leaky_relu(out) + out = self.conv(out) + out = self.leaky_relu2(out) + identity = self.upsample(x) + out += identity + return out + + +class ResidualBlock(nn.Module): + """Simple residual block with two 3x3 convolutions. + + Args: + in_ch (int): number of input channels + out_ch (int): number of output channels + """ + + def __init__(self, in_ch, out_ch, leaky_relu_slope=0.01): + super().__init__() + self.conv1 = conv3x3(in_ch, out_ch) + self.leaky_relu = nn.LeakyReLU(negative_slope=leaky_relu_slope) + self.conv2 = conv3x3(out_ch, out_ch) + + def forward(self, x): + identity = x + + out = self.conv1(x) + out = self.leaky_relu(out) + out = self.conv2(out) + out = self.leaky_relu(out) + + out = out + identity + return out diff --git a/ACMMM2022/src/models/common_model.py b/ACMMM2022/src/models/common_model.py new file mode 100644 index 0000000..defaa76 --- /dev/null +++ b/ACMMM2022/src/models/common_model.py @@ -0,0 +1,188 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import math + +import torch +from torch import nn + +from pytorch_msssim import MS_SSIM + +from .video_net import LowerBound +from ..entropy_models.entropy_models import BitEstimator, GaussianEncoder, EntropyCoder + + +class CompressionModel(nn.Module): + def __init__(self, y_distribution, z_channel, mv_z_channel=None): + super().__init__() + + self.y_distribution = y_distribution + self.z_channel = z_channel + self.mv_z_channel = mv_z_channel + self.entropy_coder = None + self.bit_estimator_z = BitEstimator(z_channel) + self.bit_estimator_z_mv = None + if mv_z_channel is not None: + self.bit_estimator_z_mv = BitEstimator(mv_z_channel) + self.gaussian_encoder = GaussianEncoder(distribution=y_distribution) + + self.mse = nn.MSELoss(reduction='none') + self.ssim = MS_SSIM(data_range=1.0, size_average=False) + + def _initialize_weights(self): + for m in self.modules(): + if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)): + torch.nn.init.xavier_normal_(m.weight, math.sqrt(2)) + if m.bias is not None: + torch.nn.init.constant_(m.bias, 0.01) + + def quant(self, x, force_detach=False): + if self.training or force_detach: + n = torch.round(x) - x + n = n.clone().detach() + return x + n + + return torch.round(x) + + def add_noise(self, x): + noise = torch.nn.init.uniform_(torch.zeros_like(x), -0.5, 0.5) + noise = noise.clone().detach() + return x + noise + + @staticmethod + def probs_to_bits(probs): + bits = -1.0 * torch.log(probs + 1e-5) / math.log(2.0) + bits = LowerBound.apply(bits, 0) + return bits + + def get_y_gaussian_bits(self, y, sigma): + mu = torch.zeros_like(sigma) + sigma = sigma.clamp(0.11, 1e10) + gaussian = torch.distributions.normal.Normal(mu, sigma) + probs = gaussian.cdf(y + 0.5) - gaussian.cdf(y - 0.5) + return CompressionModel.probs_to_bits(probs) + + def get_y_laplace_bits(self, y, sigma): + mu = torch.zeros_like(sigma) + sigma = sigma.clamp(1e-5, 1e10) + gaussian = torch.distributions.laplace.Laplace(mu, sigma) + probs = gaussian.cdf(y + 0.5) - gaussian.cdf(y - 0.5) + return CompressionModel.probs_to_bits(probs) + + def get_z_bits(self, z, bit_estimator): + probs = bit_estimator.get_cdf(z + 0.5) - bit_estimator.get_cdf(z - 0.5) + return CompressionModel.probs_to_bits(probs) + + def update(self, force=False): + self.entropy_coder = EntropyCoder() + self.gaussian_encoder.update(force=force, entropy_coder=self.entropy_coder) + self.bit_estimator_z.update(force=force, entropy_coder=self.entropy_coder) + if self.bit_estimator_z_mv is not None: + self.bit_estimator_z_mv.update(force=force, entropy_coder=self.entropy_coder) + + @staticmethod + def get_mask(height, width, device): + micro_mask = torch.tensor(((1, 0), (0, 1)), dtype=torch.float32, device=device) + mask_0 = micro_mask.repeat(height // 2, width // 2) + mask_0 = torch.unsqueeze(mask_0, 0) + mask_0 = torch.unsqueeze(mask_0, 0) + mask_1 = torch.ones_like(mask_0) - mask_0 + return mask_0, mask_1 + + def process_with_mask(self, y, scales, means, mask): + scales_hat = scales * mask + means_hat = means * mask + + y_res = (y - means_hat) * mask + y_q = self.quant(y_res) + y_hat = y_q + means_hat + + return y_res, y_q, y_hat, scales_hat + + def forward_dual_prior(self, y, means, scales, quant_step, y_spatial_prior, write=False): + ''' + y_0 means split in channel, the first half + y_1 means split in channel, the second half + y_?_0, means multiply with mask_0 + y_?_1, means multiply with mask_1 + ''' + device = y.device + _, _, H, W = y.size() + mask_0, mask_1 = self.get_mask(H, W, device) + + quant_step = LowerBound.apply(quant_step, 0.5) + y = y / quant_step + y_0, y_1 = y.chunk(2, 1) + + scales_0, scales_1 = scales.chunk(2, 1) + means_0, means_1 = means.chunk(2, 1) + + y_res_0_0, y_q_0_0, y_hat_0_0, scales_hat_0_0 = \ + self.process_with_mask(y_0, scales_0, means_0, mask_0) + y_res_1_1, y_q_1_1, y_hat_1_1, scales_hat_1_1 = \ + self.process_with_mask(y_1, scales_1, means_1, mask_1) + + params = torch.cat((y_hat_0_0, y_hat_1_1, means, scales, quant_step), dim=1) + scales_0, means_0, scales_1, means_1 = y_spatial_prior(params).chunk(4, 1) + + y_res_0_1, y_q_0_1, y_hat_0_1, scales_hat_0_1 = \ + self.process_with_mask(y_0, scales_0, means_0, mask_1) + y_res_1_0, y_q_1_0, y_hat_1_0, scales_hat_1_0 = \ + self.process_with_mask(y_1, scales_1, means_1, mask_0) + + y_res_0 = y_res_0_0 + y_res_0_1 + y_q_0 = y_q_0_0 + y_q_0_1 + y_hat_0 = y_hat_0_0 + y_hat_0_1 + scales_hat_0 = scales_hat_0_0 + scales_hat_0_1 + + y_res_1 = y_res_1_1 + y_res_1_0 + y_q_1 = y_q_1_1 + y_q_1_0 + y_hat_1 = y_hat_1_1 + y_hat_1_0 + scales_hat_1 = scales_hat_1_1 + scales_hat_1_0 + + y_res = torch.cat((y_res_0, y_res_1), dim=1) + y_q = torch.cat((y_q_0, y_q_1), dim=1) + y_hat = torch.cat((y_hat_0, y_hat_1), dim=1) + scales_hat = torch.cat((scales_hat_0, scales_hat_1), dim=1) + + y_hat = y_hat * quant_step + + if write: + y_q_w_0 = y_q_0_0 + y_q_1_1 + y_q_w_1 = y_q_0_1 + y_q_1_0 + scales_w_0 = scales_hat_0_0 + scales_hat_1_1 + scales_w_1 = scales_hat_0_1 + scales_hat_1_0 + return y_q_w_0, y_q_w_1, scales_w_0, scales_w_1, y_hat + return y_res, y_q, y_hat, scales_hat + + def compress_dual_prior(self, y, means, scales, quant_step, y_spatial_prior): + return self.forward_dual_prior(y, means, scales, quant_step, y_spatial_prior, write=True) + + def decompress_dual_prior(self, means, scales, quant_step, y_spatial_prior): + device = means.device + _, _, H, W = means.size() + mask_0, mask_1 = self.get_mask(H, W, device) + quant_step = torch.clamp_min(quant_step, 0.5) + + scales_0, scales_1 = scales.chunk(2, 1) + means_0, means_1 = means.chunk(2, 1) + + scales_r_0 = scales_0 * mask_0 + scales_1 * mask_1 + y_q_r_0 = self.gaussian_encoder.decode_stream(scales_r_0).to(device) + y_hat_0_0 = (y_q_r_0 + means_0) * mask_0 + y_hat_1_1 = (y_q_r_0 + means_1) * mask_1 + + params = torch.cat((y_hat_0_0, y_hat_1_1, means, scales, quant_step), dim=1) + scales_0, means_0, scales_1, means_1 = y_spatial_prior(params).chunk(4, 1) + + scales_r_1 = scales_0 * mask_1 + scales_1 * mask_0 + y_q_r_1 = self.gaussian_encoder.decode_stream(scales_r_1).to(device) + y_hat_0_1 = (y_q_r_1 + means_0) * mask_1 + y_hat_1_0 = (y_q_r_1 + means_1) * mask_0 + + y_hat_0 = y_hat_0_0 + y_hat_0_1 + y_hat_1 = y_hat_1_1 + y_hat_1_0 + y_hat = torch.cat((y_hat_0, y_hat_1), dim=1) + y_hat = y_hat * quant_step + + return y_hat diff --git a/ACMMM2022/src/models/image_model.py b/ACMMM2022/src/models/image_model.py new file mode 100644 index 0000000..5575454 --- /dev/null +++ b/ACMMM2022/src/models/image_model.py @@ -0,0 +1,169 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import torch +from torch import nn + +from src.layers.layers import conv3x3 + +from .common_model import CompressionModel +from .video_net import LowerBound, UNet, get_enc_dec_models, get_hyper_enc_dec_models +from ..utils.stream_helper import encode_i, decode_i, get_downsampled_shape, filesize, \ + get_rounded_q, get_state_dict + + +class IntraNoAR(CompressionModel): + def __init__(self, N=192, anchor_num=4): + super().__init__(y_distribution='gaussian', z_channel=N) + + self.enc, self.dec = get_enc_dec_models(3, 16, N) + self.refine = nn.Sequential( + UNet(16, 16), + conv3x3(16, 3), + ) + self.hyper_enc, self.hyper_dec = get_hyper_enc_dec_models(N, N) + self.y_prior_fusion = nn.Sequential( + nn.Conv2d(N * 2, N * 3, 3, stride=1, padding=1), + nn.LeakyReLU(0.2), + nn.Conv2d(N * 3, N * 3, 3, stride=1, padding=1), + nn.LeakyReLU(0.2), + nn.Conv2d(N * 3, N * 3, 3, stride=1, padding=1) + ) + + self.y_spatial_prior = nn.Sequential( + nn.Conv2d(N * 4, N * 3, 3, padding=1), + nn.LeakyReLU(0.2), + nn.Conv2d(N * 3, N * 3, 3, padding=1), + nn.LeakyReLU(0.2), + nn.Conv2d(N * 3, N * 2, 3, padding=1) + ) + + self.q_basic = nn.Parameter(torch.ones((1, N, 1, 1))) + self.q_scale = nn.Parameter(torch.ones((anchor_num, 1, 1, 1))) + # the exact q_step is q_basic * q_scale + self.N = int(N) + self.anchor_num = int(anchor_num) + + self._initialize_weights() + + def get_curr_q(self, q_scale): + q_basic = LowerBound.apply(self.q_basic, 0.5) + return q_basic * q_scale + + def forward(self, x, q_scale=None): + curr_q = self.get_curr_q(q_scale) + + y = self.enc(x) + y = y / curr_q + z = self.hyper_enc(y) + z_hat = self.quant(z) + + params = self.hyper_dec(z_hat) + q_step, scales, means = self.y_prior_fusion(params).chunk(3, 1) + y_res, y_q, y_hat, scales_hat = self.forward_dual_prior( + y, means, scales, q_step, self.y_spatial_prior) + + y_hat = y_hat * curr_q + x_hat = self.refine(self.dec(y_hat)) + + if self.training: + y_for_bit = self.add_noise(y_res) + z_for_bit = self.add_noise(z) + else: + y_for_bit = y_q + z_for_bit = z_hat + bits_y = self.get_y_gaussian_bits(y_for_bit, scales_hat) + bits_z = self.get_z_bits(z_for_bit, self.bit_estimator_z) + mse = self.mse(x, x_hat) + ssim = self.ssim(x, x_hat) + + _, _, H, W = x.size() + pixel_num = H * W + bpp_y = torch.sum(bits_y, dim=(1, 2, 3)) / pixel_num + bpp_z = torch.sum(bits_z, dim=(1, 2, 3)) / pixel_num + mse = torch.sum(mse, dim=(1, 2, 3)) / pixel_num + + bits = torch.sum(bpp_y + bpp_z) * pixel_num + bpp = bpp_y + bpp_z + + return { + "x_hat": x_hat, + "mse": mse, + "ssim": ssim, + "bit": bits.item(), + "bpp": bpp, + "bpp_y": bpp_y, + "bpp_z": bpp_z, + } + + @staticmethod + def get_q_scales_from_ckpt(ckpt_path): + ckpt = get_state_dict(ckpt_path) + q_scales = ckpt["q_scale"] + return q_scales.reshape(-1) + + def encode_decode(self, x, q_scale, output_path=None, pic_width=None, pic_height=None): + # pic_width and pic_height may be different from x's size. X here is after padding + # x_hat has the same size with x + if output_path is None: + return self.forward(x, q_scale) + + assert pic_height is not None + assert pic_width is not None + q_scale, q_index = get_rounded_q(q_scale) + compressed = self.compress(x, q_scale) + bit_stream = compressed['bit_stream'] + encode_i(pic_height, pic_width, q_index, bit_stream, output_path) + bit = filesize(output_path) * 8 + + height, width, q_index, bit_stream = decode_i(output_path) + decompressed = self.decompress(bit_stream, height, width, q_index / 100) + x_hat = decompressed['x_hat'] + + result = { + 'bit': bit, + 'x_hat': x_hat, + } + return result + + def compress(self, x, q_scale): + curr_q = self.get_curr_q(q_scale) + + y = self.enc(x) + y = y / curr_q + z = self.hyper_enc(y) + z_hat = torch.round(z) + + params = self.hyper_dec(z_hat) + q_step, scales, means = self.y_prior_fusion(params).chunk(3, 1) + y_q_w_0, y_q_w_1, scales_w_0, scales_w_1, y_hat = self.compress_dual_prior( + y, means, scales, q_step, self.y_spatial_prior) + y_hat = y_hat * curr_q + + self.entropy_coder.reset_encoder() + _ = self.bit_estimator_z.encode(z_hat) + _ = self.gaussian_encoder.encode(y_q_w_0, scales_w_0) + _ = self.gaussian_encoder.encode(y_q_w_1, scales_w_1) + bit_stream = self.entropy_coder.flush_encoder() + + result = { + "bit_stream": bit_stream, + } + return result + + def decompress(self, bit_stream, height, width, q_scale): + curr_q = self.get_curr_q(q_scale) + + self.entropy_coder.set_stream(bit_stream) + device = next(self.parameters()).device + z_size = get_downsampled_shape(height, width, 64) + z_hat = self.bit_estimator_z.decode_stream(z_size) + z_hat = z_hat.to(device) + + params = self.hyper_dec(z_hat) + q_step, scales, means = self.y_prior_fusion(params).chunk(3, 1) + y_hat = self.decompress_dual_prior(means, scales, q_step, self.y_spatial_prior) + + y_hat = y_hat * curr_q + x_hat = self.refine(self.dec(y_hat)).clamp_(0, 1) + return {"x_hat": x_hat} diff --git a/ACMMM2022/src/models/video_model.py b/ACMMM2022/src/models/video_model.py new file mode 100644 index 0000000..4c9997d --- /dev/null +++ b/ACMMM2022/src/models/video_model.py @@ -0,0 +1,513 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import time + +import torch +from torch import nn + +from .common_model import CompressionModel +from .video_net import ME_Spynet, flow_warp, ResBlock, bilineardownsacling, LowerBound, UNet, \ + get_enc_dec_models, get_hyper_enc_dec_models +from ..layers.layers import conv3x3, subpel_conv1x1, subpel_conv3x3 +from ..utils.stream_helper import get_downsampled_shape, encode_p, decode_p, filesize, \ + get_rounded_q, get_state_dict + + +class FeatureExtractor(nn.Module): + def __init__(self, channel=64): + super().__init__() + self.conv1 = nn.Conv2d(channel, channel, 3, stride=1, padding=1) + self.res_block1 = ResBlock(channel) + self.conv2 = nn.Conv2d(channel, channel, 3, stride=2, padding=1) + self.res_block2 = ResBlock(channel) + self.conv3 = nn.Conv2d(channel, channel, 3, stride=2, padding=1) + self.res_block3 = ResBlock(channel) + + def forward(self, feature): + layer1 = self.conv1(feature) + layer1 = self.res_block1(layer1) + + layer2 = self.conv2(layer1) + layer2 = self.res_block2(layer2) + + layer3 = self.conv3(layer2) + layer3 = self.res_block3(layer3) + + return layer1, layer2, layer3 + + +class MultiScaleContextFusion(nn.Module): + def __init__(self, channel_in=64, channel_out=64): + super().__init__() + self.conv3_up = subpel_conv3x3(channel_in, channel_out, 2) + self.res_block3_up = ResBlock(channel_out) + self.conv3_out = nn.Conv2d(channel_out, channel_out, 3, padding=1) + self.res_block3_out = ResBlock(channel_out) + self.conv2_up = subpel_conv3x3(channel_out * 2, channel_out, 2) + self.res_block2_up = ResBlock(channel_out) + self.conv2_out = nn.Conv2d(channel_out * 2, channel_out, 3, padding=1) + self.res_block2_out = ResBlock(channel_out) + self.conv1_out = nn.Conv2d(channel_out * 2, channel_out, 3, padding=1) + self.res_block1_out = ResBlock(channel_out) + + def forward(self, context1, context2, context3): + context3_up = self.conv3_up(context3) + context3_up = self.res_block3_up(context3_up) + context3_out = self.conv3_out(context3) + context3_out = self.res_block3_out(context3_out) + context2_up = self.conv2_up(torch.cat((context3_up, context2), dim=1)) + context2_up = self.res_block2_up(context2_up) + context2_out = self.conv2_out(torch.cat((context3_up, context2), dim=1)) + context2_out = self.res_block2_out(context2_out) + context1_out = self.conv1_out(torch.cat((context2_up, context1), dim=1)) + context1_out = self.res_block1_out(context1_out) + context1 = context1 + context1_out + context2 = context2 + context2_out + context3 = context3 + context3_out + return context1, context2, context3 + + +class ContextualEncoder(nn.Module): + def __init__(self, channel_N=64, channel_M=96): + super().__init__() + self.conv1 = nn.Conv2d(channel_N + 3, channel_N, 3, stride=2, padding=1) + self.res1 = ResBlock(channel_N * 2, bottleneck=True, slope=0.1, + start_from_relu=True, end_with_relu=True) + self.conv2 = nn.Conv2d(channel_N * 2, channel_N, 3, stride=2, padding=1) + self.res2 = ResBlock(channel_N * 2, bottleneck=True, slope=0.1, + start_from_relu=True, end_with_relu=True) + self.conv3 = nn.Conv2d(channel_N * 2, channel_N, 3, stride=2, padding=1) + self.conv4 = nn.Conv2d(channel_N, channel_M, 3, stride=2, padding=1) + + def forward(self, x, context1, context2, context3): + feature = self.conv1(torch.cat([x, context1], dim=1)) + feature = self.res1(torch.cat([feature, context2], dim=1)) + feature = self.conv2(feature) + feature = self.res2(torch.cat([feature, context3], dim=1)) + feature = self.conv3(feature) + feature = self.conv4(feature) + return feature + + +class ContextualDecoder(nn.Module): + def __init__(self, channel_N=64, channel_M=96): + super().__init__() + self.up1 = subpel_conv3x3(channel_M, channel_N, 2) + self.up2 = subpel_conv3x3(channel_N, channel_N, 2) + self.res1 = ResBlock(channel_N * 2, bottleneck=True, slope=0.1, + start_from_relu=True, end_with_relu=True) + self.up3 = subpel_conv3x3(channel_N * 2, channel_N, 2) + self.res2 = ResBlock(channel_N * 2, bottleneck=True, slope=0.1, + start_from_relu=True, end_with_relu=True) + self.up4 = subpel_conv3x3(channel_N * 2, 32, 2) + + def forward(self, x, context2, context3): + feature = self.up1(x) + feature = self.up2(feature) + feature = self.res1(torch.cat([feature, context3], dim=1)) + feature = self.up3(feature) + feature = self.res2(torch.cat([feature, context2], dim=1)) + feature = self.up4(feature) + return feature + + +class ReconGeneration(nn.Module): + def __init__(self, ctx_channel=64, res_channel=32, channel=64): + super().__init__() + self.first_conv = nn.Conv2d(ctx_channel + res_channel, channel, 3, stride=1, padding=1) + self.unet_1 = UNet(channel) + self.unet_2 = UNet(channel) + self.recon_conv = nn.Conv2d(channel, 3, 3, stride=1, padding=1) + + def forward(self, ctx, res): + feature = self.first_conv(torch.cat((ctx, res), dim=1)) + feature = self.unet_1(feature) + feature = self.unet_2(feature) + recon = self.recon_conv(feature) + return feature, recon + + +class DMC(CompressionModel): + def __init__(self, anchor_num=4): + super().__init__(y_distribution='laplace', z_channel=64, mv_z_channel=64) + self.DMC_version = '1.19' + + channel_mv = 64 + channel_N = 64 + channel_M = 96 + + self.channel_mv = channel_mv + self.channel_N = channel_N + self.channel_M = channel_M + + self.optic_flow = ME_Spynet() + + self.mv_encoder, self.mv_decoder = get_enc_dec_models(2, 2, channel_mv) + self.mv_hyper_prior_encoder, self.mv_hyper_prior_decoder = \ + get_hyper_enc_dec_models(channel_mv, channel_N) + + self.mv_y_prior_fusion = nn.Sequential( + nn.Conv2d(channel_mv * 3, channel_mv * 3, 3, stride=1, padding=1), + nn.LeakyReLU(0.2), + nn.Conv2d(channel_mv * 3, channel_mv * 3, 3, stride=1, padding=1), + nn.LeakyReLU(0.2), + nn.Conv2d(channel_mv * 3, channel_mv * 3, 3, stride=1, padding=1) + ) + + self.mv_y_spatial_prior = nn.Sequential( + nn.Conv2d(channel_mv * 4, channel_mv * 3, 3, padding=1), + nn.LeakyReLU(0.2), + nn.Conv2d(channel_mv * 3, channel_mv * 3, 3, padding=1), + nn.LeakyReLU(0.2), + nn.Conv2d(channel_mv * 3, channel_mv * 2, 3, padding=1) + ) + + self.feature_adaptor_I = nn.Conv2d(3, channel_N, 3, stride=1, padding=1) + self.feature_adaptor_P = nn.Conv2d(channel_N, channel_N, 1) + self.feature_extractor = FeatureExtractor() + self.context_fusion_net = MultiScaleContextFusion() + + self.contextual_encoder = ContextualEncoder(channel_N=channel_N, channel_M=channel_M) + + self.contextual_hyper_prior_encoder = nn.Sequential( + nn.Conv2d(channel_M, channel_N, 3, stride=1, padding=1), + nn.LeakyReLU(), + nn.Conv2d(channel_N, channel_N, 3, stride=2, padding=1), + nn.LeakyReLU(), + nn.Conv2d(channel_N, channel_N, 3, stride=2, padding=1), + ) + + self.contextual_hyper_prior_decoder = nn.Sequential( + conv3x3(channel_N, channel_M), + nn.LeakyReLU(), + subpel_conv1x1(channel_M, channel_M, 2), + nn.LeakyReLU(), + conv3x3(channel_M, channel_M * 3 // 2), + nn.LeakyReLU(), + subpel_conv1x1(channel_M * 3 // 2, channel_M * 3 // 2, 2), + nn.LeakyReLU(), + conv3x3(channel_M * 3 // 2, channel_M * 2), + ) + + self.temporal_prior_encoder = nn.Sequential( + nn.Conv2d(channel_N, channel_M * 3 // 2, 3, stride=2, padding=1), + nn.LeakyReLU(0.1), + nn.Conv2d(channel_M * 3 // 2, channel_M * 2, 3, stride=2, padding=1), + ) + + self.y_prior_fusion = nn.Sequential( + nn.Conv2d(channel_M * 5, channel_M * 4, 3, stride=1, padding=1), + nn.LeakyReLU(0.2), + nn.Conv2d(channel_M * 4, channel_M * 3, 3, stride=1, padding=1), + nn.LeakyReLU(0.2), + nn.Conv2d(channel_M * 3, channel_M * 3, 3, stride=1, padding=1) + ) + + self.y_spatial_prior = nn.Sequential( + nn.Conv2d(channel_M * 4, channel_M * 3, 3, padding=1), + nn.LeakyReLU(0.2), + nn.Conv2d(channel_M * 3, channel_M * 3, 3, padding=1), + nn.LeakyReLU(0.2), + nn.Conv2d(channel_M * 3, channel_M * 2, 3, padding=1) + ) + + self.contextual_decoder = ContextualDecoder(channel_N=channel_N, channel_M=channel_M) + self.recon_generation_net = ReconGeneration() + + self.mv_y_q_basic = nn.Parameter(torch.ones((1, channel_mv, 1, 1))) + self.mv_y_q_scale = nn.Parameter(torch.ones((anchor_num, 1, 1, 1))) + self.y_q_basic = nn.Parameter(torch.ones((1, channel_M, 1, 1))) + self.y_q_scale = nn.Parameter(torch.ones((anchor_num, 1, 1, 1))) + self.anchor_num = int(anchor_num) + + self._initialize_weights() + + def multi_scale_feature_extractor(self, dpb): + if dpb["ref_feature"] is None: + feature = self.feature_adaptor_I(dpb["ref_frame"]) + else: + feature = self.feature_adaptor_P(dpb["ref_feature"]) + return self.feature_extractor(feature) + + def motion_compensation(self, dpb, mv): + warpframe = flow_warp(dpb["ref_frame"], mv) + mv2 = bilineardownsacling(mv) / 2 + mv3 = bilineardownsacling(mv2) / 2 + ref_feature1, ref_feature2, ref_feature3 = self.multi_scale_feature_extractor(dpb) + context1 = flow_warp(ref_feature1, mv) + context2 = flow_warp(ref_feature2, mv2) + context3 = flow_warp(ref_feature3, mv3) + context1, context2, context3 = self.context_fusion_net(context1, context2, context3) + return context1, context2, context3, warpframe + + @staticmethod + def get_q_scales_from_ckpt(ckpt_path): + ckpt = get_state_dict(ckpt_path) + y_q_scales = ckpt["y_q_scale"] + mv_y_q_scales = ckpt["mv_y_q_scale"] + return y_q_scales.reshape(-1), mv_y_q_scales.reshape(-1) + + def get_curr_mv_y_q(self, q_scale): + q_basic = LowerBound.apply(self.mv_y_q_basic, 0.5) + return q_basic * q_scale + + def get_curr_y_q(self, q_scale): + q_basic = LowerBound.apply(self.y_q_basic, 0.5) + return q_basic * q_scale + + def compress(self, x, dpb, mv_y_q_scale, y_q_scale): + # pic_width and pic_height may be different from x's size. x here is after padding + # x_hat has the same size with x + curr_mv_y_q = self.get_curr_mv_y_q(mv_y_q_scale) + curr_y_q = self.get_curr_y_q(y_q_scale) + + est_mv = self.optic_flow(x, dpb["ref_frame"]) + mv_y = self.mv_encoder(est_mv) + mv_y = mv_y / curr_mv_y_q + mv_z = self.mv_hyper_prior_encoder(mv_y) + mv_z_hat = torch.round(mv_z) + mv_params = self.mv_hyper_prior_decoder(mv_z_hat) + ref_mv_y = dpb["ref_mv_y"] + if ref_mv_y is None: + ref_mv_y = torch.zeros_like(mv_y) + mv_params = torch.cat((mv_params, ref_mv_y), dim=1) + mv_q_step, mv_scales, mv_means = self.mv_y_prior_fusion(mv_params).chunk(3, 1) + mv_y_q_w_0, mv_y_q_w_1, mv_scales_w_0, mv_scales_w_1, mv_y_hat = self.compress_dual_prior( + mv_y, mv_means, mv_scales, mv_q_step, self.mv_y_spatial_prior) + mv_y_hat = mv_y_hat * curr_mv_y_q + + mv_hat = self.mv_decoder(mv_y_hat) + context1, context2, context3, _ = self.motion_compensation(dpb, mv_hat) + + y = self.contextual_encoder(x, context1, context2, context3) + y = y / curr_y_q + z = self.contextual_hyper_prior_encoder(y) + z_hat = torch.round(z) + hierarchical_params = self.contextual_hyper_prior_decoder(z_hat) + temporal_params = self.temporal_prior_encoder(context3) + ref_y = dpb["ref_y"] + if ref_y is None: + ref_y = torch.zeros_like(y) + params = torch.cat((temporal_params, hierarchical_params, ref_y), dim=1) + + q_step, scales, means = self.y_prior_fusion(params).chunk(3, 1) + y_q_w_0, y_q_w_1, scales_w_0, scales_w_1, y_hat = self.compress_dual_prior( + y, means, scales, q_step, self.y_spatial_prior) + y_hat = y_hat * curr_y_q + + recon_image_feature = self.contextual_decoder(y_hat, context2, context3) + feature, x_hat = self.recon_generation_net(recon_image_feature, context1) + + self.entropy_coder.reset_encoder() + _ = self.bit_estimator_z_mv.encode(mv_z_hat) + _ = self.gaussian_encoder.encode(mv_y_q_w_0, mv_scales_w_0) + _ = self.gaussian_encoder.encode(mv_y_q_w_1, mv_scales_w_1) + _ = self.bit_estimator_z.encode(z_hat) + _ = self.gaussian_encoder.encode(y_q_w_0, scales_w_0) + _ = self.gaussian_encoder.encode(y_q_w_1, scales_w_1) + bit_stream = self.entropy_coder.flush_encoder() + + result = { + "dbp": { + "ref_frame": x_hat, + "ref_feature": feature, + "ref_y": y_hat, + "ref_mv_y": mv_y_hat, + }, + "bit_stream": bit_stream, + } + return result + + def decompress(self, dpb, string, height, width, + mv_y_q_scale, y_q_scale): + curr_mv_y_q = self.get_curr_mv_y_q(mv_y_q_scale) + curr_y_q = self.get_curr_y_q(y_q_scale) + + self.entropy_coder.set_stream(string) + device = next(self.parameters()).device + mv_z_size = get_downsampled_shape(height, width, 64) + mv_z_hat = self.bit_estimator_z_mv.decode_stream(mv_z_size) + mv_z_hat = mv_z_hat.to(device) + mv_params = self.mv_hyper_prior_decoder(mv_z_hat) + ref_mv_y = dpb["ref_mv_y"] + if ref_mv_y is None: + _, C, H, W = mv_params.size() + ref_mv_y = torch.zeros((1, C // 2, H, W), device=mv_params.device) + mv_params = torch.cat((mv_params, ref_mv_y), dim=1) + mv_q_step, mv_scales, mv_means = self.mv_y_prior_fusion(mv_params).chunk(3, 1) + mv_y_hat = self.decompress_dual_prior(mv_means, mv_scales, mv_q_step, + self.mv_y_spatial_prior) + mv_y_hat = mv_y_hat * curr_mv_y_q + + mv_hat = self.mv_decoder(mv_y_hat) + context1, context2, context3, _ = self.motion_compensation(dpb, mv_hat) + + z_size = get_downsampled_shape(height, width, 64) + z_hat = self.bit_estimator_z.decode_stream(z_size) + z_hat = z_hat.to(device) + hierarchical_params = self.contextual_hyper_prior_decoder(z_hat) + temporal_params = self.temporal_prior_encoder(context3) + ref_y = dpb["ref_y"] + if ref_y is None: + _, C, H, W = temporal_params.size() + ref_y = torch.zeros((1, C // 2, H, W), device=temporal_params.device) + params = torch.cat((temporal_params, hierarchical_params, ref_y), dim=1) + q_step, scales, means = self.y_prior_fusion(params).chunk(3, 1) + y_hat = self.decompress_dual_prior(means, scales, q_step, self.y_spatial_prior) + y_hat = y_hat * curr_y_q + + recon_image_feature = self.contextual_decoder(y_hat, context2, context3) + feature, recon_image = self.recon_generation_net(recon_image_feature, context1) + recon_image = recon_image.clamp(0, 1) + + return { + "dpb": { + "ref_frame": recon_image, + "ref_feature": feature, + "ref_y": y_hat, + "ref_mv_y": mv_y_hat, + }, + } + + def encode_decode(self, x, dpb, output_path=None, + pic_width=None, pic_height=None, + mv_y_q_scale=None, y_q_scale=None): + # pic_width and pic_height may be different from x's size. x here is after padding + # x_hat has the same size with x + if output_path is not None: + mv_y_q_scale, mv_y_q_index = get_rounded_q(mv_y_q_scale) + y_q_scale, y_q_index = get_rounded_q(y_q_scale) + + encoded = self.compress(x, dpb, + mv_y_q_scale, y_q_scale) + encode_p(encoded['bit_stream'], mv_y_q_index, y_q_index, output_path) + bits = filesize(output_path) * 8 + mv_y_q_index, y_q_index, string = decode_p(output_path) + + start = time.time() + decoded = self.decompress(dpb, string, + pic_height, pic_width, + mv_y_q_index / 100, y_q_index / 100) + decoding_time = time.time() - start + result = { + "dpb": decoded["dpb"], + "bit": bits, + "decoding_time": decoding_time, + } + return result + + encoded = self.forward_one_frame(x, dpb, + mv_y_q_scale=mv_y_q_scale, y_q_scale=y_q_scale) + result = { + "dpb": encoded['dpb'], + "bit_y": encoded['bit_y'].item(), + "bit_z": encoded['bit_z'].item(), + "bit_mv_y": encoded['bit_mv_y'].item(), + "bit_mv_z": encoded['bit_mv_z'].item(), + "bit": encoded['bit'].item(), + "decoding_time": 0, + } + return result + + def forward_one_frame(self, x, dpb, mv_y_q_scale=None, y_q_scale=None): + ref_frame = dpb["ref_frame"] + curr_mv_y_q = self.get_curr_mv_y_q(mv_y_q_scale) + curr_y_q = self.get_curr_y_q(y_q_scale) + + est_mv = self.optic_flow(x, ref_frame) + mv_y = self.mv_encoder(est_mv) + mv_y = mv_y / curr_mv_y_q + mv_z = self.mv_hyper_prior_encoder(mv_y) + mv_z_hat = self.quant(mv_z) + mv_params = self.mv_hyper_prior_decoder(mv_z_hat) + ref_mv_y = dpb["ref_mv_y"] + if ref_mv_y is None: + ref_mv_y = torch.zeros_like(mv_y) + mv_params = torch.cat((mv_params, ref_mv_y), dim=1) + mv_q_step, mv_scales, mv_means = self.mv_y_prior_fusion(mv_params).chunk(3, 1) + mv_y_res, mv_y_q, mv_y_hat, mv_scales_hat = self.forward_dual_prior( + mv_y, mv_means, mv_scales, mv_q_step, self.mv_y_spatial_prior) + mv_y_hat = mv_y_hat * curr_mv_y_q + + mv_hat = self.mv_decoder(mv_y_hat) + context1, context2, context3, warp_frame = self.motion_compensation(dpb, mv_hat) + + y = self.contextual_encoder(x, context1, context2, context3) + y = y / curr_y_q + z = self.contextual_hyper_prior_encoder(y) + z_hat = self.quant(z) + hierarchical_params = self.contextual_hyper_prior_decoder(z_hat) + temporal_params = self.temporal_prior_encoder(context3) + + ref_y = dpb["ref_y"] + if ref_y is None: + ref_y = torch.zeros_like(y) + params = torch.cat((temporal_params, hierarchical_params, ref_y), dim=1) + q_step, scales, means = self.y_prior_fusion(params).chunk(3, 1) + y_res, y_q, y_hat, scales_hat = self.forward_dual_prior( + y, means, scales, q_step, self.y_spatial_prior) + y_hat = y_hat * curr_y_q + + recon_image_feature = self.contextual_decoder(y_hat, context2, context3) + feature, recon_image = self.recon_generation_net(recon_image_feature, context1) + + B, _, H, W = x.size() + pixel_num = H * W + mse = self.mse(x, recon_image) + ssim = self.ssim(x, recon_image) + me_mse = self.mse(x, warp_frame) + mse = torch.sum(mse, dim=(1, 2, 3)) / pixel_num + me_mse = torch.sum(me_mse, dim=(1, 2, 3)) / pixel_num + + if self.training: + y_for_bit = self.add_noise(y_res) + mv_y_for_bit = self.add_noise(mv_y_res) + z_for_bit = self.add_noise(z) + mv_z_for_bit = self.add_noise(mv_z) + else: + y_for_bit = y_q + mv_y_for_bit = mv_y_q + z_for_bit = z_hat + mv_z_for_bit = mv_z_hat + bits_y = self.get_y_laplace_bits(y_for_bit, scales_hat) + bits_mv_y = self.get_y_laplace_bits(mv_y_for_bit, mv_scales_hat) + bits_z = self.get_z_bits(z_for_bit, self.bit_estimator_z) + bits_mv_z = self.get_z_bits(mv_z_for_bit, self.bit_estimator_z_mv) + + bpp_y = torch.sum(bits_y, dim=(1, 2, 3)) / pixel_num + bpp_z = torch.sum(bits_z, dim=(1, 2, 3)) / pixel_num + bpp_mv_y = torch.sum(bits_mv_y, dim=(1, 2, 3)) / pixel_num + bpp_mv_z = torch.sum(bits_mv_z, dim=(1, 2, 3)) / pixel_num + + bpp = bpp_y + bpp_z + bpp_mv_y + bpp_mv_z + bit = torch.sum(bpp) * pixel_num + bit_y = torch.sum(bpp_y) * pixel_num + bit_z = torch.sum(bpp_z) * pixel_num + bit_mv_y = torch.sum(bpp_mv_y) * pixel_num + bit_mv_z = torch.sum(bpp_mv_z) * pixel_num + + return {"bpp_mv_y": bpp_mv_y, + "bpp_mv_z": bpp_mv_z, + "bpp_y": bpp_y, + "bpp_z": bpp_z, + "bpp": bpp, + "me_mse": me_mse, + "mse": mse, + "ssim": ssim, + "dpb": { + "ref_frame": recon_image, + "ref_feature": feature, + "ref_y": y_hat, + "ref_mv_y": mv_y_hat, + }, + "bit": bit, + "bit_y": bit_y, + "bit_z": bit_z, + "bit_mv_y": bit_mv_y, + "bit_mv_z": bit_mv_z, + } + + def forward(self, x, dpb, mv_y_q_scale=None, y_q_scale=None): + return self.forward_one_frame(x, dpb, + mv_y_q_scale=mv_y_q_scale, y_q_scale=y_q_scale) diff --git a/ACMMM2022/src/models/video_net.py b/ACMMM2022/src/models/video_net.py new file mode 100644 index 0000000..3ab4f6d --- /dev/null +++ b/ACMMM2022/src/models/video_net.py @@ -0,0 +1,276 @@ +import torch +from torch import nn +import torch.nn.functional as F +from torch.autograd import Function + +from ..layers.layers import subpel_conv1x1, conv3x3, \ + ResidualBlock, ResidualBlockWithStride, ResidualBlockUpsample + + +backward_grid = [{} for _ in range(9)] # 0~7 for GPU, -1 for CPU + + +# pylint: disable=W0221 +class LowerBound(Function): + @staticmethod + def forward(ctx, inputs, bound): + b = torch.ones_like(inputs) * bound + ctx.save_for_backward(inputs, b) + return torch.max(inputs, b) + + @staticmethod + def backward(ctx, grad_output): + inputs, b = ctx.saved_tensors + pass_through_1 = inputs >= b + pass_through_2 = grad_output < 0 + + pass_through = pass_through_1 | pass_through_2 + return pass_through.type(grad_output.dtype) * grad_output, None +# pylint: enable=W0221 + + +def torch_warp(feature, flow): + device_id = -1 if feature.device == torch.device('cpu') else feature.device.index + if str(flow.size()) not in backward_grid[device_id]: + N, _, H, W = flow.size() + tensor_hor = torch.linspace(-1.0, 1.0, W, device=feature.device, dtype=feature.dtype).view( + 1, 1, 1, W).expand(N, -1, H, -1) + tensor_ver = torch.linspace(-1.0, 1.0, H, device=feature.device, dtype=feature.dtype).view( + 1, 1, H, 1).expand(N, -1, -1, W) + backward_grid[device_id][str(flow.size())] = torch.cat([tensor_hor, tensor_ver], 1) + + flow = torch.cat([flow[:, 0:1, :, :] / ((feature.size(3) - 1.0) / 2.0), + flow[:, 1:2, :, :] / ((feature.size(2) - 1.0) / 2.0)], 1) + + grid = (backward_grid[device_id][str(flow.size())] + flow) + return torch.nn.functional.grid_sample(input=feature, + grid=grid.permute(0, 2, 3, 1), + mode='bilinear', + padding_mode='border', + align_corners=True) + + +def flow_warp(im, flow): + warp = torch_warp(im, flow) + return warp + + +def bilinearupsacling(inputfeature): + inputheight = inputfeature.size()[2] + inputwidth = inputfeature.size()[3] + outfeature = F.interpolate( + inputfeature, (inputheight * 2, inputwidth * 2), mode='bilinear', align_corners=False) + return outfeature + + +def bilineardownsacling(inputfeature): + inputheight = inputfeature.size()[2] + inputwidth = inputfeature.size()[3] + outfeature = F.interpolate( + inputfeature, (inputheight // 2, inputwidth // 2), mode='bilinear', align_corners=False) + return outfeature + + +class ResBlock(nn.Module): + def __init__(self, channel, slope=0.01, start_from_relu=True, end_with_relu=False, + bottleneck=False): + super().__init__() + self.relu = nn.LeakyReLU(negative_slope=slope) + if slope < 0.0001: + self.relu = nn.ReLU() + if bottleneck: + self.conv1 = nn.Conv2d(channel, channel // 2, 3, padding=1) + self.conv2 = nn.Conv2d(channel // 2, channel, 3, padding=1) + else: + self.conv1 = nn.Conv2d(channel, channel, 3, padding=1) + self.conv2 = nn.Conv2d(channel, channel, 3, padding=1) + self.first_layer = self.relu if start_from_relu else nn.Identity() + self.last_layer = self.relu if end_with_relu else nn.Identity() + + def forward(self, x): + out = self.first_layer(x) + out = self.conv1(out) + out = self.relu(out) + out = self.conv2(out) + out = self.last_layer(out) + return x + out + + +class MEBasic(nn.Module): + def __init__(self): + super().__init__() + self.relu = nn.ReLU() + self.conv1 = nn.Conv2d(8, 32, 7, 1, padding=3) + self.conv2 = nn.Conv2d(32, 64, 7, 1, padding=3) + self.conv3 = nn.Conv2d(64, 32, 7, 1, padding=3) + self.conv4 = nn.Conv2d(32, 16, 7, 1, padding=3) + self.conv5 = nn.Conv2d(16, 2, 7, 1, padding=3) + + def forward(self, x): + x = self.relu(self.conv1(x)) + x = self.relu(self.conv2(x)) + x = self.relu(self.conv3(x)) + x = self.relu(self.conv4(x)) + x = self.conv5(x) + return x + + +class ME_Spynet(nn.Module): + def __init__(self): + super().__init__() + self.L = 4 + self.moduleBasic = torch.nn.ModuleList([MEBasic() for _ in range(self.L)]) + + def forward(self, im1, im2): + batchsize = im1.size()[0] + im1_pre = im1 + im2_pre = im2 + + im1_list = [im1_pre] + im2_list = [im2_pre] + for level in range(self.L - 1): + im1_list.append(F.avg_pool2d(im1_list[level], kernel_size=2, stride=2)) + im2_list.append(F.avg_pool2d(im2_list[level], kernel_size=2, stride=2)) + + shape_fine = im2_list[self.L - 1].size() + zero_shape = [batchsize, 2, shape_fine[2] // 2, shape_fine[3] // 2] + flow = torch.zeros(zero_shape, dtype=im1.dtype, device=im1.device) + for level in range(self.L): + flow_up = bilinearupsacling(flow) * 2.0 + img_index = self.L - 1 - level + flow = flow_up + \ + self.moduleBasic[level](torch.cat([im1_list[img_index], + flow_warp(im2_list[img_index], flow_up), + flow_up], 1)) + + return flow + + +class SELayer(nn.Module): + def __init__(self, channel, reduction=16): + super().__init__() + self.fc = nn.Sequential( + nn.Linear(channel, channel // reduction, bias=False), + nn.ReLU(inplace=True), + nn.Linear(channel // reduction, channel, bias=False), + nn.Sigmoid() + ) + + def forward(self, x): + y = torch.mean(x, dim=(-1, -2)) + y = self.fc(y) + return x * y[:, :, None, None] + + +class ConvBlockResidual(nn.Module): + def __init__(self, ch_in, ch_out, se_layer=True): + super().__init__() + self.conv = nn.Sequential( + nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1), + nn.LeakyReLU(0.01), + nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1), + SELayer(ch_out) if se_layer else nn.Identity(), + ) + self.up_dim = nn.Conv2d(ch_in, ch_out, kernel_size=1, stride=1, padding=0) + + def forward(self, x): + x1 = self.conv(x) + x2 = self.up_dim(x) + return x2 + x1 + + +class UNet(nn.Module): + def __init__(self, in_ch=64, out_ch=64): + super().__init__() + self.max_pool = nn.MaxPool2d(kernel_size=2, stride=2) + + self.conv1 = ConvBlockResidual(ch_in=in_ch, ch_out=32) + self.conv2 = ConvBlockResidual(ch_in=32, ch_out=64) + self.conv3 = ConvBlockResidual(ch_in=64, ch_out=128) + + self.context_refine = nn.Sequential( + ResBlock(128, 0), + ResBlock(128, 0), + ResBlock(128, 0), + ResBlock(128, 0), + ) + + self.up3 = subpel_conv1x1(128, 64, 2) + self.up_conv3 = ConvBlockResidual(ch_in=128, ch_out=64) + + self.up2 = subpel_conv1x1(64, 32, 2) + self.up_conv2 = ConvBlockResidual(ch_in=64, ch_out=out_ch) + + def forward(self, x): + # encoding path + x1 = self.conv1(x) + x2 = self.max_pool(x1) + + x2 = self.conv2(x2) + x3 = self.max_pool(x2) + + x3 = self.conv3(x3) + x3 = self.context_refine(x3) + + # decoding + concat path + d3 = self.up3(x3) + d3 = torch.cat((x2, d3), dim=1) + d3 = self.up_conv3(d3) + + d2 = self.up2(d3) + d2 = torch.cat((x1, d2), dim=1) + d2 = self.up_conv2(d2) + return d2 + + +def get_enc_dec_models(input_channel, output_channel, channel): + enc = nn.Sequential( + ResidualBlockWithStride(input_channel, channel, stride=2), + ResidualBlock(channel, channel), + ResidualBlockWithStride(channel, channel, stride=2), + ResidualBlock(channel, channel), + ResidualBlockWithStride(channel, channel, stride=2), + ResidualBlock(channel, channel), + conv3x3(channel, channel, stride=2), + ) + + dec = nn.Sequential( + ResidualBlock(channel, channel), + ResidualBlockUpsample(channel, channel, 2), + ResidualBlock(channel, channel), + ResidualBlockUpsample(channel, channel, 2), + ResidualBlock(channel, channel), + ResidualBlockUpsample(channel, channel, 2), + ResidualBlock(channel, channel), + subpel_conv1x1(channel, output_channel, 2), + ) + + return enc, dec + + +def get_hyper_enc_dec_models(y_channel, z_channel): + enc = nn.Sequential( + conv3x3(y_channel, z_channel), + nn.LeakyReLU(), + conv3x3(z_channel, z_channel), + nn.LeakyReLU(), + conv3x3(z_channel, z_channel, stride=2), + nn.LeakyReLU(), + conv3x3(z_channel, z_channel), + nn.LeakyReLU(), + conv3x3(z_channel, z_channel, stride=2), + ) + + dec = nn.Sequential( + conv3x3(z_channel, y_channel), + nn.LeakyReLU(), + subpel_conv1x1(y_channel, y_channel, 2), + nn.LeakyReLU(), + conv3x3(y_channel, y_channel * 3 // 2), + nn.LeakyReLU(), + subpel_conv1x1(y_channel * 3 // 2, y_channel * 3 // 2, 2), + nn.LeakyReLU(), + conv3x3(y_channel * 3 // 2, y_channel * 2), + ) + + return enc, dec diff --git a/ACMMM2022/src/utils/common.py b/ACMMM2022/src/utils/common.py new file mode 100644 index 0000000..4582755 --- /dev/null +++ b/ACMMM2022/src/utils/common.py @@ -0,0 +1,113 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import argparse +import json +import os +from unittest.mock import patch + +import torch +import numpy as np + + +def str2bool(v): + if isinstance(v, bool): + return v + if v.lower() in ('yes', 'true', 't', 'y', '1'): + return True + if v.lower() in ('no', 'false', 'f', 'n', '0'): + return False + + raise argparse.ArgumentTypeError('Boolean value expected.') + + +def interpolate_log(min_val, max_val, num, decending=True): + assert max_val > min_val + assert min_val > 0 + if decending: + values = np.linspace(np.log(max_val), np.log(min_val), num) + else: + values = np.linspace(np.log(min_val), np.log(max_val), num) + values = np.exp(values) + return values + + +def scale_list_to_str(scales): + s = '' + for scale in scales: + s += f'{scale:.2f} ' + + return s + + +def create_folder(path, print_if_create=False): + if not os.path.exists(path): + os.makedirs(path) + if print_if_create: + print(f"created folder: {path}") + + +@patch('json.encoder.c_make_encoder', None) +def dump_json(obj, fid, float_digits=-1, **kwargs): + of = json.encoder._make_iterencode # pylint: disable=W0212 + + def inner(*args, **kwargs): + args = list(args) + # fifth argument is float formater which we will replace + args[4] = lambda o: format(o, '.%df' % float_digits) + return of(*args, **kwargs) + + with patch('json.encoder._make_iterencode', wraps=inner): + json.dump(obj, fid, **kwargs) + + +def generate_log_json(frame_num, frame_types, bits, psnrs, ssims, + frame_pixel_num, test_time): + cur_ave_i_frame_bit = 0 + cur_ave_i_frame_psnr = 0 + cur_ave_i_frame_msssim = 0 + cur_ave_p_frame_bit = 0 + cur_ave_p_frame_psnr = 0 + cur_ave_p_frame_msssim = 0 + cur_i_frame_num = 0 + cur_p_frame_num = 0 + for idx in range(frame_num): + if frame_types[idx] == 0: + cur_ave_i_frame_bit += bits[idx] + cur_ave_i_frame_psnr += psnrs[idx] + cur_ave_i_frame_msssim += ssims[idx] + cur_i_frame_num += 1 + else: + cur_ave_p_frame_bit += bits[idx] + cur_ave_p_frame_psnr += psnrs[idx] + cur_ave_p_frame_msssim += ssims[idx] + cur_p_frame_num += 1 + + log_result = {} + log_result['frame_pixel_num'] = frame_pixel_num + log_result['i_frame_num'] = cur_i_frame_num + log_result['p_frame_num'] = cur_p_frame_num + log_result['ave_i_frame_bpp'] = cur_ave_i_frame_bit / cur_i_frame_num / frame_pixel_num + log_result['ave_i_frame_psnr'] = cur_ave_i_frame_psnr / cur_i_frame_num + log_result['ave_i_frame_msssim'] = cur_ave_i_frame_msssim / cur_i_frame_num + log_result['frame_bpp'] = list(np.array(bits) / frame_pixel_num) + log_result['frame_psnr'] = psnrs + log_result['frame_msssim'] = ssims + log_result['frame_type'] = frame_types + log_result['test_time'] = test_time + if cur_p_frame_num > 0: + total_p_pixel_num = cur_p_frame_num * frame_pixel_num + log_result['ave_p_frame_bpp'] = cur_ave_p_frame_bit / total_p_pixel_num + log_result['ave_p_frame_psnr'] = cur_ave_p_frame_psnr / cur_p_frame_num + log_result['ave_p_frame_msssim'] = cur_ave_p_frame_msssim / cur_p_frame_num + else: + log_result['ave_p_frame_bpp'] = 0 + log_result['ave_p_frame_psnr'] = 0 + log_result['ave_p_frame_msssim'] = 0 + log_result['ave_all_frame_bpp'] = (cur_ave_i_frame_bit + cur_ave_p_frame_bit) / \ + (frame_num * frame_pixel_num) + log_result['ave_all_frame_psnr'] = (cur_ave_i_frame_psnr + cur_ave_p_frame_psnr) / frame_num + log_result['ave_all_frame_msssim'] = (cur_ave_i_frame_msssim + cur_ave_p_frame_msssim) / \ + frame_num + + return log_result diff --git a/ACMMM2022/src/utils/png_reader.py b/ACMMM2022/src/utils/png_reader.py new file mode 100644 index 0000000..8664c72 --- /dev/null +++ b/ACMMM2022/src/utils/png_reader.py @@ -0,0 +1,51 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import os + +import numpy as np +from PIL import Image + + +class PNGReader(): + def __init__(self, src_folder, width, height): + self.src_folder = src_folder + pngs = os.listdir(self.src_folder) + self.width = width + self.height = height + if 'im1.png' in pngs: + self.padding = 1 + elif 'im00001.png' in pngs: + self.padding = 5 + else: + raise ValueError('unknown image naming convention; please specify') + self.current_frame_index = 1 + self.eof = False + + def read_one_frame(self, src_format="rgb"): + def _none_exist_frame(): + if src_format == "rgb": + return None + return None, None, None + if self.eof: + return _none_exist_frame() + + png_path = os.path.join(self.src_folder, + f"im{str(self.current_frame_index).zfill(self.padding)}.png" + ) + if not os.path.exists(png_path): + self.eof = True + return _none_exist_frame() + + rgb = Image.open(png_path).convert('RGB') + rgb = np.asarray(rgb).astype('float32').transpose(2, 0, 1) + rgb = rgb / 255. + _, height, width = rgb.shape + assert height == self.height + assert width == self.width + + self.current_frame_index += 1 + return rgb + + def close(self): + self.current_frame_index = 1 diff --git a/ACMMM2022/src/utils/stream_helper.py b/ACMMM2022/src/utils/stream_helper.py new file mode 100644 index 0000000..56509ef --- /dev/null +++ b/ACMMM2022/src/utils/stream_helper.py @@ -0,0 +1,143 @@ +# Copyright 2020 InterDigital Communications, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import struct +from pathlib import Path + +import numpy as np +import torch +from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present + + +def get_padding_size(height, width, p=64): + new_h = (height + p - 1) // p * p + new_w = (width + p - 1) // p * p + # padding_left = (new_w - width) // 2 + padding_left = 0 + padding_right = new_w - width - padding_left + # padding_top = (new_h - height) // 2 + padding_top = 0 + padding_bottom = new_h - height - padding_top + return padding_left, padding_right, padding_top, padding_bottom + + +def get_downsampled_shape(height, width, p): + new_h = (height + p - 1) // p * p + new_w = (width + p - 1) // p * p + return int(new_h / p + 0.5), int(new_w / p + 0.5) + + +def get_rounded_q(q_scale): + q_scale = np.clip(q_scale, 0.01, 655.) + q_index = int(np.round(q_scale * 100)) + q_scale = q_index / 100 + return q_scale, q_index + + +def get_state_dict(ckpt_path): + ckpt = torch.load(ckpt_path, map_location=torch.device('cpu')) + if "state_dict" in ckpt: + ckpt = ckpt['state_dict'] + if "net" in ckpt: + ckpt = ckpt["net"] + consume_prefix_in_state_dict_if_present(ckpt, prefix="module.") + return ckpt + + +def filesize(filepath: str) -> int: + if not Path(filepath).is_file(): + raise ValueError(f'Invalid file "{filepath}".') + return Path(filepath).stat().st_size + + +def write_uints(fd, values, fmt=">{:d}I"): + fd.write(struct.pack(fmt.format(len(values)), *values)) + + +def write_uchars(fd, values, fmt=">{:d}B"): + fd.write(struct.pack(fmt.format(len(values)), *values)) + + +def read_uints(fd, n, fmt=">{:d}I"): + sz = struct.calcsize("I") + return struct.unpack(fmt.format(n), fd.read(n * sz)) + + +def read_uchars(fd, n, fmt=">{:d}B"): + sz = struct.calcsize("B") + return struct.unpack(fmt.format(n), fd.read(n * sz)) + + +def write_bytes(fd, values, fmt=">{:d}s"): + if len(values) == 0: + return + fd.write(struct.pack(fmt.format(len(values)), values)) + + +def read_bytes(fd, n, fmt=">{:d}s"): + sz = struct.calcsize("s") + return struct.unpack(fmt.format(n), fd.read(n * sz))[0] + + +def write_ushorts(fd, values, fmt=">{:d}H"): + fd.write(struct.pack(fmt.format(len(values)), *values)) + + +def read_ushorts(fd, n, fmt=">{:d}H"): + sz = struct.calcsize("H") + return struct.unpack(fmt.format(n), fd.read(n * sz)) + + +def encode_i(height, width, q_index, bit_stream, output): + with Path(output).open("wb") as f: + stream_length = len(bit_stream) + + write_uints(f, (height, width)) + write_ushorts(f, (q_index,)) + write_uints(f, (stream_length,)) + write_bytes(f, bit_stream) + + +def decode_i(inputpath): + with Path(inputpath).open("rb") as f: + header = read_uints(f, 2) + height = header[0] + width = header[1] + q_index = read_ushorts(f, 1)[0] + stream_length = read_uints(f, 1)[0] + + bit_stream = read_bytes(f, stream_length) + + return height, width, q_index, bit_stream + + +def encode_p(string, mv_y_q_index, y_q_index, output): + with Path(output).open("wb") as f: + string_length = len(string) + write_ushorts(f, (mv_y_q_index, y_q_index)) + write_uints(f, (string_length,)) + write_bytes(f, string) + + +def decode_p(inputpath): + with Path(inputpath).open("rb") as f: + header = read_ushorts(f, 2) + mv_y_q_index = header[0] + y_q_index = header[1] + + header = read_uints(f, 1) + string_length = header[0] + string = read_bytes(f, string_length) + + return mv_y_q_index, y_q_index, string diff --git a/ACMMM2022/test_video.py b/ACMMM2022/test_video.py new file mode 100644 index 0000000..15577ea --- /dev/null +++ b/ACMMM2022/test_video.py @@ -0,0 +1,406 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import argparse +import os +import concurrent.futures +import json +import multiprocessing +import time + +import torch +import torch.nn.functional as F +import numpy as np +from PIL import Image +from src.models.video_model import DMC +from src.models.image_model import IntraNoAR +from src.utils.common import str2bool, interpolate_log, create_folder, generate_log_json, dump_json +from src.utils.stream_helper import get_padding_size, get_state_dict +from src.utils.png_reader import PNGReader +from tqdm import tqdm +from pytorch_msssim import ms_ssim + + +def parse_args(): + parser = argparse.ArgumentParser(description="Example testing script") + + parser.add_argument('--i_frame_model_path', type=str) + parser.add_argument('--i_frame_q_scales', type=float, nargs="+") + parser.add_argument("--force_intra", type=str2bool, nargs='?', const=True, default=False) + parser.add_argument("--force_frame_num", type=int, default=-1) + parser.add_argument("--force_intra_period", type=int, default=-1) + parser.add_argument('--model_path', type=str) + parser.add_argument('--p_frame_y_q_scales', type=float, nargs="+") + parser.add_argument('--p_frame_mv_y_q_scales', type=float, nargs="+") + parser.add_argument('--rate_num', type=int, default=4) + parser.add_argument('--test_config', type=str, required=True) + parser.add_argument('--force_root_path', type=str, default=None, required=False) + parser.add_argument("--worker", "-w", type=int, default=1, help="worker number") + parser.add_argument("--cuda", type=str2bool, nargs='?', const=True, default=False) + parser.add_argument("--cuda_device", default=None, + help="the cuda device used, e.g., 0; 0,1; 1,2,3; etc.") + parser.add_argument('--write_stream', type=str2bool, nargs='?', + const=True, default=False) + parser.add_argument('--stream_path', type=str, default="out_bin") + parser.add_argument('--save_decoded_frame', type=str2bool, default=False) + parser.add_argument('--decoded_frame_path', type=str, default='decoded_frames') + parser.add_argument('--output_path', type=str, required=True) + parser.add_argument('--verbose', type=int, default=0) + + args = parser.parse_args() + return args + + +def read_image_to_torch(path): + input_image = Image.open(path).convert('RGB') + input_image = np.asarray(input_image).astype('float64').transpose(2, 0, 1) + input_image = torch.from_numpy(input_image).type(torch.FloatTensor) + input_image = input_image.unsqueeze(0)/255 + return input_image + + +def np_image_to_tensor(img): + image = torch.from_numpy(img).type(torch.FloatTensor) + image = image.unsqueeze(0) + return image + + +def save_torch_image(img, save_path): + img = img.squeeze(0).permute(1, 2, 0).detach().cpu().numpy() + img = np.clip(np.rint(img * 255), 0, 255).astype(np.uint8) + Image.fromarray(img).save(save_path) + + +def PSNR(input1, input2): + mse = torch.mean((input1 - input2) ** 2) + psnr = 20 * torch.log10(1 / torch.sqrt(mse)) + return psnr.item() + + +def run_test(video_net, i_frame_net, args, device): + frame_num = args['frame_num'] + gop_size = args['gop_size'] + write_stream = 'write_stream' in args and args['write_stream'] + save_decoded_frame = 'save_decoded_frame' in args and args['save_decoded_frame'] + verbose = args['verbose'] if 'verbose' in args else 0 + + if args['src_type'] == 'png': + src_reader = PNGReader(args['img_path'], args['src_width'], args['src_height']) + + frame_types = [] + psnrs = [] + msssims = [] + bits = [] + frame_pixel_num = 0 + + start_time = time.time() + p_frame_number = 0 + overall_p_decoding_time = 0 + with torch.no_grad(): + for frame_idx in range(frame_num): + frame_start_time = time.time() + rgb = src_reader.read_one_frame(src_format="rgb") + x = np_image_to_tensor(rgb) + x = x.to(device) + pic_height = x.shape[2] + pic_width = x.shape[3] + + if frame_pixel_num == 0: + frame_pixel_num = x.shape[2] * x.shape[3] + else: + assert frame_pixel_num == x.shape[2] * x.shape[3] + + # pad if necessary + padding_l, padding_r, padding_t, padding_b = get_padding_size(pic_height, pic_width) + x_padded = torch.nn.functional.pad( + x, + (padding_l, padding_r, padding_t, padding_b), + mode="constant", + value=0, + ) + + bin_path = os.path.join(args['bin_folder'], f"{frame_idx}.bin") \ + if write_stream else None + + if frame_idx % gop_size == 0: + result = i_frame_net.encode_decode(x_padded, args['i_frame_q_scale'], bin_path, + pic_height=pic_height, pic_width=pic_width) + dpb = { + "ref_frame": result["x_hat"], + "ref_feature": None, + "ref_y": None, + "ref_mv_y": None, + } + recon_frame = result["x_hat"] + frame_types.append(0) + bits.append(result["bit"]) + else: + result = video_net.encode_decode(x_padded, dpb, bin_path, + pic_height=pic_height, pic_width=pic_width, + mv_y_q_scale=args['p_frame_mv_y_q_scale'], + y_q_scale=args['p_frame_y_q_scale']) + dpb = result["dpb"] + recon_frame = dpb["ref_frame"] + frame_types.append(1) + bits.append(result['bit']) + p_frame_number += 1 + overall_p_decoding_time += result['decoding_time'] + + recon_frame = recon_frame.clamp_(0, 1) + x_hat = F.pad(recon_frame, (-padding_l, -padding_r, -padding_t, -padding_b)) + psnr = PSNR(x_hat, x) + msssim = ms_ssim(x_hat, x, data_range=1).item() + psnrs.append(psnr) + msssims.append(msssim) + frame_end_time = time.time() + + if verbose >= 2: + print(f"frame {frame_idx}, {frame_end_time - frame_start_time:.3f} seconds,", + f"bits: {bits[-1]:.3f}, PSNR: {psnrs[-1]:.4f}, MS-SSIM: {msssims[-1]:.4f} ") + + if save_decoded_frame: + save_path = os.path.join(args['decoded_frame_folder'], f'{frame_idx}.png') + save_torch_image(x_hat, save_path) + + test_time = time.time() - start_time + if verbose >= 1 and p_frame_number > 0: + print(f"decoding {p_frame_number} P frames, " + f"average {overall_p_decoding_time/p_frame_number * 1000:.0f} ms.") + + log_result = generate_log_json(frame_num, frame_types, bits, psnrs, msssims, + frame_pixel_num, test_time) + return log_result + + +def encode_one(args, device): + i_state_dict = get_state_dict(args['i_frame_model_path']) + i_frame_net = IntraNoAR() + i_frame_net.load_state_dict(i_state_dict) + i_frame_net = i_frame_net.to(device) + i_frame_net.eval() + + if args['force_intra']: + video_net = None + else: + p_state_dict = get_state_dict(args['model_path']) + video_net = DMC() + video_net.load_state_dict(p_state_dict) + video_net = video_net.to(device) + video_net.eval() + + if args['write_stream']: + if video_net is not None: + video_net.update(force=True) + i_frame_net.update(force=True) + + sub_dir_name = args['video_path'] + gop_size = args['gop'] + frame_num = args['frame_num'] + + bin_folder = os.path.join(args['stream_path'], sub_dir_name, str(args['rate_idx'])) + if args['write_stream']: + create_folder(bin_folder, True) + + if args['save_decoded_frame']: + decoded_frame_folder = os.path.join(args['decoded_frame_path'], sub_dir_name, + str(args['rate_idx'])) + create_folder(decoded_frame_folder) + else: + decoded_frame_folder = None + + args['img_path'] = os.path.join(args['dataset_path'], sub_dir_name) + args['gop_size'] = gop_size + args['frame_num'] = frame_num + args['bin_folder'] = bin_folder + args['decoded_frame_folder'] = decoded_frame_folder + + result = run_test(video_net, i_frame_net, args, device=device) + + result['ds_name'] = args['ds_name'] + result['video_path'] = args['video_path'] + result['rate_idx'] = args['rate_idx'] + + return result + + +def worker(use_cuda, args): + torch.backends.cudnn.benchmark = False + torch.use_deterministic_algorithms(True) + torch.manual_seed(0) + torch.set_num_threads(1) + np.random.seed(seed=0) + gpu_num = 0 + if use_cuda: + gpu_num = torch.cuda.device_count() + + process_name = multiprocessing.current_process().name + process_idx = int(process_name[process_name.rfind('-') + 1:]) + gpu_id = -1 + if gpu_num > 0: + gpu_id = process_idx % gpu_num + if gpu_id >= 0: + device = f"cuda:{gpu_id}" + else: + device = "cpu" + + result = encode_one(args, device) + return result + + +def main(): + begin_time = time.time() + + torch.backends.cudnn.enabled = True + args = parse_args() + + if args.cuda_device is not None and args.cuda_device != '': + os.environ['CUDA_VISIBLE_DEVICES'] = args.cuda_device + os.environ['CUBLAS_WORKSPACE_CONFIG'] = ":4096:8" + + worker_num = args.worker + assert worker_num >= 1 + + with open(args.test_config) as f: + config = json.load(f) + + multiprocessing.set_start_method("spawn") + threadpool_executor = concurrent.futures.ProcessPoolExecutor(max_workers=worker_num) + objs = [] + + count_frames = 0 + count_sequences = 0 + + rate_num = args.rate_num + i_frame_q_scales = IntraNoAR.get_q_scales_from_ckpt(args.i_frame_model_path) + print("q_scales in intra ckpt: ", end='') + for q in i_frame_q_scales: + print(f"{q:.3f}, ", end='') + print() + if args.i_frame_q_scales is not None: + assert len(args.i_frame_q_scales) == rate_num + i_frame_q_scales = args.i_frame_q_scales + print(f"testing {rate_num} rate points with pre-defined intra y q_scales: ", end='') + elif len(i_frame_q_scales) == rate_num: + print(f"testing {rate_num} rate points with intra y q_scales in ckpt: ", end='') + else: + max_q_scale = i_frame_q_scales[0] + min_q_scale = i_frame_q_scales[-1] + i_frame_q_scales = interpolate_log(min_q_scale, max_q_scale, rate_num) + print(f"testing {rate_num} rates, using intra y q_scales: ", end='') + + for q in i_frame_q_scales: + print(f"{q:.3f}, ", end='') + print() + + if not args.force_intra: + p_frame_y_q_scales, p_frame_mv_y_q_scales = DMC.get_q_scales_from_ckpt(args.model_path) + print("y_q_scales in inter ckpt: ", end='') + for q in p_frame_y_q_scales: + print(f"{q:.3f}, ", end='') + print() + print("mv_y_q_scales in inter ckpt: ", end='') + for q in p_frame_mv_y_q_scales: + print(f"{q:.3f}, ", end='') + print() + if args.p_frame_y_q_scales is not None: + assert len(args.p_frame_y_q_scales) == rate_num + assert len(args.p_frame_mv_y_q_scales) == rate_num + p_frame_y_q_scales = args.p_frame_y_q_scales + p_frame_mv_y_q_scales = args.p_frame_mv_y_q_scales + print(f"testing {rate_num} rate points with pre-defined inter q_scales") + elif len(p_frame_y_q_scales) == rate_num: + print(f"testing {rate_num} rate points with inter q_scales in ckpt") + else: + max_y_q_scale = p_frame_y_q_scales[0] + min_y_q_scale = p_frame_y_q_scales[-1] + p_frame_y_q_scales = interpolate_log(min_y_q_scale, max_y_q_scale, rate_num) + + max_mv_y_q_scale = p_frame_mv_y_q_scales[0] + min_mv_y_q_scale = p_frame_mv_y_q_scales[-1] + p_frame_mv_y_q_scales = interpolate_log(min_mv_y_q_scale, max_mv_y_q_scale, rate_num) + print("y_q_scales for testing: ", end='') + for q in p_frame_y_q_scales: + print(f"{q:.3f}, ", end='') + print() + print("mv_y_q_scales for testing: ", end='') + for q in p_frame_mv_y_q_scales: + print(f"{q:.3f}, ", end='') + print() + + root_path = args.force_root_path if args.force_root_path is not None else config['root_path'] + config = config['test_classes'] + for ds_name in config: + if config[ds_name]['test'] == 0: + continue + for seq_name in config[ds_name]['sequences']: + count_sequences += 1 + for rate_idx in range(rate_num): + cur_args = {} + cur_args['rate_idx'] = rate_idx + cur_args['i_frame_model_path'] = args.i_frame_model_path + cur_args['i_frame_q_scale'] = i_frame_q_scales[rate_idx] + if not args.force_intra: + cur_args['model_path'] = args.model_path + cur_args['p_frame_y_q_scale'] = p_frame_y_q_scales[rate_idx] + cur_args['p_frame_mv_y_q_scale'] = p_frame_mv_y_q_scales[rate_idx] + cur_args['force_intra'] = args.force_intra + cur_args['video_path'] = seq_name + cur_args['src_type'] = config[ds_name]['src_type'] + cur_args['src_height'] = config[ds_name]['sequences'][seq_name]['height'] + cur_args['src_width'] = config[ds_name]['sequences'][seq_name]['width'] + cur_args['gop'] = config[ds_name]['sequences'][seq_name]['gop'] + if args.force_intra: + cur_args['gop'] = 1 + if args.force_intra_period > 0: + cur_args['gop'] = args.force_intra_period + cur_args['frame_num'] = config[ds_name]['sequences'][seq_name]['frames'] + if args.force_frame_num > 0: + cur_args['frame_num'] = args.force_frame_num + cur_args['dataset_path'] = os.path.join(root_path, config[ds_name]['base_path']) + cur_args['write_stream'] = args.write_stream + cur_args['stream_path'] = args.stream_path + cur_args['save_decoded_frame'] = args.save_decoded_frame + cur_args['decoded_frame_path'] = f'{args.decoded_frame_path}_DMC_{rate_idx}' + cur_args['ds_name'] = ds_name + cur_args['verbose'] = args.verbose + + count_frames += cur_args['frame_num'] + + obj = threadpool_executor.submit( + worker, + args.cuda, + cur_args) + objs.append(obj) + + results = [] + for obj in tqdm(objs): + result = obj.result() + results.append(result) + + log_result = {} + for ds_name in config: + if config[ds_name]['test'] == 0: + continue + log_result[ds_name] = {} + for seq in config[ds_name]['sequences']: + log_result[ds_name][seq] = {} + for rate in range(rate_num): + for res in results: + if res['rate_idx'] == rate and ds_name == res['ds_name'] \ + and seq == res['video_path']: + log_result[ds_name][seq][f"{rate:03d}"] = res + + out_json_dir = os.path.dirname(args.output_path) + if len(out_json_dir) > 0: + create_folder(out_json_dir, True) + with open(args.output_path, 'w') as fp: + dump_json(log_result, fp, float_digits=6, indent=2) + + total_minutes = (time.time() - begin_time) / 60 + print('Test finished') + print(f'Tested {count_frames} frames from {count_sequences} sequences') + print(f'Total elapsed time: {total_minutes:.1f} min') + + +if __name__ == "__main__": + main() diff --git a/README.md b/README.md index 157ddfa..1645e43 100644 --- a/README.md +++ b/README.md @@ -1,8 +1,26 @@ # Introduction Official Pytorch implementation for Neural Video Compression including: -* [Deep Contextual Video Compression](https://proceedings.neurips.cc/paper/2021/file/96b250a90d3cf0868c83f8c965142d2a-Paper.pdf), NeurIPS 2021, in [this foloder](./NeurIPS2021/). -* [Hybrid Spatial-Temporal Entropy Modelling for Neural Video Compression], ACM MM 2022, in [this folder](./ACMMM2022/) +* [Deep Contextual Video Compression](https://proceedings.neurips.cc/paper/2021/file/96b250a90d3cf0868c83f8c965142d2a-Paper.pdf), NeurIPS 2021, in [this folder](./NeurIPS2021/). +* [Hybrid Spatial-Temporal Entropy Modelling for Neural Video Compression](https://arxiv.org/abs/2207.05894), ACM MM 2022, in [this folder](./ACMMM2022/). + - The first end-to-end neural video codec to exceed H.266 (VTM) using the highest compression ratio configuration, in terms of both PSNR and MS-SSIM. + - The first end-to-end neural video codec to achieve rate adjustment in single model. + +# On the comparison + +Please note that different methods may use different configurations to test different models, such as +* Source video may be different, e.g., cropped or padded to the desired resolution. +* Intra period may be different, e.g., 32, 12, or 10. +* Number of encoded frames may be different. + +So, it does not make sense to compare the numbers in different methods directly, unless making sure they are using same test conditions. + +# Command line to generate VTM results + +Get VTM from https://vcgit.hhi.fraunhofer.de/jvet/VVCSoftware_VTM and build the project. +```bash +EncoderApp -c encoder_lowdelay_vtm.cfg --InputFile={input file name} --BitstreamFile={bitstream file name} --DecodingRefreshType=2 --InputBitDepth=8 --OutputBitDepth=8 --OutputBitDepthC=8 --InputChromaFormat=444 --FrameRate={frame rate} --FramesToBeEncoded={frame number} --SourceWidth={width} --SourceHeight={height} --IntraPeriod=32 --QP={qp} --Level=6.2 +``` # Acknowledgement The implementation is based on [CompressAI](https://github.com/InterDigitalInc/CompressAI) and [PyTorchVideoCompression](https://github.com/ZhihaoHu/PyTorchVideoCompression). @@ -18,6 +36,13 @@ If you find this work useful for your research, please cite: volume={34}, year={2021} } + +@inproceedings{li2022hybrid, + title={Hybrid Spatial-Temporal Entropy Modelling for Neural Video Compression}, + author={Li, Jiahao and Li, Bin and Lu, Yan}, + booktitle={Proceedings of the 30th ACM International Conference on Multimedia}, + year={2022} +} ``` # Trademarks