diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..5ac5364 --- /dev/null +++ b/.gitignore @@ -0,0 +1,6 @@ +*.pyc +.vscode/ +*.bin +*.png +*.so +build/ \ No newline at end of file diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md new file mode 100644 index 0000000..f9ba8cf --- /dev/null +++ b/CODE_OF_CONDUCT.md @@ -0,0 +1,9 @@ +# Microsoft Open Source Code of Conduct + +This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). + +Resources: + +- [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/) +- [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) +- Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000..ebf23ac --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,14 @@ +# Contributing + +This project welcomes contributions and suggestions. Most contributions require you to +agree to a Contributor License Agreement (CLA) declaring that you have the right to, +and actually do, grant us the rights to use your contribution. For details, visit +https://cla.microsoft.com. + +When you submit a pull request, a CLA-bot will automatically determine whether you need +to provide a CLA and decorate the PR appropriately (e.g., label, comment). Simply follow the +instructions provided by the bot. You will only need to do this once across all repositories using our CLA. + +This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). +For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) +or contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments. diff --git a/LICENSE.txt b/LICENSE.txt new file mode 100644 index 0000000..b2f52a2 --- /dev/null +++ b/LICENSE.txt @@ -0,0 +1,21 @@ +Copyright (c) Microsoft Corporation. + +MIT License + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/NOTICE .txt b/NOTICE .txt new file mode 100644 index 0000000..27a0897 --- /dev/null +++ b/NOTICE .txt @@ -0,0 +1,237 @@ +The implementation is based on [CompressAI](https://github.com/InterDigitalInc/CompressAI). +The license when we obtained the code was Apache License, Version 2.0. + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + +The current license is BSD 3-Clause Clear License. + +Copyright (c) 2021-2022, InterDigital Communications, Inc +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted (subject to the limitations in the disclaimer +below) provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, +this list of conditions and the following disclaimer. +* Redistributions in binary form must reproduce the above copyright notice, +this list of conditions and the following disclaimer in the documentation +and/or other materials provided with the distribution. +* Neither the name of InterDigital Communications, Inc nor the names of its +contributors may be used to endorse or promote products derived from this +software without specific prior written permission. + +NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY +THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND +CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT +NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A +PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER +OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; +OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR +OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF +ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +[PyTorchVideoCompression](https://github.com/ZhihaoHu/PyTorchVideoCompression). +There is no explicit license in this repo. \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..f455619 --- /dev/null +++ b/README.md @@ -0,0 +1,99 @@ +# Introduction + +Official Pytorch implementation for [Deep Contextual Video Compression](https://proceedings.neurips.cc/paper/2021/file/96b250a90d3cf0868c83f8c965142d2a-Paper.pdf), NeurIPS 2021 + +# Prerequisites +* Python 3.8 and conda, get [Conda](https://www.anaconda.com/) +* CUDA 11.0 +* Environment + ``` + conda create -n $YOUR_PY38_ENV_NAME python=3.8 + conda activate $YOUR_PY38_ENV_NAME + + pip install torch==1.7.1+cu110 torchvision==0.8.2+cu110 torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html + python -m pip install -r requirements.txt + ``` + + + +# Test dataset +Currenlty the spatial resolution of video needs to be cropped into the integral times of 64. + +The dataset format can be seen in dataset_config_example.json. + +For example, one video of HEVC Class B can be prepared as: +* Crop the original YUV via ffmpeg: + ``` + ffmpeg -pix_fmt yuv420p -s 1920x1080 -i BasketballDrive_1920x1080_50.yuv -vf crop=1920:1024:0:0 BasketballDrive_1920x1024_50.yuv + ``` +* Make the video path: + ``` + mkdir BasketballDrive_1920x1024_50 + ``` +* Convert YUV to PNG: + ``` + ffmpeg -pix_fmt yuv420p -s 1920x1024 -i BasketballDrive_1920x1024_50.yuv -f image2 BasketballDrive_1920x1024_50/im%05d.png + ``` +At last, the folder structure of dataset is like: + + /media/data/HEVC_B/ + * BQTerrace_1920x1024_60/ + - im00001.png + - im00002.png + - im00003.png + - ... + * BasketballDrive_1920x1024_50/ + - im00001.png + - im00002.png + - im00003.png + - ... + * ... + /media/data/HEVC_D + /media/data/HEVC_C/ + ... + +# Pretrained models + +* Download CompressAI models + ``` + cd checkpoints/ + python download_compressai_models.py + cd .. + ``` + +* Download [DCVC models](https://1drv.ms/u/s!AozfVVwtWWYoiS5mcGX320bFXI0k?e=iMeykH) and put them into /checkpoints folder. + +# Test DCVC + +Example of test the PSNR model: +```bash +python test_video.py --i_frame_model_name cheng2020-anchor --i_frame_model_path checkpoints/cheng2020-anchor-3-e49be189.pth.tar checkpoints/cheng2020-anchor-4-98b0b468.pth.tar checkpoints/cheng2020-anchor-5-23852949.pth.tar checkpoints/cheng2020-anchor-6-4c052b1a.pth.tar --test_config dataset_config_example.json --cuda true --cuda_device 0,1,2,3 --worker 4 --output_json_result_path DCVC_result_psnr.json --model_type psnr --recon_bin_path recon_bin_folder_psnr --model_path checkpoints/model_dcvc_quality_0_psnr.pth checkpoints/model_dcvc_quality_1_psnr.pth checkpoints/model_dcvc_quality_2_psnr.pth checkpoints/model_dcvc_quality_3_psnr.pth +``` + +Example of test the MSSSIM model: +```bash +python test_video.py --i_frame_model_name bmshj2018-hyperprior --i_frame_model_path checkpoints/bmshj2018-hyperprior-ms-ssim-3-92dd7878.pth.tar checkpoints/bmshj2018-hyperprior-ms-ssim-4-4377354e.pth.tar checkpoints/bmshj2018-hyperprior-ms-ssim-5-c34afc8d.pth.tar checkpoints/bmshj2018-hyperprior-ms-ssim-6-3a6d8229.pth.tar --test_config dataset_config_example.json --cuda true --cuda_device 0,1,2,3 --worker 4 --output_json_result_path DCVC_result_msssim.json --model_type msssim --recon_bin_path recon_bin_folder_msssim --model_path checkpoints/model_dcvc_quality_0_msssim.pth checkpoints/model_dcvc_quality_1_msssim.pth checkpoints/model_dcvc_quality_2_msssim.pth checkpoints/model_dcvc_quality_3_msssim.pth +``` +It is recommended that the ```--worker``` number is equal to your GPU number. + +# R-D Curve of DCVC +![PSNR RD Curve](assets/rd_curve_psnr.png) + +# Acknowledgement +The implementation is based on [CompressAI](https://github.com/InterDigitalInc/CompressAI) and [PyTorchVideoCompression](https://github.com/ZhihaoHu/PyTorchVideoCompression). The model weights of intra coding come from [CompressAI](https://github.com/InterDigitalInc/CompressAI). + +# Citation +If you find this work useful for your research, please cite: + +``` +@article{li2021deep, + title={Deep Contextual Video Compression}, + author={Li, Jiahao and Li, Bin and Lu, Yan}, + journal={Advances in Neural Information Processing Systems}, + volume={34}, + year={2021} +} +``` + +# Trademarks +This project may contain trademarks or logos for projects, products, or services. Authorized use of Microsoft trademarks or logos is subject to and must follow [Microsoft’s Trademark & Brand Guidelines](https://www.microsoft.com/en-us/legal/intellectualproperty/trademarks/usage/general). Use of Microsoft trademarks or logos in modified versions of this project must not cause confusion or imply Microsoft sponsorship. Any use of third-party trademarks or logos are subject to those third-party’s policies. diff --git a/SECURITY.md b/SECURITY.md new file mode 100644 index 0000000..926b8ae --- /dev/null +++ b/SECURITY.md @@ -0,0 +1,41 @@ + + +## Security + +Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/). + +If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://docs.microsoft.com/en-us/previous-versions/tn-archive/cc751383(v=technet.10)), please report it to us as described below. + +## Reporting Security Issues + +**Please do not report security vulnerabilities through public GitHub issues.** + +Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://msrc.microsoft.com/create-report). + +If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://www.microsoft.com/en-us/msrc/pgp-key-msrc). + +You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc). + +Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue: + + * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.) + * Full paths of source file(s) related to the manifestation of the issue + * The location of the affected source code (tag/branch/commit or direct URL) + * Any special configuration required to reproduce the issue + * Step-by-step instructions to reproduce the issue + * Proof-of-concept or exploit code (if possible) + * Impact of the issue, including how an attacker might exploit the issue + +This information will help us triage your report more quickly. + +If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://microsoft.com/msrc/bounty) page for more details about our active programs. + +## Preferred Languages + +We prefer all communications to be in English. + +## Policy + +Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://www.microsoft.com/en-us/msrc/cvd). + + diff --git a/assets/rd_curve_psnr.png b/assets/rd_curve_psnr.png new file mode 100644 index 0000000..157f497 Binary files /dev/null and b/assets/rd_curve_psnr.png differ diff --git a/checkpoints/download_compressai_models.py b/checkpoints/download_compressai_models.py new file mode 100644 index 0000000..c942195 --- /dev/null +++ b/checkpoints/download_compressai_models.py @@ -0,0 +1,19 @@ +import urllib.request + +# The model weights of intra coding come from CompressAI. +root_url = "https://compressai.s3.amazonaws.com/models/v1/" + +model_names = [ + "bmshj2018-hyperprior-ms-ssim-3-92dd7878.pth.tar", + "bmshj2018-hyperprior-ms-ssim-4-4377354e.pth.tar", + "bmshj2018-hyperprior-ms-ssim-5-c34afc8d.pth.tar", + "bmshj2018-hyperprior-ms-ssim-6-3a6d8229.pth.tar", + "cheng2020-anchor-3-e49be189.pth.tar", + "cheng2020-anchor-4-98b0b468.pth.tar", + "cheng2020-anchor-5-23852949.pth.tar", + "cheng2020-anchor-6-4c052b1a.pth.tar", +] + +for model in model_names: + print(f"downloading {model}") + urllib.request.urlretrieve(root_url+model, model) \ No newline at end of file diff --git a/dataset_config_example.json b/dataset_config_example.json new file mode 100644 index 0000000..d39a9e9 --- /dev/null +++ b/dataset_config_example.json @@ -0,0 +1,85 @@ +{ + "HEVC_B": { + "base_path": "/media/data/HEVC_B", + "sequences": { + "BQTerrace_1920x1024_60": {"frames": 100, "gop": 10}, + "BasketballDrive_1920x1024_50": {"frames": 100, "gop": 10}, + "Cactus_1920x1024_50": {"frames": 100, "gop": 10}, + "Kimono1_1920x1024_24": {"frames": 100, "gop": 10}, + "ParkScene_1920x1024_24": {"frames": 100, "gop": 10} + } + }, + "HEVC_C": { + "base_path": "/media/data/HEVC_C", + "sequences": { + "BQMall_832x448_60": {"frames": 100, "gop": 10}, + "BasketballDrill_832x448_50": {"frames": 100, "gop": 10}, + "PartyScene_832x448_50": {"frames": 100, "gop": 10}, + "RaceHorses_832x448_30": {"frames": 100, "gop": 10} + } + }, + "HEVC_D": { + "base_path": "/media/data/HEVC_D", + "sequences": { + "BasketballPass_384x192_50": {"frames": 100, "gop": 10}, + "BlowingBubbles_384x192_50": {"frames": 100, "gop": 10}, + "BQSquare_384x192_50": {"frames": 100, "gop": 10}, + "RaceHorses_384x192_50": {"frames": 100, "gop": 10} + } + }, + "HEVC_E": { + "base_path": "/media/data/HEVC_E", + "sequences": { + "FourPeople_1280x704_60": {"frames": 100, "gop": 10}, + "Johnny_1280x704_60": {"frames": 100, "gop": 10}, + "KristenAndSara_1280x704_60": {"frames": 100, "gop": 10} + } + }, + "UVG": { + "base_path": "/media/data/UVGDataSet_crop", + "sequences": { + "Beauty_1920x1024_120fps_420_8bit_YUV": {"frames": 120, "gop": 12}, + "Bosphorus_1920x1024_120fps_420_8bit_YUV": {"frames": 120, "gop": 12}, + "HoneyBee_1920x1024_120fps_420_8bit_YUV": {"frames": 120, "gop": 12}, + "Jockey_1920x1024_120fps_420_8bit_YUV": {"frames": 120, "gop": 12}, + "ReadySteadyGo_1920x1024_120fps_420_8bit_YUV": {"frames": 120, "gop": 12}, + "ShakeNDry_1920x1024_120fps_420_8bit_YUV": {"frames": 120, "gop": 12}, + "YachtRide_1920x1024_120fps_420_8bit_YUV": {"frames": 120, "gop": 12} + } + }, + "MCL-JCV": { + "base_path": "/media/data/MCL-JCV", + "sequences": { + "videoSRC01_1920x1024_30": {"frames": 120, "gop": 12}, + "videoSRC02_1920x1024_30": {"frames": 120, "gop": 12}, + "videoSRC03_1920x1024_30": {"frames": 120, "gop": 12}, + "videoSRC04_1920x1024_30": {"frames": 120, "gop": 12}, + "videoSRC05_1920x1024_25": {"frames": 120, "gop": 12}, + "videoSRC06_1920x1024_25": {"frames": 120, "gop": 12}, + "videoSRC07_1920x1024_25": {"frames": 120, "gop": 12}, + "videoSRC08_1920x1024_25": {"frames": 120, "gop": 12}, + "videoSRC09_1920x1024_25": {"frames": 120, "gop": 12}, + "videoSRC10_1920x1024_30": {"frames": 120, "gop": 12}, + "videoSRC11_1920x1024_30": {"frames": 120, "gop": 12}, + "videoSRC12_1920x1024_30": {"frames": 120, "gop": 12}, + "videoSRC13_1920x1024_30": {"frames": 120, "gop": 12}, + "videoSRC14_1920x1024_30": {"frames": 120, "gop": 12}, + "videoSRC15_1920x1024_30": {"frames": 120, "gop": 12}, + "videoSRC16_1920x1024_30": {"frames": 120, "gop": 12}, + "videoSRC17_1920x1024_24": {"frames": 120, "gop": 12}, + "videoSRC18_1920x1024_25": {"frames": 120, "gop": 12}, + "videoSRC19_1920x1024_30": {"frames": 120, "gop": 12}, + "videoSRC20_1920x1024_25": {"frames": 120, "gop": 12}, + "videoSRC21_1920x1024_24": {"frames": 120, "gop": 12}, + "videoSRC22_1920x1024_24": {"frames": 120, "gop": 12}, + "videoSRC23_1920x1024_24": {"frames": 120, "gop": 12}, + "videoSRC24_1920x1024_24": {"frames": 120, "gop": 12}, + "videoSRC25_1920x1024_24": {"frames": 120, "gop": 12}, + "videoSRC26_1920x1024_30": {"frames": 120, "gop": 12}, + "videoSRC27_1920x1024_30": {"frames": 120, "gop": 12}, + "videoSRC28_1920x1024_30": {"frames": 120, "gop": 12}, + "videoSRC29_1920x1024_24": {"frames": 120, "gop": 12}, + "videoSRC30_1920x1024_30": {"frames": 120, "gop": 12} + } + } +} diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..fad3090 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,6 @@ +numpy +scipy +matplotlib +Pillow +pytorch-msssim +tqdm \ No newline at end of file diff --git a/src/cpp/3rdparty/CMakeLists.txt b/src/cpp/3rdparty/CMakeLists.txt new file mode 100644 index 0000000..8f63698 --- /dev/null +++ b/src/cpp/3rdparty/CMakeLists.txt @@ -0,0 +1,5 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +add_subdirectory(pybind11) +add_subdirectory(ryg_rans) \ No newline at end of file diff --git a/src/cpp/3rdparty/pybind11/CMakeLists.txt b/src/cpp/3rdparty/pybind11/CMakeLists.txt new file mode 100644 index 0000000..3c88809 --- /dev/null +++ b/src/cpp/3rdparty/pybind11/CMakeLists.txt @@ -0,0 +1,24 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +configure_file(CMakeLists.txt.in pybind11-download/CMakeLists.txt) +execute_process(COMMAND ${CMAKE_COMMAND} -G "${CMAKE_GENERATOR}" . + RESULT_VARIABLE result + WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/pybind11-download ) +if(result) + message(FATAL_ERROR "CMake step for pybind11 failed: ${result}") +endif() +execute_process(COMMAND ${CMAKE_COMMAND} --build . + RESULT_VARIABLE result + WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/pybind11-download ) +if(result) + message(FATAL_ERROR "Build step for pybind11 failed: ${result}") +endif() + +add_subdirectory(${CMAKE_CURRENT_BINARY_DIR}/pybind11-src/ + ${CMAKE_CURRENT_BINARY_DIR}/pybind11-build/ + EXCLUDE_FROM_ALL) + +set(PYBIND11_INCLUDE + ${CMAKE_CURRENT_BINARY_DIR}/pybind11-src/include/ + CACHE INTERNAL "") diff --git a/src/cpp/3rdparty/pybind11/CMakeLists.txt.in b/src/cpp/3rdparty/pybind11/CMakeLists.txt.in new file mode 100644 index 0000000..f0b4565 --- /dev/null +++ b/src/cpp/3rdparty/pybind11/CMakeLists.txt.in @@ -0,0 +1,33 @@ +cmake_minimum_required(VERSION 3.6.3) + +project(pybind11-download NONE) + +include(ExternalProject) +if(IS_DIRECTORY "${PROJECT_BINARY_DIR}/3rdparty/pybind11/pybind11-src/include") + ExternalProject_Add(pybind11 + GIT_REPOSITORY https://github.com/pybind/pybind11.git + GIT_TAG v2.6.1 + GIT_SHALLOW 1 + SOURCE_DIR "${CMAKE_CURRENT_BINARY_DIR}/pybind11-src" + BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}/pybind11-build" + DOWNLOAD_COMMAND "" + UPDATE_COMMAND "" + CONFIGURE_COMMAND "" + BUILD_COMMAND "" + INSTALL_COMMAND "" + TEST_COMMAND "" + ) +else() + ExternalProject_Add(pybind11 + GIT_REPOSITORY https://github.com/pybind/pybind11.git + GIT_TAG v2.6.1 + GIT_SHALLOW 1 + SOURCE_DIR "${CMAKE_CURRENT_BINARY_DIR}/pybind11-src" + BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}/pybind11-build" + UPDATE_COMMAND "" + CONFIGURE_COMMAND "" + BUILD_COMMAND "" + INSTALL_COMMAND "" + TEST_COMMAND "" + ) +endif() diff --git a/src/cpp/3rdparty/ryg_rans/CMakeLists.txt b/src/cpp/3rdparty/ryg_rans/CMakeLists.txt new file mode 100644 index 0000000..577eebe --- /dev/null +++ b/src/cpp/3rdparty/ryg_rans/CMakeLists.txt @@ -0,0 +1,20 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +configure_file(CMakeLists.txt.in ryg_rans-download/CMakeLists.txt) +execute_process(COMMAND ${CMAKE_COMMAND} -G "${CMAKE_GENERATOR}" . + RESULT_VARIABLE result + WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/ryg_rans-download ) +if(result) + message(FATAL_ERROR "CMake step for ryg_rans failed: ${result}") +endif() +execute_process(COMMAND ${CMAKE_COMMAND} --build . + RESULT_VARIABLE result + WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/ryg_rans-download ) +if(result) + message(FATAL_ERROR "Build step for ryg_rans failed: ${result}") +endif() + +set(RYG_RANS_INCLUDE + ${CMAKE_CURRENT_BINARY_DIR}/ryg_rans-src/ + CACHE INTERNAL "") diff --git a/src/cpp/3rdparty/ryg_rans/CMakeLists.txt.in b/src/cpp/3rdparty/ryg_rans/CMakeLists.txt.in new file mode 100644 index 0000000..3c62451 --- /dev/null +++ b/src/cpp/3rdparty/ryg_rans/CMakeLists.txt.in @@ -0,0 +1,33 @@ +cmake_minimum_required(VERSION 3.6.3) + +project(ryg_rans-download NONE) + +include(ExternalProject) +if(EXISTS "${PROJECT_BINARY_DIR}/3rdparty/ryg_rans/ryg_rans-src/rans64.h") + ExternalProject_Add(ryg_rans + GIT_REPOSITORY https://github.com/rygorous/ryg_rans.git + GIT_TAG c9d162d996fd600315af9ae8eb89d832576cb32d + GIT_SHALLOW 1 + SOURCE_DIR "${CMAKE_CURRENT_BINARY_DIR}/ryg_rans-src" + BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}/ryg_rans-build" + DOWNLOAD_COMMAND "" + UPDATE_COMMAND "" + CONFIGURE_COMMAND "" + BUILD_COMMAND "" + INSTALL_COMMAND "" + TEST_COMMAND "" + ) +else() + ExternalProject_Add(ryg_rans + GIT_REPOSITORY https://github.com/rygorous/ryg_rans.git + GIT_TAG c9d162d996fd600315af9ae8eb89d832576cb32d + GIT_SHALLOW 1 + SOURCE_DIR "${CMAKE_CURRENT_BINARY_DIR}/ryg_rans-src" + BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}/ryg_rans-build" + UPDATE_COMMAND "" + CONFIGURE_COMMAND "" + BUILD_COMMAND "" + INSTALL_COMMAND "" + TEST_COMMAND "" + ) +endif() diff --git a/src/cpp/CMakeLists.txt b/src/cpp/CMakeLists.txt new file mode 100644 index 0000000..069e920 --- /dev/null +++ b/src/cpp/CMakeLists.txt @@ -0,0 +1,23 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +cmake_minimum_required (VERSION 3.6.3) +project (ErrorRecovery) + +set(CMAKE_CONFIGURATION_TYPES "Debug;Release;RelWithDebInfo" CACHE STRING "" FORCE) + +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED ON) +set(CMAKE_CXX_EXTENSIONS OFF) + +# treat warning as error +if (MSVC) + add_compile_options(/W4 /WX) +else() + add_compile_options(-Wall -Wextra -pedantic -Werror) +endif() + +# The sequence is tricky, put 3rd party first +add_subdirectory(3rdparty) +add_subdirectory (ops) +add_subdirectory (rans) diff --git a/src/cpp/ops/CMakeLists.txt b/src/cpp/ops/CMakeLists.txt new file mode 100644 index 0000000..03b72c4 --- /dev/null +++ b/src/cpp/ops/CMakeLists.txt @@ -0,0 +1,28 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +cmake_minimum_required(VERSION 3.7) +set(PROJECT_NAME MLCodec_CXX) +project(${PROJECT_NAME}) + +set(cxx_source + ops.cpp + ) + +set(include_dirs + ${CMAKE_CURRENT_SOURCE_DIR} + ${PYBIND11_INCLUDE} + ) + +pybind11_add_module(${PROJECT_NAME} ${cxx_source}) + +target_include_directories (${PROJECT_NAME} PUBLIC ${include_dirs}) + +# The post build argument is executed after make! +add_custom_command( + TARGET ${PROJECT_NAME} POST_BUILD + COMMAND + "${CMAKE_COMMAND}" -E copy + "$" + "${CMAKE_CURRENT_SOURCE_DIR}/../../entropy_models/" +) diff --git a/src/cpp/ops/ops.cpp b/src/cpp/ops/ops.cpp new file mode 100644 index 0000000..9463ab7 --- /dev/null +++ b/src/cpp/ops/ops.cpp @@ -0,0 +1,91 @@ +/* Copyright 2020 InterDigital Communications, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#include +#include +#include +#include + +std::vector pmf_to_quantized_cdf(const std::vector &pmf, + int precision) { + /* NOTE(begaintj): ported from `ryg_rans` public implementation. Not optimal + * although it's only run once per model after training. See TF/compression + * implementation for an optimized version. */ + + std::vector cdf(pmf.size() + 1); + cdf[0] = 0; /* freq 0 */ + + std::transform(pmf.begin(), pmf.end(), cdf.begin() + 1, [=](float p) { + return static_cast(std::round(p * (1 << precision)) + 0.5); + }); + + const uint32_t total = std::accumulate(cdf.begin(), cdf.end(), 0); + + std::transform( + cdf.begin(), cdf.end(), cdf.begin(), [precision, total](uint32_t p) { + return static_cast((((1ull << precision) * p) / total)); + }); + + std::partial_sum(cdf.begin(), cdf.end(), cdf.begin()); + cdf.back() = 1 << precision; + + for (int i = 0; i < static_cast(cdf.size() - 1); ++i) { + if (cdf[i] == cdf[i + 1]) { + /* Try to steal frequency from low-frequency symbols */ + uint32_t best_freq = ~0u; + int best_steal = -1; + for (int j = 0; j < static_cast(cdf.size()) - 1; ++j) { + uint32_t freq = cdf[j + 1] - cdf[j]; + if (freq > 1 && freq < best_freq) { + best_freq = freq; + best_steal = j; + } + } + + assert(best_steal != -1); + + if (best_steal < i) { + for (int j = best_steal + 1; j <= i; ++j) { + cdf[j]--; + } + } else { + assert(best_steal > i); + for (int j = i + 1; j <= best_steal; ++j) { + cdf[j]++; + } + } + } + } + + assert(cdf[0] == 0); + assert(cdf.back() == (1u << precision)); + for (int i = 0; i < static_cast(cdf.size()) - 1; ++i) { + assert(cdf[i + 1] > cdf[i]); + } + + return cdf; +} + +PYBIND11_MODULE(MLCodec_CXX, m) { + m.attr("__name__") = "MLCodec_CXX"; + + m.doc() = "C++ utils"; + + m.def("pmf_to_quantized_cdf", &pmf_to_quantized_cdf, + "Return quantized CDF for a given PMF"); +} diff --git a/src/cpp/rans/CMakeLists.txt b/src/cpp/rans/CMakeLists.txt new file mode 100644 index 0000000..6b443f0 --- /dev/null +++ b/src/cpp/rans/CMakeLists.txt @@ -0,0 +1,29 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +cmake_minimum_required(VERSION 3.7) +set(PROJECT_NAME MLCodec_rans) +project(${PROJECT_NAME}) + +set(rans_source + rans_interface.hpp + rans_interface.cpp + ) + +set(include_dirs + ${CMAKE_CURRENT_SOURCE_DIR} + ${PYBIND11_INCLUDE} + ${RYG_RANS_INCLUDE} + ) + +pybind11_add_module(${PROJECT_NAME} ${rans_source}) + +target_include_directories (${PROJECT_NAME} PUBLIC ${include_dirs}) + +add_custom_command( + TARGET ${PROJECT_NAME} POST_BUILD + COMMAND + "${CMAKE_COMMAND}" -E copy + "$" + "${CMAKE_CURRENT_SOURCE_DIR}/../../entropy_models/" +) diff --git a/src/cpp/rans/rans_interface.cpp b/src/cpp/rans/rans_interface.cpp new file mode 100644 index 0000000..b75f5e0 --- /dev/null +++ b/src/cpp/rans/rans_interface.cpp @@ -0,0 +1,375 @@ +/* Copyright 2020 InterDigital Communications, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* Rans64 extensions from: + * https://fgiesen.wordpress.com/2015/12/21/rans-in-practice/ + * Unbounded range coding from: + * https://github.com/tensorflow/compression/blob/master/tensorflow_compression/cc/kernels/unbounded_index_range_coding_kernels.cc + **/ + +#include "rans_interface.hpp" + +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +namespace py = pybind11; + +/* probability range, this could be a parameter... */ +constexpr int precision = 16; + +constexpr uint16_t bypass_precision = 4; /* number of bits in bypass mode */ +constexpr uint16_t max_bypass_val = (1 << bypass_precision) - 1; + +namespace { + +/* We only run this in debug mode as its costly... */ +void assert_cdfs(const std::vector> &cdfs, + const std::vector &cdfs_sizes) { + for (int i = 0; i < static_cast(cdfs.size()); ++i) { + assert(cdfs[i][0] == 0); + assert(cdfs[i][cdfs_sizes[i] - 1] == (1 << precision)); + for (int j = 0; j < cdfs_sizes[i] - 1; ++j) { + assert(cdfs[i][j + 1] > cdfs[i][j]); + } + } +} + +/* Support only 16 bits word max */ +inline void Rans64EncPutBits(Rans64State *r, uint32_t **pptr, uint32_t val, + uint32_t nbits) { + assert(nbits <= 16); + assert(val < (1u << nbits)); + + /* Re-normalize */ + uint64_t x = *r; + uint32_t freq = 1 << (16 - nbits); + uint64_t x_max = ((RANS64_L >> 16) << 32) * freq; + if (x >= x_max) { + *pptr -= 1; + **pptr = (uint32_t)x; + x >>= 32; + Rans64Assert(x < x_max); + } + + /* x = C(s, x) */ + *r = (x << nbits) | val; +} + +inline uint32_t Rans64DecGetBits(Rans64State *r, uint32_t **pptr, + uint32_t n_bits) { + uint64_t x = *r; + uint32_t val = x & ((1u << n_bits) - 1); + + /* Re-normalize */ + x = x >> n_bits; + if (x < RANS64_L) { + x = (x << 32) | **pptr; + *pptr += 1; + Rans64Assert(x >= RANS64_L); + } + + *r = x; + + return val; +} +} // namespace + +void BufferedRansEncoder::encode_with_indexes( + const std::vector &symbols, const std::vector &indexes, + const std::vector> &cdfs, + const std::vector &cdfs_sizes, + const std::vector &offsets) { + assert(cdfs.size() == cdfs_sizes.size()); + assert_cdfs(cdfs, cdfs_sizes); + + // backward loop on symbols from the end; + for (size_t i = 0; i < symbols.size(); ++i) { + const int32_t cdf_idx = indexes[i]; + assert(cdf_idx >= 0); + assert(cdf_idx < static_cast(cdfs.size())); + + const auto &cdf = cdfs[cdf_idx]; + + const int32_t max_value = cdfs_sizes[cdf_idx] - 2; + assert(max_value >= 0); + assert((max_value + 1) < static_cast(cdf.size())); + + int32_t value = symbols[i] - offsets[cdf_idx]; + + uint32_t raw_val = 0; + if (value < 0) { + raw_val = -2 * value - 1; + value = max_value; + } else if (value >= max_value) { + raw_val = 2 * (value - max_value); + value = max_value; + } + + assert(value >= 0); + assert(value < cdfs_sizes[cdf_idx] - 1); + + _syms.push_back({static_cast(cdf[value]), + static_cast(cdf[value + 1] - cdf[value]), + false}); + + /* Bypass coding mode (value == max_value -> sentinel flag) */ + if (value == max_value) { + /* Determine the number of bypasses (in bypass_precision size) needed to + * encode the raw value. */ + int32_t n_bypass = 0; + while ((raw_val >> (n_bypass * bypass_precision)) != 0) { + ++n_bypass; + } + + /* Encode number of bypasses */ + int32_t val = n_bypass; + while (val >= max_bypass_val) { + _syms.push_back({max_bypass_val, max_bypass_val + 1, true}); + val -= max_bypass_val; + } + _syms.push_back( + {static_cast(val), static_cast(val + 1), true}); + + /* Encode raw value */ + for (int32_t j = 0; j < n_bypass; ++j) { + const int32_t val1 = + (raw_val >> (j * bypass_precision)) & max_bypass_val; + _syms.push_back({static_cast(val1), + static_cast(val1 + 1), true}); + } + } + } +} + +py::bytes BufferedRansEncoder::flush() { + Rans64State rans; + Rans64EncInit(&rans); + + std::vector output(_syms.size(), 0xCC); // too much space ? + uint32_t *ptr = output.data() + output.size(); + assert(ptr != nullptr); + + while (!_syms.empty()) { + const RansSymbol sym = _syms.back(); + + if (!sym.bypass) { + Rans64EncPut(&rans, &ptr, sym.start, sym.range, precision); + } else { + // unlikely... + Rans64EncPutBits(&rans, &ptr, sym.start, bypass_precision); + } + _syms.pop_back(); + } + + Rans64EncFlush(&rans, &ptr); + + const int nbytes = static_cast( + std::distance(ptr, output.data() + output.size()) * sizeof(uint32_t)); + return std::string(reinterpret_cast(ptr), nbytes); +} + +py::bytes +RansEncoder::encode_with_indexes(const std::vector &symbols, + const std::vector &indexes, + const std::vector> &cdfs, + const std::vector &cdfs_sizes, + const std::vector &offsets) { + + BufferedRansEncoder buffered_rans_enc; + buffered_rans_enc.encode_with_indexes(symbols, indexes, cdfs, cdfs_sizes, + offsets); + return buffered_rans_enc.flush(); +} + +std::vector +RansDecoder::decode_with_indexes(const std::string &encoded, + const std::vector &indexes, + const std::vector> &cdfs, + const std::vector &cdfs_sizes, + const std::vector &offsets) { + assert(cdfs.size() == cdfs_sizes.size()); + assert_cdfs(cdfs, cdfs_sizes); + + std::vector output(indexes.size()); + + Rans64State rans; + uint32_t *ptr = (uint32_t *)encoded.data(); + assert(ptr != nullptr); + Rans64DecInit(&rans, &ptr); + + for (int i = 0; i < static_cast(indexes.size()); ++i) { + const int32_t cdf_idx = indexes[i]; + assert(cdf_idx >= 0); + assert(cdf_idx < static_cast(cdfs.size())); + + const auto &cdf = cdfs[cdf_idx]; + + const int32_t max_value = cdfs_sizes[cdf_idx] - 2; + assert(max_value >= 0); + assert((max_value + 1) < static_cast(cdf.size())); + + const int32_t offset = offsets[cdf_idx]; + + const uint32_t cum_freq = Rans64DecGet(&rans, precision); + + const auto cdf_end = cdf.begin() + cdfs_sizes[cdf_idx]; + const auto it = std::find_if(cdf.begin(), cdf_end, [cum_freq](int v) { + return static_cast(v) > cum_freq; + }); + assert(it != cdf_end + 1); + const uint32_t s = + static_cast(std::distance(cdf.begin(), it) - 1); + + Rans64DecAdvance(&rans, &ptr, cdf[s], cdf[s + 1] - cdf[s], precision); + + int32_t value = static_cast(s); + + if (value == max_value) { + /* Bypass decoding mode */ + int32_t val = Rans64DecGetBits(&rans, &ptr, bypass_precision); + int32_t n_bypass = val; + + while (val == max_bypass_val) { + val = Rans64DecGetBits(&rans, &ptr, bypass_precision); + n_bypass += val; + } + + int32_t raw_val = 0; + for (int j = 0; j < n_bypass; ++j) { + val = Rans64DecGetBits(&rans, &ptr, bypass_precision); + assert(val <= max_bypass_val); + raw_val |= val << (j * bypass_precision); + } + value = raw_val >> 1; + if (raw_val & 1) { + value = -value - 1; + } else { + value += max_value; + } + } + + output[i] = value + offset; + } + + return output; +} + +void RansDecoder::set_stream(const std::string &encoded) { + _stream = encoded; + uint32_t *ptr = (uint32_t *)_stream.data(); + assert(ptr != nullptr); + _ptr = ptr; + Rans64DecInit(&_rans, &_ptr); +} + + +std::vector +RansDecoder::decode_stream(const std::vector &indexes, + const std::vector> &cdfs, + const std::vector &cdfs_sizes, + const std::vector &offsets) { + assert(cdfs.size() == cdfs_sizes.size()); + assert_cdfs(cdfs, cdfs_sizes); + + std::vector output(indexes.size()); + + assert(_ptr != nullptr); + + for (int i = 0; i < static_cast(indexes.size()); ++i) { + const int32_t cdf_idx = indexes[i]; + assert(cdf_idx >= 0); + assert(cdf_idx < static_cast(cdfs.size())); + + const auto &cdf = cdfs[cdf_idx]; + + const int32_t max_value = cdfs_sizes[cdf_idx] - 2; + assert(max_value >= 0); + assert((max_value + 1) < static_cast(cdf.size())); + + const int32_t offset = offsets[cdf_idx]; + + const uint32_t cum_freq = Rans64DecGet(&_rans, precision); + + const auto cdf_end = cdf.begin() + cdfs_sizes[cdf_idx]; + const auto it = std::find_if(cdf.begin(), cdf_end, [cum_freq](int v) { + return static_cast(v) > cum_freq; + }); + assert(it != cdf_end + 1); + const uint32_t s = + static_cast(std::distance(cdf.begin(), it) - 1); + + Rans64DecAdvance(&_rans, &_ptr, cdf[s], cdf[s + 1] - cdf[s], precision); + + int32_t value = static_cast(s); + + if (value == max_value) { + /* Bypass decoding mode */ + int32_t val = Rans64DecGetBits(&_rans, &_ptr, bypass_precision); + int32_t n_bypass = val; + + while (val == max_bypass_val) { + val = Rans64DecGetBits(&_rans, &_ptr, bypass_precision); + n_bypass += val; + } + + int32_t raw_val = 0; + for (int j = 0; j < n_bypass; ++j) { + val = Rans64DecGetBits(&_rans, &_ptr, bypass_precision); + assert(val <= max_bypass_val); + raw_val |= val << (j * bypass_precision); + } + value = raw_val >> 1; + if (raw_val & 1) { + value = -value - 1; + } else { + value += max_value; + } + } + + output[i] = value + offset; + } + + return output; +} + +PYBIND11_MODULE(MLCodec_rans, m) { + m.attr("__name__") = "MLCodec_rans"; + + m.doc() = "range Asymmetric Numeral System python bindings"; + + py::class_(m, "BufferedRansEncoder") + .def(py::init<>()) + .def("encode_with_indexes", &BufferedRansEncoder::encode_with_indexes) + .def("flush", &BufferedRansEncoder::flush); + + py::class_(m, "RansEncoder") + .def(py::init<>()) + .def("encode_with_indexes", &RansEncoder::encode_with_indexes); + + py::class_(m, "RansDecoder") + .def(py::init<>()) + .def("set_stream", &RansDecoder::set_stream) + .def("decode_stream", &RansDecoder::decode_stream) + .def("decode_with_indexes", &RansDecoder::decode_with_indexes, + "Decode a string to a list of symbols"); +} diff --git a/src/cpp/rans/rans_interface.hpp b/src/cpp/rans/rans_interface.hpp new file mode 100644 index 0000000..49c35fb --- /dev/null +++ b/src/cpp/rans/rans_interface.hpp @@ -0,0 +1,113 @@ +/* Copyright 2020 InterDigital Communications, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +#ifdef __GNUC__ +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wpedantic" +#pragma GCC diagnostic ignored "-Wsign-compare" +#elif _MSC_VER +#pragma warning(push, 0) +#endif + +#include + +#ifdef __GNUC__ +#pragma GCC diagnostic pop +#elif _MSC_VER +#pragma warning(pop) +#endif + +namespace py = pybind11; + +struct RansSymbol { + uint16_t start; + uint16_t range; + bool bypass; // bypass flag to write raw bits to the stream +}; + +/* NOTE: Warning, we buffer everything for now... In case of large files we + * should split the bitstream into chunks... Or for a memory-bounded encoder + **/ +class BufferedRansEncoder { +public: + BufferedRansEncoder() = default; + + BufferedRansEncoder(const BufferedRansEncoder &) = delete; + BufferedRansEncoder(BufferedRansEncoder &&) = delete; + BufferedRansEncoder &operator=(const BufferedRansEncoder &) = delete; + BufferedRansEncoder &operator=(BufferedRansEncoder &&) = delete; + + void encode_with_indexes(const std::vector &symbols, + const std::vector &indexes, + const std::vector> &cdfs, + const std::vector &cdfs_sizes, + const std::vector &offsets); + py::bytes flush(); + +private: + std::vector _syms; +}; + +class RansEncoder { +public: + RansEncoder() = default; + + RansEncoder(const RansEncoder &) = delete; + RansEncoder(RansEncoder &&) = delete; + RansEncoder &operator=(const RansEncoder &) = delete; + RansEncoder &operator=(RansEncoder &&) = delete; + + py::bytes encode_with_indexes(const std::vector &symbols, + const std::vector &indexes, + const std::vector> &cdfs, + const std::vector &cdfs_sizes, + const std::vector &offsets); +}; + +class RansDecoder { +public: + RansDecoder() = default; + + RansDecoder(const RansDecoder &) = delete; + RansDecoder(RansDecoder &&) = delete; + RansDecoder &operator=(const RansDecoder &) = delete; + RansDecoder &operator=(RansDecoder &&) = delete; + + std::vector + decode_with_indexes(const std::string &encoded, + const std::vector &indexes, + const std::vector> &cdfs, + const std::vector &cdfs_sizes, + const std::vector &offsets); + + void set_stream(const std::string &stream); + + std::vector + decode_stream(const std::vector &indexes, + const std::vector> &cdfs, + const std::vector &cdfs_sizes, + const std::vector &offsets); + + +private: + Rans64State _rans; + std::string _stream; + uint32_t *_ptr; +}; diff --git a/src/entropy_models/entropy_models.py b/src/entropy_models/entropy_models.py new file mode 100644 index 0000000..6111cac --- /dev/null +++ b/src/entropy_models/entropy_models.py @@ -0,0 +1,516 @@ +import numpy as np +import scipy.stats +import torch +import torch.nn as nn +import torch.nn.functional as F + +# isort: off; pylint: disable=E0611,E0401 +from ..ops.bound_ops import LowerBound + +# isort: on; pylint: enable=E0611,E0401 + + +class _EntropyCoder: + """Proxy class to an actual entropy coder class.""" + + def __init__(self): + from .MLCodec_rans import RansEncoder, RansDecoder + + encoder = RansEncoder() + decoder = RansDecoder() + self._encoder = encoder + self._decoder = decoder + + def encode_with_indexes(self, *args, **kwargs): + return self._encoder.encode_with_indexes(*args, **kwargs) + + def decode_with_indexes(self, *args, **kwargs): + return self._decoder.decode_with_indexes(*args, **kwargs) + + +def pmf_to_quantized_cdf(pmf, precision=16): + from .MLCodec_CXX import pmf_to_quantized_cdf as _pmf_to_quantized_cdf + cdf = _pmf_to_quantized_cdf(pmf.tolist(), precision) + cdf = torch.IntTensor(cdf) + return cdf + + +class EntropyModel(nn.Module): + r"""Entropy model base class. + + Args: + likelihood_bound (float): minimum likelihood bound + entropy_coder (str, optional): set the entropy coder to use, use default + one if None + entropy_coder_precision (int): set the entropy coder precision + """ + + def __init__( + self, likelihood_bound=1e-9, entropy_coder=None, entropy_coder_precision=16 + ): + super().__init__() + self.entropy_coder = None + self.entropy_coder_precision = int(entropy_coder_precision) + + self.use_likelihood_bound = likelihood_bound > 0 + if self.use_likelihood_bound: + self.likelihood_lower_bound = LowerBound(likelihood_bound) + + # to be filled on update() + self.register_buffer("_offset", torch.IntTensor()) + self.register_buffer("_quantized_cdf", torch.IntTensor()) + self.register_buffer("_cdf_length", torch.IntTensor()) + + def forward(self, *args): + raise NotImplementedError() + + def _check_entropy_coder(self): + if self.entropy_coder == None: + self.entropy_coder = _EntropyCoder() + + + def _quantize(self, inputs, mode, means=None): + if mode not in ("dequantize", "symbols"): + raise ValueError(f'Invalid quantization mode: "{mode}"') + + outputs = inputs.clone() + if means is not None: + outputs -= means + + outputs = torch.round(outputs) + + if mode == "dequantize": + if means is not None: + outputs += means + return outputs + + assert mode == "symbols", mode + outputs = outputs.int() + return outputs + + @staticmethod + def _dequantize(inputs, means=None): + if means is not None: + outputs = inputs.type_as(means) + outputs += means + else: + outputs = inputs.float() + return outputs + + def _pmf_to_cdf(self, pmf, tail_mass, pmf_length, max_length): + cdf = torch.zeros((len(pmf_length), max_length + 2), dtype=torch.int32) + for i, p in enumerate(pmf): + prob = torch.cat((p[: pmf_length[i]], tail_mass[i]), dim=0) + _cdf = pmf_to_quantized_cdf(prob, self.entropy_coder_precision) + cdf[i, : _cdf.size(0)] = _cdf + return cdf + + def _check_cdf_size(self): + if self._quantized_cdf.numel() == 0: + raise ValueError("Uninitialized CDFs. Run update() first") + + if len(self._quantized_cdf.size()) != 2: + raise ValueError(f"Invalid CDF size {self._quantized_cdf.size()}") + + def _check_offsets_size(self): + if self._offset.numel() == 0: + raise ValueError("Uninitialized offsets. Run update() first") + + if len(self._offset.size()) != 1: + raise ValueError(f"Invalid offsets size {self._offset.size()}") + + def _check_cdf_length(self): + if self._cdf_length.numel() == 0: + raise ValueError("Uninitialized CDF lengths. Run update() first") + + if len(self._cdf_length.size()) != 1: + raise ValueError(f"Invalid offsets size {self._cdf_length.size()}") + + def compress(self, inputs, indexes, means=None): + """ + Compress input tensors to char strings. + + Args: + inputs (torch.Tensor): input tensors + indexes (torch.IntTensor): tensors CDF indexes + means (torch.Tensor, optional): optional tensor means + """ + symbols = self._quantize(inputs, "symbols", means) + + if len(inputs.size()) != 4: + raise ValueError("Invalid `inputs` size. Expected a 4-D tensor.") + + if inputs.size() != indexes.size(): + raise ValueError("`inputs` and `indexes` should have the same size.") + + self._check_cdf_size() + self._check_cdf_length() + self._check_offsets_size() + + strings = [] + self._check_entropy_coder() + for i in range(symbols.size(0)): + rv = self.entropy_coder.encode_with_indexes( + symbols[i].reshape(-1).int().tolist(), + indexes[i].reshape(-1).int().tolist(), + self._quantized_cdf.tolist(), + self._cdf_length.reshape(-1).int().tolist(), + self._offset.reshape(-1).int().tolist(), + ) + strings.append(rv) + return strings + + def decompress(self, strings, indexes, means=None): + """ + Decompress char strings to tensors. + + Args: + strings (str): compressed tensors + indexes (torch.IntTensor): tensors CDF indexes + means (torch.Tensor, optional): optional tensor means + """ + + if not isinstance(strings, (tuple, list)): + raise ValueError("Invalid `strings` parameter type.") + + if not len(strings) == indexes.size(0): + raise ValueError("Invalid strings or indexes parameters") + + if len(indexes.size()) != 4: + raise ValueError("Invalid `indexes` size. Expected a 4-D tensor.") + + self._check_cdf_size() + self._check_cdf_length() + self._check_offsets_size() + + if means is not None: + if means.size()[:-2] != indexes.size()[:-2]: + raise ValueError("Invalid means or indexes parameters") + if means.size() != indexes.size() and ( + means.size(2) != 1 or means.size(3) != 1 + ): + raise ValueError("Invalid means parameters") + + cdf = self._quantized_cdf + outputs = cdf.new(indexes.size()) + self._check_entropy_coder() + for i, s in enumerate(strings): + values = self.entropy_coder.decode_with_indexes( + s, + indexes[i].reshape(-1).int().tolist(), + cdf.tolist(), + self._cdf_length.reshape(-1).int().tolist(), + self._offset.reshape(-1).int().tolist(), + ) + outputs[i] = torch.Tensor(values).reshape(outputs[i].size()) + outputs = self._dequantize(outputs, means) + return outputs + + +class EntropyBottleneck(EntropyModel): + r"""Entropy bottleneck layer, introduced by J. Ballé, D. Minnen, S. Singh, + S. J. Hwang, N. Johnston, in `"Variational image compression with a scale + hyperprior" `_. + + This is a re-implementation of the entropy bottleneck layer in + *tensorflow/compression*. See the original paper and the `tensorflow + documentation + `__ + for an introduction. + """ + + def __init__( + self, + channels, + *args, + tail_mass=1e-9, + init_scale=10, + filters=(3, 3, 3, 3), + **kwargs, + ): + super().__init__(*args, **kwargs) + + self.channels = int(channels) + self.filters = tuple(int(f) for f in filters) + self.init_scale = float(init_scale) + self.tail_mass = float(tail_mass) + + # Create parameters + self._biases = nn.ParameterList() + self._factors = nn.ParameterList() + self._matrices = nn.ParameterList() + + filters = (1,) + self.filters + (1,) + scale = self.init_scale ** (1 / (len(self.filters) + 1)) + channels = self.channels + + for i in range(len(self.filters) + 1): + init = np.log(np.expm1(1 / scale / filters[i + 1])) + matrix = torch.Tensor(channels, filters[i + 1], filters[i]) + matrix.data.fill_(init) + self._matrices.append(nn.Parameter(matrix)) + + bias = torch.Tensor(channels, filters[i + 1], 1) + nn.init.uniform_(bias, -0.5, 0.5) + self._biases.append(nn.Parameter(bias)) + + if i < len(self.filters): + factor = torch.Tensor(channels, filters[i + 1], 1) + nn.init.zeros_(factor) + self._factors.append(nn.Parameter(factor)) + + self.quantiles = nn.Parameter(torch.Tensor(channels, 1, 3)) + init = torch.Tensor([-self.init_scale, 0, self.init_scale]) + self.quantiles.data = init.repeat(self.quantiles.size(0), 1, 1) + + target = np.log(2 / self.tail_mass - 1) + self.register_buffer("target", torch.Tensor([-target, 0, target])) + + def _medians(self): + medians = self.quantiles[:, :, 1:2] + return medians + + def update(self, force=False): + # Check if we need to update the bottleneck parameters, the offsets are + # only computed and stored when the conditonal model is update()'d. + if self._offset.numel() > 0 and not force: # pylint: disable=E0203 + return + + medians = self.quantiles[:, 0, 1] + + minima = medians - self.quantiles[:, 0, 0] + minima = torch.ceil(minima).int() + minima = torch.clamp(minima, min=0) + + maxima = self.quantiles[:, 0, 2] - medians + maxima = torch.ceil(maxima).int() + maxima = torch.clamp(maxima, min=0) + + self._offset = -minima + + pmf_start = medians - minima + pmf_length = maxima + minima + 1 + + max_length = pmf_length.max() + device = pmf_start.device + samples = torch.arange(max_length, device=device) + + samples = samples[None, :] + pmf_start[:, None, None] + + half = float(0.5) + + lower = self._logits_cumulative(samples - half, stop_gradient=True) + upper = self._logits_cumulative(samples + half, stop_gradient=True) + sign = -torch.sign(lower + upper) + pmf = torch.abs(torch.sigmoid(sign * upper) - torch.sigmoid(sign * lower)) + + pmf = pmf[:, 0, :] + tail_mass = torch.sigmoid(lower[:, 0, :1]) + torch.sigmoid(-upper[:, 0, -1:]) + + quantized_cdf = self._pmf_to_cdf(pmf, tail_mass, pmf_length, max_length) + self._quantized_cdf = quantized_cdf + self._cdf_length = pmf_length + 2 + + + def _logits_cumulative(self, inputs, stop_gradient): + # TorchScript not yet working (nn.Mmodule indexing not supported) + logits = inputs + for i in range(len(self.filters) + 1): + matrix = self._matrices[i] + if stop_gradient: + matrix = matrix.detach() + logits = torch.matmul(F.softplus(matrix), logits) + + bias = self._biases[i] + if stop_gradient: + bias = bias.detach() + logits += bias + + if i < len(self._factors): + factor = self._factors[i] + if stop_gradient: + factor = factor.detach() + logits += torch.tanh(factor) * torch.tanh(logits) + return logits + + @torch.jit.unused + def _likelihood(self, inputs): + half = float(0.5) + v0 = inputs - half + v1 = inputs + half + lower = self._logits_cumulative(v0, stop_gradient=False) + upper = self._logits_cumulative(v1, stop_gradient=False) + sign = -torch.sign(lower + upper) + sign = sign.detach() + likelihood = torch.abs( + torch.sigmoid(sign * upper) - torch.sigmoid(sign * lower) + ) + return likelihood + + def forward(self, x): + # Convert to (channels, ... , batch) format + x = x.permute(1, 2, 3, 0).contiguous() + shape = x.size() + values = x.reshape(x.size(0), 1, -1) + + # Add noise or quantize + + outputs = self._quantize( + values, "dequantize", self._medians() + ) + + likelihood = self._likelihood(outputs) + if self.use_likelihood_bound: + likelihood = self.likelihood_lower_bound(likelihood) + + # Convert back to input tensor shape + outputs = outputs.reshape(shape) + outputs = outputs.permute(3, 0, 1, 2).contiguous() + + likelihood = likelihood.reshape(shape) + likelihood = likelihood.permute(3, 0, 1, 2).contiguous() + + return outputs, likelihood + + @staticmethod + def _build_indexes(size): + N, C, H, W = size + indexes = torch.arange(C).view(1, -1, 1, 1) + indexes = indexes.int() + return indexes.repeat(N, 1, H, W) + + def compress(self, x): + indexes = self._build_indexes(x.size()) + medians = self._medians().detach().view(1, -1, 1, 1) + return super().compress(x, indexes, medians) + + def decompress(self, strings, size): + output_size = (len(strings), self._quantized_cdf.size(0), size[0], size[1]) + indexes = self._build_indexes(output_size) + medians = self._medians().detach().view(1, -1, 1, 1) + return super().decompress(strings, indexes, medians) + + +class GaussianConditional(EntropyModel): + r"""Gaussian conditional layer, introduced by J. Ballé, D. Minnen, S. Singh, + S. J. Hwang, N. Johnston, in `"Variational image compression with a scale + hyperprior" `_. + + This is a re-implementation of the Gaussian conditional layer in + *tensorflow/compression*. See the `tensorflow documentation + `__ + for more information. + """ + + def __init__(self, scale_table, *args, scale_bound=0.11, tail_mass=1e-9, **kwargs): + super().__init__(*args, **kwargs) + + if not isinstance(scale_table, (type(None), list, tuple)): + raise ValueError(f'Invalid type for scale_table "{type(scale_table)}"') + + if isinstance(scale_table, (list, tuple)) and len(scale_table) < 1: + raise ValueError(f'Invalid scale_table length "{len(scale_table)}"') + + if scale_table and ( + scale_table != sorted(scale_table) or any(s <= 0 for s in scale_table) + ): + raise ValueError(f'Invalid scale_table "({scale_table})"') + + self.register_buffer( + "scale_table", + self._prepare_scale_table(scale_table) if scale_table else torch.Tensor(), + ) + + self.register_buffer( + "scale_bound", + torch.Tensor([float(scale_bound)]) if scale_bound is not None else None, + ) + + self.tail_mass = float(tail_mass) + if scale_bound is None and scale_table: + self.lower_bound_scale = LowerBound(self.scale_table[0]) + elif scale_bound > 0: + self.lower_bound_scale = LowerBound(scale_bound) + else: + raise ValueError("Invalid parameters") + + @staticmethod + def _prepare_scale_table(scale_table): + return torch.Tensor(tuple(float(s) for s in scale_table)) + + def _standardized_cumulative(self, inputs): + half = float(0.5) + const = float(-(2 ** -0.5)) + # Using the complementary error function maximizes numerical precision. + return half * torch.erfc(const * inputs) + + @staticmethod + def _standardized_quantile(quantile): + return scipy.stats.norm.ppf(quantile) + + def update_scale_table(self, scale_table, force=False): + # Check if we need to update the gaussian conditional parameters, the + # offsets are only computed and stored when the conditonal model is + # updated. + if self._offset.numel() > 0 and not force: + return + self.scale_table = self._prepare_scale_table(scale_table) + self.update() + + def update(self): + multiplier = -self._standardized_quantile(self.tail_mass / 2) + pmf_center = torch.ceil(self.scale_table * multiplier).int() + pmf_length = 2 * pmf_center + 1 + max_length = torch.max(pmf_length).item() + + device = pmf_center.device + samples = torch.abs( + torch.arange(max_length, device=device).int() - pmf_center[:, None] + ) + samples_scale = self.scale_table.unsqueeze(1) + samples = samples.float() + samples_scale = samples_scale.float() + upper = self._standardized_cumulative((0.5 - samples) / samples_scale) + lower = self._standardized_cumulative((-0.5 - samples) / samples_scale) + pmf = upper - lower + + tail_mass = 2 * lower[:, :1] + + quantized_cdf = torch.Tensor(len(pmf_length), max_length + 2) + quantized_cdf = self._pmf_to_cdf(pmf, tail_mass, pmf_length, max_length) + self._quantized_cdf = quantized_cdf + self._offset = -pmf_center + self._cdf_length = pmf_length + 2 + + def _likelihood(self, inputs, scales, means=None): + half = float(0.5) + + if means is not None: + values = inputs - means + else: + values = inputs + + scales = self.lower_bound_scale(scales) + + values = torch.abs(values) + upper = self._standardized_cumulative((half - values) / scales) + lower = self._standardized_cumulative((-half - values) / scales) + likelihood = upper - lower + + return likelihood + + def forward(self, inputs, scales, means=None): + outputs = self._quantize( + inputs, "dequantize", means + ) + likelihood = self._likelihood(outputs, scales, means) + if self.use_likelihood_bound: + likelihood = self.likelihood_lower_bound(likelihood) + return outputs, likelihood + + def build_indexes(self, scales): + scales = self.lower_bound_scale(scales) + indexes = scales.new_full(scales.size(), len(self.scale_table) - 1).int() + for s in self.scale_table[:-1]: + indexes -= (scales <= s).int() + return indexes diff --git a/src/entropy_models/video_entropy_models.py b/src/entropy_models/video_entropy_models.py new file mode 100644 index 0000000..b4cc0b1 --- /dev/null +++ b/src/entropy_models/video_entropy_models.py @@ -0,0 +1,316 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import torch +import math +import torch.nn as nn +import torch.nn.functional as F + + +class EntropyCoder(object): + def __init__(self, entropy_coder_precision=16): + super().__init__() + + from .MLCodec_rans import RansEncoder, RansDecoder + self.encoder = RansEncoder() + self.decoder = RansDecoder() + self.entropy_coder_precision = int(entropy_coder_precision) + self._offset = None + self._quantized_cdf = None + self._cdf_length = None + + def encode_with_indexes(self, *args, **kwargs): + return self.encoder.encode_with_indexes(*args, **kwargs) + + def decode_with_indexes(self, *args, **kwargs): + return self.decoder.decode_with_indexes(*args, **kwargs) + + def set_cdf_states(self, offset, quantized_cdf, cdf_length): + self._offset = offset + self._quantized_cdf = quantized_cdf + self._cdf_length = cdf_length + + @staticmethod + def pmf_to_quantized_cdf(pmf, precision=16): + from .MLCodec_CXX import pmf_to_quantized_cdf as _pmf_to_quantized_cdf + cdf = _pmf_to_quantized_cdf(pmf.tolist(), precision) + cdf = torch.IntTensor(cdf) + return cdf + + def pmf_to_cdf(self, pmf, tail_mass, pmf_length, max_length): + cdf = torch.zeros((len(pmf_length), max_length + 2), dtype=torch.int32) + for i, p in enumerate(pmf): + prob = torch.cat((p[: pmf_length[i]], tail_mass[i]), dim=0) + _cdf = self.pmf_to_quantized_cdf(prob, self.entropy_coder_precision) + cdf[i, : _cdf.size(0)] = _cdf + return cdf + + def _check_cdf_size(self): + if self._quantized_cdf.numel() == 0: + raise ValueError("Uninitialized CDFs. Run update() first") + + if len(self._quantized_cdf.size()) != 2: + raise ValueError(f"Invalid CDF size {self._quantized_cdf.size()}") + + def _check_offsets_size(self): + if self._offset.numel() == 0: + raise ValueError("Uninitialized offsets. Run update() first") + + if len(self._offset.size()) != 1: + raise ValueError(f"Invalid offsets size {self._offset.size()}") + + def _check_cdf_length(self): + if self._cdf_length.numel() == 0: + raise ValueError("Uninitialized CDF lengths. Run update() first") + + if len(self._cdf_length.size()) != 1: + raise ValueError(f"Invalid offsets size {self._cdf_length.size()}") + + def compress(self, inputs, indexes): + """ + """ + if len(inputs.size()) != 4: + raise ValueError("Invalid `inputs` size. Expected a 4-D tensor.") + + if inputs.size() != indexes.size(): + raise ValueError("`inputs` and `indexes` should have the same size.") + symbols = inputs.int() + + self._check_cdf_size() + self._check_cdf_length() + self._check_offsets_size() + + assert symbols.size(0) == 1 + rv = self.encode_with_indexes( + symbols[0].reshape(-1).int().tolist(), + indexes[0].reshape(-1).int().tolist(), + self._quantized_cdf.tolist(), + self._cdf_length.reshape(-1).int().tolist(), + self._offset.reshape(-1).int().tolist(), + ) + return rv + + def decompress(self, strings, indexes): + """ + Decompress char strings to tensors. + + Args: + strings (str): compressed tensors + indexes (torch.IntTensor): tensors CDF indexes + """ + + assert indexes.size(0) == 1 + + if len(indexes.size()) != 4: + raise ValueError("Invalid `indexes` size. Expected a 4-D tensor.") + + self._check_cdf_size() + self._check_cdf_length() + self._check_offsets_size() + + cdf = self._quantized_cdf + outputs = cdf.new(indexes.size()) + + values = self.decode_with_indexes( + strings, + indexes[0].reshape(-1).int().tolist(), + self._quantized_cdf.tolist(), + self._cdf_length.reshape(-1).int().tolist(), + self._offset.reshape(-1).int().tolist(), + ) + outputs[0] = torch.Tensor(values).reshape(outputs[0].size()) + return outputs.float() + + def set_stream(self, stream): + self.decoder.set_stream(stream) + + def decode_stream(self, indexes): + rv = self.decoder.decode_stream( + indexes.squeeze().int().tolist(), + self._quantized_cdf.tolist(), + self._cdf_length.reshape(-1).int().tolist(), + self._offset.reshape(-1).int().tolist(), + ) + rv = torch.Tensor(rv).reshape(1, -1, 1, 1) + return rv + + +class Bitparm(nn.Module): + def __init__(self, channel, final=False): + super(Bitparm, self).__init__() + self.final = final + self.h = nn.Parameter(torch.nn.init.normal_( + torch.empty(channel).view(1, -1, 1, 1), 0, 0.01)) + self.b = nn.Parameter(torch.nn.init.normal_( + torch.empty(channel).view(1, -1, 1, 1), 0, 0.01)) + if not final: + self.a = nn.Parameter(torch.nn.init.normal_( + torch.empty(channel).view(1, -1, 1, 1), 0, 0.01)) + else: + self.a = None + + def forward(self, x): + if self.final: + return torch.sigmoid(x * F.softplus(self.h) + self.b) + else: + x = x * F.softplus(self.h) + self.b + return x + torch.tanh(x) * torch.tanh(self.a) + + +class BitEstimator(nn.Module): + def __init__(self, channel): + super(BitEstimator, self).__init__() + self.f1 = Bitparm(channel) + self.f2 = Bitparm(channel) + self.f3 = Bitparm(channel) + self.f4 = Bitparm(channel, True) + self.channel = channel + self.entropy_coder = None + + def forward(self, x): + x = self.f1(x) + x = self.f2(x) + x = self.f3(x) + return self.f4(x) + + def update(self, force=False): + # Check if we need to update the bottleneck parameters, the offsets are + # only computed and stored when the conditonal model is update()'d. + if self.entropy_coder is not None and not force: # pylint: disable=E0203 + return + + self.entropy_coder = EntropyCoder() + with torch.no_grad(): + device = next(self.parameters()).device + medians = torch.zeros((self.channel), device=device) + + minima = medians + 50 + for i in range(50, 1, -1): + samples = torch.zeros_like(medians) - i + samples = samples[None, :, None, None] + probs = self.forward(samples) + probs = torch.squeeze(probs) + minima = torch.where(probs < torch.zeros_like(medians) + 0.0001, + torch.zeros_like(medians) + i, minima) + + maxima = medians + 50 + for i in range(50, 1, -1): + samples = torch.zeros_like(medians) + i + samples = samples[None, :, None, None] + probs = self.forward(samples) + probs = torch.squeeze(probs) + maxima = torch.where(probs > torch.zeros_like(medians) + 0.9999, + torch.zeros_like(medians) + i, maxima) + + minima = minima.int() + maxima = maxima.int() + + offset = -minima + + pmf_start = medians - minima + pmf_length = maxima + minima + 1 + + max_length = pmf_length.max() + device = pmf_start.device + samples = torch.arange(max_length, device=device) + + samples = samples[None, :] + pmf_start[:, None, None] + + half = float(0.5) + + lower = self.forward(samples - half).squeeze(0) + upper = self.forward(samples + half).squeeze(0) + pmf = upper - lower + + pmf = pmf[:, 0, :] + tail_mass = lower[:, 0, :1] + (1.0 - upper[:, 0, -1:]) + + quantized_cdf = self.entropy_coder.pmf_to_cdf(pmf, tail_mass, pmf_length, max_length) + cdf_length = pmf_length + 2 + self.entropy_coder.set_cdf_states(offset, quantized_cdf, cdf_length) + + @staticmethod + def build_indexes(size): + N, C, H, W = size + indexes = torch.arange(C).view(1, -1, 1, 1) + indexes = indexes.int() + return indexes.repeat(N, 1, H, W) + + def compress(self, x): + indexes = self.build_indexes(x.size()) + return self.entropy_coder.compress(x, indexes) + + def decompress(self, strings, size): + output_size = (1, self.entropy_coder._quantized_cdf.size(0), size[0], size[1]) + indexes = self.build_indexes(output_size) + return self.entropy_coder.decompress(strings, indexes) + + +class GaussianEncoder(object): + def __init__(self): + self.scale_table = self.get_scale_table() + self.entropy_coder = None + + @staticmethod + def get_scale_table(min=0.01, max=16, levels=64): # pylint: disable=W0622 + return torch.exp(torch.linspace(math.log(min), math.log(max), levels)) + + def update(self, force=False): + if self.entropy_coder is not None and not force: + return + self.entropy_coder = EntropyCoder() + + pmf_center = torch.zeros_like(self.scale_table) + 50 + scales = torch.zeros_like(pmf_center) + self.scale_table + mu = torch.zeros_like(scales) + gaussian = torch.distributions.laplace.Laplace(mu, scales) + for i in range(50, 1, -1): + samples = torch.zeros_like(pmf_center) + i + probs = gaussian.cdf(samples) + probs = torch.squeeze(probs) + pmf_center = torch.where(probs > torch.zeros_like(pmf_center) + 0.9999, + torch.zeros_like(pmf_center) + i, pmf_center) + + pmf_center = pmf_center.int() + pmf_length = 2 * pmf_center + 1 + max_length = torch.max(pmf_length).item() + + device = pmf_center.device + samples = torch.arange(max_length, device=device) - pmf_center[:, None] + samples = samples.float() + + scales = torch.zeros_like(samples) + self.scale_table[:, None] + mu = torch.zeros_like(scales) + gaussian = torch.distributions.laplace.Laplace(mu, scales) + + upper = gaussian.cdf(samples + 0.5) + lower = gaussian.cdf(samples - 0.5) + pmf = upper - lower + + tail_mass = 2 * lower[:, :1] + + quantized_cdf = torch.Tensor(len(pmf_length), max_length + 2) + quantized_cdf = self.entropy_coder.pmf_to_cdf(pmf, tail_mass, pmf_length, max_length) + self.entropy_coder.set_cdf_states(-pmf_center, quantized_cdf, pmf_length+2) + + def build_indexes(self, scales): + scales = torch.maximum(scales, torch.zeros_like(scales) + 1e-5) + indexes = scales.new_full(scales.size(), len(self.scale_table) - 1).int() + for s in self.scale_table[:-1]: + indexes -= (scales <= s).int() + return indexes + + def compress(self, x, scales): + indexes = self.build_indexes(scales) + return self.entropy_coder.compress(x, indexes) + + def decompress(self, strings, scales): + indexes = self.build_indexes(scales) + return self.entropy_coder.decompress(strings, indexes) + + def set_stream(self, stream): + self.entropy_coder.set_stream(stream) + + def decode_stream(self, scales): + indexes = self.build_indexes(scales) + return self.entropy_coder.decode_stream(indexes) diff --git a/src/layers/gdn.py b/src/layers/gdn.py new file mode 100644 index 0000000..b053cb2 --- /dev/null +++ b/src/layers/gdn.py @@ -0,0 +1,67 @@ +# Copyright 2020 InterDigital Communications, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ..ops.parametrizers import NonNegativeParametrizer + + +class GDN(nn.Module): + r"""Generalized Divisive Normalization layer. + + Introduced in `"Density Modeling of Images Using a Generalized Normalization + Transformation" `_, + by Balle Johannes, Valero Laparra, and Eero P. Simoncelli, (2016). + + .. math:: + + y[i] = \frac{x[i]}{\sqrt{\beta[i] + \sum_j(\gamma[j, i] * x[j]^2)}} + + """ + + def __init__(self, in_channels, inverse=False, beta_min=1e-6, gamma_init=0.1): + super().__init__() + + beta_min = float(beta_min) + gamma_init = float(gamma_init) + self.inverse = bool(inverse) + + self.beta_reparam = NonNegativeParametrizer(minimum=beta_min) + beta = torch.ones(in_channels) + beta = self.beta_reparam.init(beta) + self.beta = nn.Parameter(beta) + + self.gamma_reparam = NonNegativeParametrizer() + gamma = gamma_init * torch.eye(in_channels) + gamma = self.gamma_reparam.init(gamma) + self.gamma = nn.Parameter(gamma) + + def forward(self, x): + _, C, _, _ = x.size() + + beta = self.beta_reparam(self.beta) + gamma = self.gamma_reparam(self.gamma) + gamma = gamma.reshape(C, C, 1, 1) + norm = F.conv2d(x ** 2, gamma, beta) + + if self.inverse: + norm = torch.sqrt(norm) + else: + norm = torch.rsqrt(norm) + + out = x * norm + + return out diff --git a/src/layers/layers.py b/src/layers/layers.py new file mode 100644 index 0000000..4640364 --- /dev/null +++ b/src/layers/layers.py @@ -0,0 +1,152 @@ +# Copyright 2020 InterDigital Communications, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn + +from .gdn import GDN + + +class MaskedConv2d(nn.Conv2d): + r"""Masked 2D convolution implementation, mask future "unseen" pixels. + Useful for building auto-regressive network components. + + Introduced in `"Conditional Image Generation with PixelCNN Decoders" + `_. + + Inherits the same arguments as a `nn.Conv2d`. Use `mask_type='A'` for the + first layer (which also masks the "current pixel"), `mask_type='B'` for the + following layers. + """ + + def __init__(self, *args, mask_type="A", **kwargs): + super().__init__(*args, **kwargs) + + if mask_type not in ("A", "B"): + raise ValueError(f'Invalid "mask_type" value "{mask_type}"') + + self.register_buffer("mask", torch.ones_like(self.weight.data)) + _, _, h, w = self.mask.size() + self.mask[:, :, h // 2, w // 2 + (mask_type == "B"):] = 0 + self.mask[:, :, h // 2 + 1:] = 0 + + def forward(self, x): + # TODO(begaintj): weight assigment is not supported by torchscript + self.weight.data *= self.mask + return super().forward(x) + + +def conv3x3(in_ch, out_ch, stride=1): + """3x3 convolution with padding.""" + return nn.Conv2d(in_ch, out_ch, kernel_size=3, stride=stride, padding=1) + + +def subpel_conv3x3(in_ch, out_ch, r=1): + """3x3 sub-pixel convolution for up-sampling.""" + return nn.Sequential( + nn.Conv2d(in_ch, out_ch * r ** 2, kernel_size=3, padding=1), nn.PixelShuffle(r) + ) + + +def conv1x1(in_ch, out_ch, stride=1): + """1x1 convolution.""" + return nn.Conv2d(in_ch, out_ch, kernel_size=1, stride=stride) + + +class ResidualBlockWithStride(nn.Module): + """Residual block with a stride on the first convolution. + + Args: + in_ch (int): number of input channels + out_ch (int): number of output channels + stride (int): stride value (default: 2) + """ + + def __init__(self, in_ch, out_ch, stride=2): + super().__init__() + self.conv1 = conv3x3(in_ch, out_ch, stride=stride) + self.leaky_relu = nn.LeakyReLU(inplace=True) + self.conv2 = conv3x3(out_ch, out_ch) + self.gdn = GDN(out_ch) + if stride != 1: + self.downsample = conv1x1(in_ch, out_ch, stride=stride) + else: + self.downsample = None + + def forward(self, x): + identity = x + out = self.conv1(x) + out = self.leaky_relu(out) + out = self.conv2(out) + out = self.gdn(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + return out + + +class ResidualBlockUpsample(nn.Module): + """Residual block with sub-pixel upsampling on the last convolution. + + Args: + in_ch (int): number of input channels + out_ch (int): number of output channels + upsample (int): upsampling factor (default: 2) + """ + + def __init__(self, in_ch, out_ch, upsample=2): + super().__init__() + self.subpel_conv = subpel_conv3x3(in_ch, out_ch, upsample) + self.leaky_relu = nn.LeakyReLU(inplace=True) + self.conv = conv3x3(out_ch, out_ch) + self.igdn = GDN(out_ch, inverse=True) + self.upsample = subpel_conv3x3(in_ch, out_ch, upsample) + + def forward(self, x): + identity = x + out = self.subpel_conv(x) + out = self.leaky_relu(out) + out = self.conv(out) + out = self.igdn(out) + identity = self.upsample(x) + out += identity + return out + + +class ResidualBlock(nn.Module): + """Simple residual block with two 3x3 convolutions. + + Args: + in_ch (int): number of input channels + out_ch (int): number of output channels + """ + + def __init__(self, in_ch, out_ch): + super().__init__() + self.conv1 = conv3x3(in_ch, out_ch) + self.leaky_relu = nn.LeakyReLU(inplace=True) + self.conv2 = conv3x3(out_ch, out_ch) + + def forward(self, x): + identity = x + + out = self.conv1(x) + out = self.leaky_relu(out) + out = self.conv2(out) + out = self.leaky_relu(out) + + out = out + identity + return out \ No newline at end of file diff --git a/src/models/DCVC_net.py b/src/models/DCVC_net.py new file mode 100644 index 0000000..4e58468 --- /dev/null +++ b/src/models/DCVC_net.py @@ -0,0 +1,488 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import torch +import math +import torch.nn as nn +import torch.nn.functional as F + +from .video_net import ME_Spynet, GDN, flow_warp, ResBlock, ResBlock_LeakyReLU_0_Point_1 +from ..entropy_models.video_entropy_models import BitEstimator, GaussianEncoder +from ..utils.stream_helper import get_downsampled_shape +from ..layers.layers import MaskedConv2d, subpel_conv3x3 + + +class DCVC_net(nn.Module): + def __init__(self): + super().__init__() + out_channel_mv = 128 + out_channel_N = 64 + out_channel_M = 96 + + self.out_channel_mv = out_channel_mv + self.out_channel_N = out_channel_N + self.out_channel_M = out_channel_M + + self.bitEstimator_z = BitEstimator(out_channel_N) + self.bitEstimator_z_mv = BitEstimator(out_channel_N) + + self.feature_extract = nn.Sequential( + nn.Conv2d(3, out_channel_N, 3, stride=1, padding=1), + ResBlock(out_channel_N, out_channel_N, 3), + ) + + self.context_refine = nn.Sequential( + ResBlock(out_channel_N, out_channel_N, 3), + nn.Conv2d(out_channel_N, out_channel_N, 3, stride=1, padding=1), + ) + + self.gaussian_encoder = GaussianEncoder() + + self.mvEncoder = nn.Sequential( + nn.Conv2d(2, out_channel_mv, 3, stride=2, padding=1), + GDN(out_channel_mv), + nn.Conv2d(out_channel_mv, out_channel_mv, 3, stride=2, padding=1), + GDN(out_channel_mv), + nn.Conv2d(out_channel_mv, out_channel_mv, 3, stride=2, padding=1), + GDN(out_channel_mv), + nn.Conv2d(out_channel_mv, out_channel_mv, 3, stride=2, padding=1), + ) + + self.mvDecoder_part1 = nn.Sequential( + nn.ConvTranspose2d(out_channel_mv, out_channel_mv, 3, + stride=2, padding=1, output_padding=1), + GDN(out_channel_mv, inverse=True), + nn.ConvTranspose2d(out_channel_mv, out_channel_mv, 3, + stride=2, padding=1, output_padding=1), + GDN(out_channel_mv, inverse=True), + nn.ConvTranspose2d(out_channel_mv, out_channel_mv, 3, + stride=2, padding=1, output_padding=1), + GDN(out_channel_mv, inverse=True), + nn.ConvTranspose2d(out_channel_mv, 2, 3, stride=2, padding=1, output_padding=1), + ) + + self.mvDecoder_part2 = nn.Sequential( + nn.Conv2d(5, 64, 3, stride=1, padding=1), + nn.LeakyReLU(negative_slope=0.1), + nn.Conv2d(64, 64, 3, stride=1, padding=1), + nn.LeakyReLU(negative_slope=0.1), + nn.Conv2d(64, 64, 3, stride=1, padding=1), + nn.LeakyReLU(negative_slope=0.1), + nn.Conv2d(64, 64, 3, stride=1, padding=1), + nn.LeakyReLU(negative_slope=0.1), + nn.Conv2d(64, 64, 3, stride=1, padding=1), + nn.LeakyReLU(negative_slope=0.1), + nn.Conv2d(64, 64, 3, stride=1, padding=1), + nn.LeakyReLU(negative_slope=0.1), + nn.Conv2d(64, 2, 3, stride=1, padding=1), + ) + + self.contextualEncoder = nn.Sequential( + nn.Conv2d(out_channel_N+3, out_channel_N, 5, stride=2, padding=2), + GDN(out_channel_N), + ResBlock_LeakyReLU_0_Point_1(out_channel_N), + nn.Conv2d(out_channel_N, out_channel_N, 5, stride=2, padding=2), + GDN(out_channel_N), + ResBlock_LeakyReLU_0_Point_1(out_channel_N), + nn.Conv2d(out_channel_N, out_channel_N, 5, stride=2, padding=2), + GDN(out_channel_N), + nn.Conv2d(out_channel_N, out_channel_M, 5, stride=2, padding=2), + ) + + self.contextualDecoder_part1 = nn.Sequential( + subpel_conv3x3(out_channel_M, out_channel_N, 2), + GDN(out_channel_N, inverse=True), + subpel_conv3x3(out_channel_N, out_channel_N, 2), + GDN(out_channel_N, inverse=True), + ResBlock_LeakyReLU_0_Point_1(out_channel_N), + subpel_conv3x3(out_channel_N, out_channel_N, 2), + GDN(out_channel_N, inverse=True), + ResBlock_LeakyReLU_0_Point_1(out_channel_N), + subpel_conv3x3(out_channel_N, out_channel_N, 2), + ) + + self.contextualDecoder_part2 = nn.Sequential( + nn.Conv2d(out_channel_N*2, out_channel_N, 3, stride=1, padding=1), + ResBlock(out_channel_N, out_channel_N, 3), + ResBlock(out_channel_N, out_channel_N, 3), + nn.Conv2d(out_channel_N, 3, 3, stride=1, padding=1), + ) + + self.priorEncoder = nn.Sequential( + nn.Conv2d(out_channel_M, out_channel_N, 3, stride=1, padding=1), + nn.LeakyReLU(inplace=True), + nn.Conv2d(out_channel_N, out_channel_N, 5, stride=2, padding=2), + nn.LeakyReLU(inplace=True), + nn.Conv2d(out_channel_N, out_channel_N, 5, stride=2, padding=2), + ) + + self.priorDecoder = nn.Sequential( + nn.ConvTranspose2d(out_channel_N, out_channel_M, 5, + stride=2, padding=2, output_padding=1), + nn.LeakyReLU(inplace=True), + nn.ConvTranspose2d(out_channel_M, out_channel_M, 5, + stride=2, padding=2, output_padding=1), + nn.LeakyReLU(inplace=True), + nn.ConvTranspose2d(out_channel_M, out_channel_M, 3, stride=1, padding=1) + ) + + self.mvpriorEncoder = nn.Sequential( + nn.Conv2d(out_channel_mv, out_channel_N, 3, stride=1, padding=1), + nn.LeakyReLU(inplace=True), + nn.Conv2d(out_channel_N, out_channel_N, 5, stride=2, padding=2), + nn.LeakyReLU(inplace=True), + nn.Conv2d(out_channel_N, out_channel_N, 5, stride=2, padding=2), + ) + + self.mvpriorDecoder = nn.Sequential( + nn.ConvTranspose2d(out_channel_N, out_channel_N, 5, + stride=2, padding=2, output_padding=1), + nn.LeakyReLU(inplace=True), + nn.ConvTranspose2d(out_channel_N, out_channel_N * 3 // 2, 5, + stride=2, padding=2, output_padding=1), + nn.LeakyReLU(inplace=True), + nn.ConvTranspose2d(out_channel_N * 3 // 2, out_channel_mv*2, 3, stride=1, padding=1) + ) + + self.entropy_parameters = nn.Sequential( + nn.Conv2d(out_channel_M * 12 // 3, out_channel_M * 10 // 3, 1), + nn.LeakyReLU(inplace=True), + nn.Conv2d(out_channel_M * 10 // 3, out_channel_M * 8 // 3, 1), + nn.LeakyReLU(inplace=True), + nn.Conv2d(out_channel_M * 8 // 3, out_channel_M * 6 // 3, 1), + ) + + self.auto_regressive = MaskedConv2d( + out_channel_M, 2 * out_channel_M, kernel_size=5, padding=2, stride=1 + ) + + self.auto_regressive_mv = MaskedConv2d( + out_channel_mv, 2 * out_channel_mv, kernel_size=5, padding=2, stride=1 + ) + + self.entropy_parameters_mv = nn.Sequential( + nn.Conv2d(out_channel_mv * 12 // 3, out_channel_mv * 10 // 3, 1), + nn.LeakyReLU(inplace=True), + nn.Conv2d(out_channel_mv * 10 // 3, out_channel_mv * 8 // 3, 1), + nn.LeakyReLU(inplace=True), + nn.Conv2d(out_channel_mv * 8 // 3, out_channel_mv * 6 // 3, 1), + ) + + self.temporalPriorEncoder = nn.Sequential( + nn.Conv2d(out_channel_N, out_channel_N, 5, stride=2, padding=2), + GDN(out_channel_N), + nn.Conv2d(out_channel_N, out_channel_N, 5, stride=2, padding=2), + GDN(out_channel_N), + nn.Conv2d(out_channel_N, out_channel_N, 5, stride=2, padding=2), + GDN(out_channel_N), + nn.Conv2d(out_channel_N, out_channel_M, 5, stride=2, padding=2), + ) + + self.opticFlow = ME_Spynet() + + + def motioncompensation(self, ref, mv): + ref_feature = self.feature_extract(ref) + prediction_init = flow_warp(ref_feature, mv) + context = self.context_refine(prediction_init) + + return context + + def mv_refine(self, ref, mv): + return self.mvDecoder_part2(torch.cat((mv, ref), 1)) + mv + + def quantize(self, inputs, mode, means=None): + assert(mode == "dequantize") + outputs = inputs.clone() + outputs -= means + outputs = torch.round(outputs) + outputs += means + return outputs + + def feature_probs_based_sigma(self, feature, mean, sigma): + outputs = self.quantize( + feature, "dequantize", mean + ) + values = outputs - mean + mu = torch.zeros_like(sigma) + sigma = sigma.clamp(1e-5, 1e10) + gaussian = torch.distributions.laplace.Laplace(mu, sigma) + probs = gaussian.cdf(values + 0.5) - gaussian.cdf(values - 0.5) + total_bits = torch.sum(torch.clamp(-1.0 * torch.log(probs + 1e-5) / math.log(2.0), 0, 50)) + return total_bits, probs + + def iclr18_estrate_bits_z(self, z): + prob = self.bitEstimator_z(z + 0.5) - self.bitEstimator_z(z - 0.5) + total_bits = torch.sum(torch.clamp(-1.0 * torch.log(prob + 1e-5) / math.log(2.0), 0, 50)) + return total_bits, prob + + def iclr18_estrate_bits_z_mv(self, z_mv): + prob = self.bitEstimator_z_mv(z_mv + 0.5) - self.bitEstimator_z_mv(z_mv - 0.5) + total_bits = torch.sum(torch.clamp(-1.0 * torch.log(prob + 1e-5) / math.log(2.0), 0, 50)) + return total_bits, prob + + def update(self, force=False): + self.bitEstimator_z_mv.update(force=force) + self.bitEstimator_z.update(force=force) + self.gaussian_encoder.update(force=force) + + def encode_decode(self, ref_frame, input_image, output_path): + encoded = self.encode(ref_frame, input_image, output_path) + decoded = self.decode(ref_frame, output_path) + encoded['recon_image'] = decoded + return encoded + + def encode(self, ref_frame, input_image, output_path): + from ..utils.stream_helper import encode_p + N, C, H, W = ref_frame.size() + compressed = self.compress(ref_frame, input_image) + mv_y_string = compressed['mv_y_string'] + mv_z_string = compressed['mv_z_string'] + y_string = compressed['y_string'] + z_string = compressed['z_string'] + encode_p(H, W, mv_y_string, mv_z_string, y_string, z_string, output_path) + return { + 'bpp_mv_y': compressed['bpp_mv_y'], + 'bpp_mv_z': compressed['bpp_mv_z'], + 'bpp_y': compressed['bpp_y'], + 'bpp_z': compressed['bpp_z'], + 'bpp': compressed['bpp'], + } + + def decode(self, ref_frame, input_path): + from ..utils.stream_helper import decode_p + height, width, mv_y_string, mv_z_string, y_string, z_string = decode_p(input_path) + return self.decompress(ref_frame, mv_y_string, mv_z_string, + y_string, z_string, height, width) + + def compress_ar(self, y, kernel_size, context_prediction, params, entropy_parameters): + kernel_size = 5 + padding = (kernel_size - 1) // 2 + + height = y.size(2) + width = y.size(3) + + y_hat = F.pad(y, (padding, padding, padding, padding)) + y_q = torch.zeros_like(y) + y_scales = torch.zeros_like(y) + + for h in range(height): + for w in range(width): + y_crop = y_hat[0:1, :, h:h + kernel_size, w:w + kernel_size] + ctx_p = F.conv2d( + y_crop, + context_prediction.weight, + bias=context_prediction.bias, + ) + + p = params[0:1, :, h:h + 1, w:w + 1] + gaussian_params = entropy_parameters(torch.cat((p, ctx_p), dim=1)) + means_hat, scales_hat = gaussian_params.chunk(2, 1) + + y_crop = y_crop[0:1, :, padding:padding+1, padding:padding+1] + y_crop_q = torch.round(y_crop - means_hat) + y_hat[0, :, h + padding, w + padding] = (y_crop_q + means_hat)[0, :, 0, 0] + y_q[0, :, h, w] = y_crop_q[0, :, 0, 0] + y_scales[0, :, h, w] = scales_hat[0, :, 0, 0] + # change to channel last + y_q = y_q.permute(0, 2, 3, 1) + y_scales = y_scales.permute(0, 2, 3, 1) + y_string = self.gaussian_encoder.compress(y_q, y_scales) + y_hat = y_hat[:, :, padding:-padding, padding:-padding] + return y_string, y_hat + + def decompress_ar(self, y_string, channel, height, width, downsample, kernel_size, + context_prediction, params, entropy_parameters): + device = next(self.parameters()).device + padding = (kernel_size - 1) // 2 + + y_size = get_downsampled_shape(height, width, downsample) + y_height = y_size[0] + y_width = y_size[1] + + y_hat = torch.zeros( + (1, channel, y_height + 2 * padding, y_width + 2 * padding), + device=params.device, + ) + + self.gaussian_encoder.set_stream(y_string) + + for h in range(y_height): + for w in range(y_width): + # only perform the 5x5 convolution on a cropped tensor + # centered in (h, w) + y_crop = y_hat[0:1, :, h:h + kernel_size, w:w + kernel_size] + ctx_p = F.conv2d( + y_crop, + context_prediction.weight, + bias=context_prediction.bias, + ) + p = params[0:1, :, h:h + 1, w:w + 1] + gaussian_params = entropy_parameters(torch.cat((p, ctx_p), dim=1)) + means_hat, scales_hat = gaussian_params.chunk(2, 1) + rv = self.gaussian_encoder.decode_stream(scales_hat) + rv = rv.to(device) + rv = rv + means_hat + y_hat[0, :, h + padding: h + padding + 1, w + padding: w + padding + 1] = rv + + y_hat = y_hat[:, :, padding:-padding, padding:-padding] + return y_hat + + def compress(self, referframe, input_image): + device = input_image.device + estmv = self.opticFlow(input_image, referframe) + mvfeature = self.mvEncoder(estmv) + z_mv = self.mvpriorEncoder(mvfeature) + compressed_z_mv = torch.round(z_mv) + mv_z_string = self.bitEstimator_z_mv.compress(compressed_z_mv) + mv_z_size = [compressed_z_mv.size(2), compressed_z_mv.size(3)] + mv_z_hat = self.bitEstimator_z_mv.decompress(mv_z_string, mv_z_size) + mv_z_hat = mv_z_hat.to(device) + + params_mv = self.mvpriorDecoder(mv_z_hat) + mv_y_string, mv_y_hat = self.compress_ar(mvfeature, 5, self.auto_regressive_mv, + params_mv, self.entropy_parameters_mv) + + quant_mv_upsample = self.mvDecoder_part1(mv_y_hat) + quant_mv_upsample_refine = self.mv_refine(referframe, quant_mv_upsample) + context = self.motioncompensation(referframe, quant_mv_upsample_refine) + + temporal_prior_params = self.temporalPriorEncoder(context) + feature = self.contextualEncoder(torch.cat((input_image, context), dim=1)) + z = self.priorEncoder(feature) + compressed_z = torch.round(z) + z_string = self.bitEstimator_z.compress(compressed_z) + z_size = [compressed_z.size(2), compressed_z.size(3)] + z_hat = self.bitEstimator_z.decompress(z_string, z_size) + z_hat = z_hat.to(device) + + params = self.priorDecoder(z_hat) + y_string, y_hat = self.compress_ar(feature, 5, self.auto_regressive, + torch.cat((temporal_prior_params, params), dim=1), self.entropy_parameters) + + recon_image_feature = self.contextualDecoder_part1(y_hat) + recon_image = self.contextualDecoder_part2(torch.cat((recon_image_feature, context), dim=1)) + + im_shape = input_image.size() + pixel_num = im_shape[0] * im_shape[2] * im_shape[3] + bpp_y = len(y_string) * 8 / pixel_num + bpp_z = len(z_string) * 8 / pixel_num + bpp_mv_y = len(mv_y_string) * 8 / pixel_num + bpp_mv_z = len(mv_z_string) * 8 / pixel_num + + bpp = bpp_y + bpp_z + bpp_mv_y + bpp_mv_z + + return {"bpp_mv_y": bpp_mv_y, + "bpp_mv_z": bpp_mv_z, + "bpp_y": bpp_y, + "bpp_z": bpp_z, + "bpp": bpp, + "recon_image": recon_image, + "mv_y_string": mv_y_string, + "mv_z_string": mv_z_string, + "y_string": y_string, + "z_string": z_string, + } + + def decompress(self, referframe, mv_y_string, mv_z_string, y_string, z_string, height, width): + device = next(self.parameters()).device + mv_z_size = get_downsampled_shape(height, width, 64) + mv_z_hat = self.bitEstimator_z_mv.decompress(mv_z_string, mv_z_size) + mv_z_hat = mv_z_hat.to(device) + params_mv = self.mvpriorDecoder(mv_z_hat) + mv_y_hat = self.decompress_ar(mv_y_string, self.out_channel_mv, height, width, 16, 5, + self.auto_regressive_mv, params_mv, + self.entropy_parameters_mv) + + quant_mv_upsample = self.mvDecoder_part1(mv_y_hat) + quant_mv_upsample_refine = self.mv_refine(referframe, quant_mv_upsample) + context = self.motioncompensation(referframe, quant_mv_upsample_refine) + temporal_prior_params = self.temporalPriorEncoder(context) + + z_size = get_downsampled_shape(height, width, 64) + z_hat = self.bitEstimator_z.decompress(z_string, z_size) + z_hat = z_hat.to(device) + params = self.priorDecoder(z_hat) + y_hat = self.decompress_ar(y_string, self.out_channel_M, height, width, 16, 5, + self.auto_regressive, torch.cat((temporal_prior_params, params), dim=1), + self.entropy_parameters) + recon_image_feature = self.contextualDecoder_part1(y_hat) + recon_image = self.contextualDecoder_part2(torch.cat((recon_image_feature, context) , dim=1)) + recon_image = recon_image.clamp(0, 1) + + return recon_image + + def forward(self, referframe, input_image): + estmv = self.opticFlow(input_image, referframe) + mvfeature = self.mvEncoder(estmv) + z_mv = self.mvpriorEncoder(mvfeature) + compressed_z_mv = torch.round(z_mv) + params_mv = self.mvpriorDecoder(compressed_z_mv) + + quant_mv = torch.round(mvfeature) + + ctx_params_mv = self.auto_regressive_mv(quant_mv) + gaussian_params_mv = self.entropy_parameters_mv( + torch.cat((params_mv, ctx_params_mv), dim=1) + ) + means_hat_mv, scales_hat_mv = gaussian_params_mv.chunk(2, 1) + + quant_mv_upsample = self.mvDecoder_part1(quant_mv) + + quant_mv_upsample_refine = self.mv_refine(referframe, quant_mv_upsample) + + context = self.motioncompensation(referframe, quant_mv_upsample_refine) + + temporal_prior_params = self.temporalPriorEncoder(context) + + feature = self.contextualEncoder(torch.cat((input_image, context), dim=1)) + z = self.priorEncoder(feature) + compressed_z = torch.round(z) + params = self.priorDecoder(compressed_z) + + feature_renorm = feature + + compressed_y_renorm = torch.round(feature_renorm) + + ctx_params = self.auto_regressive(compressed_y_renorm) + gaussian_params = self.entropy_parameters( + torch.cat((temporal_prior_params, params, ctx_params), dim=1) + ) + means_hat, scales_hat = gaussian_params.chunk(2, 1) + + recon_image_feature = self.contextualDecoder_part1(compressed_y_renorm) + recon_image = self.contextualDecoder_part2(torch.cat((recon_image_feature, context) , dim=1)) + + total_bits_y, _ = self.feature_probs_based_sigma( + feature_renorm, means_hat, scales_hat) + total_bits_mv, _ = self.feature_probs_based_sigma(mvfeature, means_hat_mv, scales_hat_mv) + total_bits_z, _ = self.iclr18_estrate_bits_z(compressed_z) + total_bits_z_mv, _ = self.iclr18_estrate_bits_z_mv(compressed_z_mv) + + im_shape = input_image.size() + pixel_num = im_shape[0] * im_shape[2] * im_shape[3] + bpp_y = total_bits_y / pixel_num + bpp_z = total_bits_z / pixel_num + bpp_mv_y = total_bits_mv / pixel_num + bpp_mv_z = total_bits_z_mv / pixel_num + + bpp = bpp_y + bpp_z + bpp_mv_y + bpp_mv_z + + return {"bpp_mv_y": bpp_mv_y, + "bpp_mv_z": bpp_mv_z, + "bpp_y": bpp_y, + "bpp_z": bpp_z, + "bpp": bpp, + "recon_image": recon_image, + "context": context, + } + + def load_dict(self, pretrained_dict): + result_dict = {} + for key, weight in pretrained_dict.items(): + result_key = key + if key[:7] == "module.": + result_key = key[7:] + result_dict[result_key] = weight + + self.load_state_dict(result_dict) \ No newline at end of file diff --git a/src/models/priors.py b/src/models/priors.py new file mode 100644 index 0000000..cf837b5 --- /dev/null +++ b/src/models/priors.py @@ -0,0 +1,718 @@ +# Copyright 2020 InterDigital Communications, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import warnings + +import torch +import torch.nn as nn +import torch.nn.functional as F + +# pylint: disable=E0611,E0401 +from ..entropy_models.entropy_models import EntropyBottleneck, GaussianConditional +from ..layers.layers import GDN, MaskedConv2d + +from .utils import conv, deconv, update_registered_buffers + +# pylint: enable=E0611,E0401 + + +__all__ = [ + "CompressionModel", + "FactorizedPrior", + "ScaleHyperprior", + "MeanScaleHyperprior", + "JointAutoregressiveHierarchicalPriors", +] + + +class CompressionModel(nn.Module): + """Base class for constructing an auto-encoder with at least one entropy + bottleneck module. + + Args: + entropy_bottleneck_channels (int): Number of channels of the entropy + bottleneck + """ + + def __init__(self, entropy_bottleneck_channels, init_weights=True): + super().__init__() + self.entropy_bottleneck = EntropyBottleneck(entropy_bottleneck_channels) + + if init_weights: + self._initialize_weights() + + def aux_loss(self): + """Return the aggregated loss over the auxiliary entropy bottleneck + module(s). + """ + aux_loss = sum( + m.loss() for m in self.modules() if isinstance(m, EntropyBottleneck) + ) + return aux_loss + + def _initialize_weights(self): + for m in self.modules(): + if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)): + nn.init.kaiming_normal_(m.weight) + if m.bias is not None: + nn.init.zeros_(m.bias) + + def forward(self, *args): + raise NotImplementedError() + + def parameters(self): + """Returns an iterator over the model parameters.""" + for m in self.children(): + if isinstance(m, EntropyBottleneck): + continue + for p in m.parameters(): + yield p + + def aux_parameters(self): + """ + Returns an iterator over the entropy bottleneck(s) parameters for + the auxiliary loss. + """ + for m in self.children(): + if not isinstance(m, EntropyBottleneck): + continue + for p in m.parameters(): + yield p + + def update(self, force=False): + """Updates the entropy bottleneck(s) CDF values. + + Needs to be called once after training to be able to later perform the + evaluation with an actual entropy coder. + + Args: + force (bool): overwrite previous values (default: False) + + """ + for m in self.children(): + if not isinstance(m, EntropyBottleneck): + continue + m.update(force=force) + + +class FactorizedPrior(CompressionModel): + r"""Factorized Prior model from J. Balle, D. Minnen, S. Singh, S.J. Hwang, + N. Johnston: `"Variational Image Compression with a Scale Hyperprior" + `_, Int Conf. on Learning Representations + (ICLR), 2018. + + Args: + N (int): Number of channels + M (int): Number of channels in the expansion layers (last layer of the + encoder and last layer of the hyperprior decoder) + """ + + def __init__(self, N, M, **kwargs): + super().__init__(entropy_bottleneck_channels=M, **kwargs) + + self.g_a = nn.Sequential( + conv(3, N), + GDN(N), + conv(N, N), + GDN(N), + conv(N, N), + GDN(N), + conv(N, M), + ) + + self.g_s = nn.Sequential( + deconv(M, N), + GDN(N, inverse=True), + deconv(N, N), + GDN(N, inverse=True), + deconv(N, N), + GDN(N, inverse=True), + deconv(N, 3), + ) + + def forward(self, x): + y = self.g_a(x) + y_hat, y_likelihoods = self.entropy_bottleneck(y) + x_hat = self.g_s(y_hat) + + return { + "x_hat": x_hat, + "likelihoods": { + "y": y_likelihoods, + }, + } + + def load_state_dict(self, state_dict): + # Dynamically update the entropy bottleneck buffers related to the CDFs + update_registered_buffers( + self.entropy_bottleneck, + "entropy_bottleneck", + ["_quantized_cdf", "_offset", "_cdf_length"], + state_dict, + ) + super().load_state_dict(state_dict) + + @classmethod + def from_state_dict(cls, state_dict): + """Return a new model instance from `state_dict`.""" + N = state_dict["g_a.0.weight"].size(0) + M = state_dict["g_a.6.weight"].size(0) + net = cls(N, M) + net.load_state_dict(state_dict) + return net + + def compress(self, x): + y = self.g_a(x) + y_strings = self.entropy_bottleneck.compress(y) + return {"strings": [y_strings], "shape": y.size()[-2:]} + + def decompress(self, strings, shape): + assert isinstance(strings, list) and len(strings) == 1 + y_hat = self.entropy_bottleneck.decompress(strings[0], shape) + x_hat = self.g_s(y_hat) + return {"x_hat": x_hat} + + +# From Balle's tensorflow compression examples +SCALES_MIN = 0.11 +SCALES_MAX = 256 +SCALES_LEVELS = 64 + + +def get_scale_table( + min=SCALES_MIN, max=SCALES_MAX, levels=SCALES_LEVELS +): # pylint: disable=W0622 + return torch.exp(torch.linspace(math.log(min), math.log(max), levels)) + + +class ScaleHyperprior(CompressionModel): + r"""Scale Hyperprior model from J. Balle, D. Minnen, S. Singh, S.J. Hwang, + N. Johnston: `"Variational Image Compression with a Scale Hyperprior" + `_ Int. Conf. on Learning Representations + (ICLR), 2018. + + Args: + N (int): Number of channels + M (int): Number of channels in the expansion layers (last layer of the + encoder and last layer of the hyperprior decoder) + """ + + def __init__(self, N, M, **kwargs): + super().__init__(entropy_bottleneck_channels=N, **kwargs) + + self.g_a = nn.Sequential( + conv(3, N), + GDN(N), + conv(N, N), + GDN(N), + conv(N, N), + GDN(N), + conv(N, M), + ) + + self.g_s = nn.Sequential( + deconv(M, N), + GDN(N, inverse=True), + deconv(N, N), + GDN(N, inverse=True), + deconv(N, N), + GDN(N, inverse=True), + deconv(N, 3), + ) + + self.h_a = nn.Sequential( + conv(M, N, stride=1, kernel_size=3), + nn.ReLU(inplace=True), + conv(N, N), + nn.ReLU(inplace=True), + conv(N, N), + ) + + self.h_s = nn.Sequential( + deconv(N, N), + nn.ReLU(inplace=True), + deconv(N, N), + nn.ReLU(inplace=True), + conv(N, M, stride=1, kernel_size=3), + nn.ReLU(inplace=True), + ) + + self.gaussian_conditional = GaussianConditional(None) + self.N = int(N) + self.M = int(M) + + def forward(self, x): + y = self.g_a(x) + z = self.h_a(torch.abs(y)) + z_hat, z_likelihoods = self.entropy_bottleneck(z) + scales_hat = self.h_s(z_hat) + y_hat, y_likelihoods = self.gaussian_conditional(y, scales_hat) + x_hat = self.g_s(y_hat) + + return { + "x_hat": x_hat, + "likelihoods": {"y": y_likelihoods, "z": z_likelihoods}, + } + + def load_state_dict(self, state_dict): + # Dynamically update the entropy bottleneck buffers related to the CDFs + update_registered_buffers( + self.entropy_bottleneck, + "entropy_bottleneck", + ["_quantized_cdf", "_offset", "_cdf_length"], + state_dict, + ) + update_registered_buffers( + self.gaussian_conditional, + "gaussian_conditional", + ["_quantized_cdf", "_offset", "_cdf_length", "scale_table"], + state_dict, + ) + super().load_state_dict(state_dict) + + @classmethod + def from_state_dict(cls, state_dict): + """Return a new model instance from `state_dict`.""" + N = state_dict["g_a.0.weight"].size(0) + M = state_dict["g_a.6.weight"].size(0) + net = cls(N, M) + net.load_state_dict(state_dict) + return net + + def update(self, scale_table=None, force=False): + if scale_table is None: + scale_table = get_scale_table() + self.gaussian_conditional.update_scale_table(scale_table, force=force) + super().update(force=force) + + def encode_decode(self, x, output_path): + N, C, H, W = x.size() + bits = self.encode(x, output_path) * 8 + bpp = bits / (H * W) + x_hat = self.decode(output_path) + result = { + 'bpp': bpp, + 'x_hat': x_hat, + } + return result + + def encode(self, x, output_path): + from ..utils.stream_helper import encode_i + N, C, H, W = x.size() + compressed = self.compress(x) + y_string = compressed['strings'][0][0] + z_string = compressed['strings'][1][0] + encode_i(H, W, y_string, z_string, output_path) + return len(y_string) + len(z_string) + + def decode(self, input_path): + from ..utils.stream_helper import decode_i, get_downsampled_shape + height, width, y_string, z_string = decode_i(input_path) + shape = get_downsampled_shape(height, width, 64) + decompressed = self.decompress([[y_string], [z_string]], shape) + return decompressed['x_hat'] + + def compress(self, x): + y = self.g_a(x) + z = self.h_a(torch.abs(y)) + + z_strings = self.entropy_bottleneck.compress(z) + z_hat = self.entropy_bottleneck.decompress(z_strings, z.size()[-2:]) + + scales_hat = self.h_s(z_hat) + indexes = self.gaussian_conditional.build_indexes(scales_hat) + y_strings = self.gaussian_conditional.compress(y, indexes) + return {"strings": [y_strings, z_strings], "shape": z.size()[-2:]} + + def decompress(self, strings, shape): + assert isinstance(strings, list) and len(strings) == 2 + z_hat = self.entropy_bottleneck.decompress(strings[1], shape) + scales_hat = self.h_s(z_hat) + indexes = self.gaussian_conditional.build_indexes(scales_hat) + y_hat = self.gaussian_conditional.decompress(strings[0], indexes) + y_hat = y_hat.to(z_hat.device) + x_hat = self.g_s(y_hat).clamp_(0, 1) + return {"x_hat": x_hat} + + +class MeanScaleHyperprior(ScaleHyperprior): + r"""Scale Hyperprior with non zero-mean Gaussian conditionals from D. + Minnen, J. Balle, G.D. Toderici: `"Joint Autoregressive and Hierarchical + Priors for Learned Image Compression" `_, + Adv. in Neural Information Processing Systems 31 (NeurIPS 2018). + + Args: + N (int): Number of channels + M (int): Number of channels in the expansion layers (last layer of the + encoder and last layer of the hyperprior decoder) + """ + + def __init__(self, N, M, **kwargs): + super().__init__(N, M, **kwargs) + + self.h_a = nn.Sequential( + conv(M, N, stride=1, kernel_size=3), + nn.LeakyReLU(inplace=True), + conv(N, N), + nn.LeakyReLU(inplace=True), + conv(N, N), + ) + + self.h_s = nn.Sequential( + deconv(N, M), + nn.LeakyReLU(inplace=True), + deconv(M, M * 3 // 2), + nn.LeakyReLU(inplace=True), + conv(M * 3 // 2, M * 2, stride=1, kernel_size=3), + ) + + def forward(self, x): + y = self.g_a(x) + z = self.h_a(y) + z_hat, z_likelihoods = self.entropy_bottleneck(z) + gaussian_params = self.h_s(z_hat) + scales_hat, means_hat = gaussian_params.chunk(2, 1) + y_hat, y_likelihoods = self.gaussian_conditional(y, scales_hat, means=means_hat) + x_hat = self.g_s(y_hat) + + return { + "x_hat": x_hat, + "likelihoods": {"y": y_likelihoods, "z": z_likelihoods}, + } + + def compress(self, x): + y = self.g_a(x) + z = self.h_a(y) + + z_strings = self.entropy_bottleneck.compress(z) + z_hat = self.entropy_bottleneck.decompress(z_strings, z.size()[-2:]) + + gaussian_params = self.h_s(z_hat) + scales_hat, means_hat = gaussian_params.chunk(2, 1) + indexes = self.gaussian_conditional.build_indexes(scales_hat) + y_strings = self.gaussian_conditional.compress(y, indexes, means=means_hat) + return {"strings": [y_strings, z_strings], "shape": z.size()[-2:]} + + def decompress(self, strings, shape): + assert isinstance(strings, list) and len(strings) == 2 + z_hat = self.entropy_bottleneck.decompress(strings[1], shape) + gaussian_params = self.h_s(z_hat) + scales_hat, means_hat = gaussian_params.chunk(2, 1) + indexes = self.gaussian_conditional.build_indexes(scales_hat) + y_hat = self.gaussian_conditional.decompress( + strings[0], indexes, means=means_hat + ) + x_hat = self.g_s(y_hat).clamp_(0, 1) + return {"x_hat": x_hat} + + +class JointAutoregressiveHierarchicalPriors(CompressionModel): + r"""Joint Autoregressive Hierarchical Priors model from D. + Minnen, J. Balle, G.D. Toderici: `"Joint Autoregressive and Hierarchical + Priors for Learned Image Compression" `_, + Adv. in Neural Information Processing Systems 31 (NeurIPS 2018). + + Args: + N (int): Number of channels + M (int): Number of channels in the expansion layers (last layer of the + encoder and last layer of the hyperprior decoder) + """ + + def __init__(self, N=192, M=192, **kwargs): + super().__init__(entropy_bottleneck_channels=N, **kwargs) + + self.g_a = nn.Sequential( + conv(3, N, kernel_size=5, stride=2), + GDN(N), + conv(N, N, kernel_size=5, stride=2), + GDN(N), + conv(N, N, kernel_size=5, stride=2), + GDN(N), + conv(N, M, kernel_size=5, stride=2), + ) + + self.g_s = nn.Sequential( + deconv(M, N, kernel_size=5, stride=2), + GDN(N, inverse=True), + deconv(N, N, kernel_size=5, stride=2), + GDN(N, inverse=True), + deconv(N, N, kernel_size=5, stride=2), + GDN(N, inverse=True), + deconv(N, 3, kernel_size=5, stride=2), + ) + + self.h_a = nn.Sequential( + conv(M, N, stride=1, kernel_size=3), + nn.LeakyReLU(inplace=True), + conv(N, N, stride=2, kernel_size=5), + nn.LeakyReLU(inplace=True), + conv(N, N, stride=2, kernel_size=5), + ) + + self.h_s = nn.Sequential( + deconv(N, M, stride=2, kernel_size=5), + nn.LeakyReLU(inplace=True), + deconv(M, M * 3 // 2, stride=2, kernel_size=5), + nn.LeakyReLU(inplace=True), + conv(M * 3 // 2, M * 2, stride=1, kernel_size=3), + ) + + self.entropy_parameters = nn.Sequential( + nn.Conv2d(M * 12 // 3, M * 10 // 3, 1), + nn.LeakyReLU(inplace=True), + nn.Conv2d(M * 10 // 3, M * 8 // 3, 1), + nn.LeakyReLU(inplace=True), + nn.Conv2d(M * 8 // 3, M * 6 // 3, 1), + ) + + self.context_prediction = MaskedConv2d( + M, 2 * M, kernel_size=5, padding=2, stride=1 + ) + + self.gaussian_conditional = GaussianConditional(None) + self.N = int(N) + self.M = int(M) + + def forward(self, x): + y = self.g_a(x) + z = self.h_a(y) + z_hat, z_likelihoods = self.entropy_bottleneck(z) + params = self.h_s(z_hat) + + y_hat = self.gaussian_conditional._quantize( # pylint: disable=protected-access + y, "noise" if self.training else "dequantize" + ) + ctx_params = self.context_prediction(y_hat) + gaussian_params = self.entropy_parameters( + torch.cat((params, ctx_params), dim=1) + ) + scales_hat, means_hat = gaussian_params.chunk(2, 1) + _, y_likelihoods = self.gaussian_conditional(y, scales_hat, means=means_hat) + x_hat = self.g_s(y_hat) + + return { + "x_hat": x_hat, + "likelihoods": {"y": y_likelihoods, "z": z_likelihoods}, + } + + @classmethod + def from_state_dict(cls, state_dict): + """Return a new model instance from `state_dict`.""" + N = state_dict["g_a.0.weight"].size(0) + M = state_dict["g_a.6.weight"].size(0) + net = cls(N, M) + net.load_state_dict(state_dict) + return net + + def encode_decode(self, x, output_path): + N, C, H, W = x.size() + bits = self.encode(x, output_path) * 8 + bpp = bits / (H * W) + x_hat = self.decode(output_path) + result = { + 'bpp': bpp, + 'x_hat': x_hat, + } + return result + + def encode(self, x, output_path): + from ..utils.stream_helper import encode_i + N, C, H, W = x.size() + compressed = self.compress(x) + y_string = compressed['strings'][0][0] + z_string = compressed['strings'][1][0] + encode_i(H, W, y_string, z_string, output_path) + return len(y_string) + len(z_string) + + def decode(self, input_path): + from ..utils.stream_helper import decode_i, get_downsampled_shape + height, width, y_string, z_string = decode_i(input_path) + shape = get_downsampled_shape(height, width, 64) + decompressed = self.decompress([[y_string], [z_string]], shape) + return decompressed['x_hat'] + + def compress(self, x): + from ..entropy_models.MLCodec_rans import BufferedRansEncoder + if next(self.parameters()).device != torch.device("cpu"): + warnings.warn( + "Inference on GPU is not recommended for the autoregressive " + "models (the entropy coder is run sequentially on CPU)." + ) + + y = self.g_a(x) + z = self.h_a(y) + + z_strings = self.entropy_bottleneck.compress(z) + z_hat = self.entropy_bottleneck.decompress(z_strings, z.size()[-2:]) + + params = self.h_s(z_hat) + + s = 4 # scaling factor between z and y + kernel_size = 5 # context prediction kernel size + padding = (kernel_size - 1) // 2 + + y_height = z_hat.size(2) * s + y_width = z_hat.size(3) * s + + y_hat = F.pad(y, (padding, padding, padding, padding)) + + # pylint: disable=protected-access + cdf = self.gaussian_conditional._quantized_cdf.tolist() + cdf_lengths = self.gaussian_conditional._cdf_length.reshape(-1).int().tolist() + offsets = self.gaussian_conditional._offset.reshape(-1).int().tolist() + # pylint: enable=protected-access + + y_strings = [] + for i in range(y.size(0)): + encoder = BufferedRansEncoder() + # Warning, this is slow... + # TODO: profile the calls to the bindings... + symbols_list = [] + indexes_list = [] + for h in range(y_height): + for w in range(y_width): + y_crop = y_hat[i:i + 1, :, h:h + kernel_size, w:w + kernel_size] + ctx_p = F.conv2d( + y_crop, + self.context_prediction.weight, + bias=self.context_prediction.bias, + ) + + # 1x1 conv for the entropy parameters prediction network, so + # we only keep the elements in the "center" + p = params[i:i + 1, :, h:h + 1, w:w + 1] + gaussian_params = self.entropy_parameters( + torch.cat((p, ctx_p), dim=1) + ) + scales_hat, means_hat = gaussian_params.chunk(2, 1) + + indexes = self.gaussian_conditional.build_indexes(scales_hat) + y_q = torch.round(y_crop - means_hat) + y_hat[i, :, h + padding, w + padding] = (y_q + means_hat)[ + i, :, padding, padding + ] + + symbols_list.extend(y_q[i, :, padding, padding].int().tolist()) + indexes_list.extend(indexes[i, :].squeeze().int().tolist()) + + encoder.encode_with_indexes( + symbols_list, indexes_list, cdf, cdf_lengths, offsets + ) + + string = encoder.flush() + y_strings.append(string) + + return {"strings": [y_strings, z_strings], "shape": z.size()[-2:]} + + def decompress(self, strings, shape): + from ..entropy_models.MLCodec_rans import RansDecoder + assert isinstance(strings, list) and len(strings) == 2 + + if next(self.parameters()).device != torch.device("cpu"): + warnings.warn( + "Inference on GPU is not recommended for the autoregressive " + "models (the entropy coder is run sequentially on CPU)." + ) + + # FIXME: we don't respect the default entropy coder and directly call the + # range ANS decoder + + z_hat = self.entropy_bottleneck.decompress(strings[1], shape) + params = self.h_s(z_hat) + + s = 4 # scaling factor between z and y + kernel_size = 5 # context prediction kernel size + padding = (kernel_size - 1) // 2 + + y_height = z_hat.size(2) * s + y_width = z_hat.size(3) * s + + # initialize y_hat to zeros, and pad it so we can directly work with + # sub-tensors of size (N, C, kernel size, kernel_size) + y_hat = torch.zeros( + (z_hat.size(0), self.M, y_height + 2 * padding, y_width + 2 * padding), + device=z_hat.device, + ) + decoder = RansDecoder() + + # pylint: disable=protected-access + cdf = self.gaussian_conditional._quantized_cdf.tolist() + cdf_lengths = self.gaussian_conditional._cdf_length.reshape(-1).int().tolist() + offsets = self.gaussian_conditional._offset.reshape(-1).int().tolist() + + # Warning: this is slow due to the auto-regressive nature of the + # decoding... See more recent publication where they use an + # auto-regressive module on chunks of channels for faster decoding... + for i, y_string in enumerate(strings[0]): + decoder.set_stream(y_string) + + for h in range(y_height): + for w in range(y_width): + # only perform the 5x5 convolution on a cropped tensor + # centered in (h, w) + y_crop = y_hat[i:i + 1, :, h:h + kernel_size, w:w + kernel_size] + ctx_p = F.conv2d( + y_crop, + self.context_prediction.weight, + bias=self.context_prediction.bias, + ) + # 1x1 conv for the entropy parameters prediction network, so + # we only keep the elements in the "center" + p = params[i:i + 1, :, h:h + 1, w:w + 1] + gaussian_params = self.entropy_parameters( + torch.cat((p, ctx_p), dim=1) + ) + scales_hat, means_hat = gaussian_params.chunk(2, 1) + + indexes = self.gaussian_conditional.build_indexes(scales_hat) + + rv = decoder.decode_stream( + indexes[i, :].squeeze().int().tolist(), + cdf, + cdf_lengths, + offsets, + ) + rv = torch.Tensor(rv).reshape(1, -1, 1, 1) + + rv = self.gaussian_conditional._dequantize(rv, means_hat) + + y_hat[i, :, h + padding: h + padding + 1, w + padding: w + padding + 1] = rv + y_hat = y_hat[:, :, padding:-padding, padding:-padding] + # pylint: enable=protected-access + + x_hat = self.g_s(y_hat).clamp_(0, 1) + return {"x_hat": x_hat} + + def update(self, scale_table=None, force=False): + if scale_table is None: + scale_table = get_scale_table() + self.gaussian_conditional.update_scale_table(scale_table, force=force) + super().update(force=force) + + def load_state_dict(self, state_dict): + # Dynamically update the entropy bottleneck buffers related to the CDFs + update_registered_buffers( + self.entropy_bottleneck, + "entropy_bottleneck", + ["_quantized_cdf", "_offset", "_cdf_length"], + state_dict, + ) + update_registered_buffers( + self.gaussian_conditional, + "gaussian_conditional", + ["_quantized_cdf", "_offset", "_cdf_length", "scale_table"], + state_dict, + ) + super().load_state_dict(state_dict) diff --git a/src/models/utils.py b/src/models/utils.py new file mode 100644 index 0000000..eedaffb --- /dev/null +++ b/src/models/utils.py @@ -0,0 +1,130 @@ +# Copyright 2020 InterDigital Communications, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn + + +def find_named_module(module, query): + """Helper function to find a named module. Returns a `nn.Module` or `None` + + Args: + module (nn.Module): the root module + query (str): the module name to find + + Returns: + nn.Module or None + """ + + return next((m for n, m in module.named_modules() if n == query), None) + + +def find_named_buffer(module, query): + """Helper function to find a named buffer. Returns a `torch.Tensor` or `None` + + Args: + module (nn.Module): the root module + query (str): the buffer name to find + + Returns: + torch.Tensor or None + """ + return next((b for n, b in module.named_buffers() if n == query), None) + + +def _update_registered_buffer( + module, + buffer_name, + state_dict_key, + state_dict, + policy="resize_if_empty", + dtype=torch.int, +): + new_size = state_dict[state_dict_key].size() + registered_buf = find_named_buffer(module, buffer_name) + + if policy in ("resize_if_empty", "resize"): + if registered_buf is None: + raise RuntimeError(f'buffer "{buffer_name}" was not registered') + + if policy == "resize" or registered_buf.numel() == 0: + registered_buf.resize_(new_size) + + elif policy == "register": + if registered_buf is not None: + raise RuntimeError(f'buffer "{buffer_name}" was already registered') + + module.register_buffer(buffer_name, torch.empty(new_size, dtype=dtype).fill_(0)) + + else: + raise ValueError(f'Invalid policy "{policy}"') + + +def update_registered_buffers( + module, + module_name, + buffer_names, + state_dict, + policy="resize_if_empty", + dtype=torch.int, +): + """Update the registered buffers in a module according to the tensors sized + in a state_dict. + + (There's no way in torch to directly load a buffer with a dynamic size) + + Args: + module (nn.Module): the module + module_name (str): module name in the state dict + buffer_names (list(str)): list of the buffer names to resize in the module + state_dict (dict): the state dict + policy (str): Update policy, choose from + ('resize_if_empty', 'resize', 'register') + dtype (dtype): Type of buffer to be registered (when policy is 'register') + """ + valid_buffer_names = [n for n, _ in module.named_buffers()] + for buffer_name in buffer_names: + if buffer_name not in valid_buffer_names: + raise ValueError(f'Invalid buffer name "{buffer_name}"') + + for buffer_name in buffer_names: + _update_registered_buffer( + module, + buffer_name, + f"{module_name}.{buffer_name}", + state_dict, + policy, + dtype, + ) + + +def conv(in_channels, out_channels, kernel_size=5, stride=2): + return nn.Conv2d( + in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + padding=kernel_size // 2, + ) + + +def deconv(in_channels, out_channels, kernel_size=5, stride=2): + return nn.ConvTranspose2d( + in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + output_padding=stride - 1, + padding=kernel_size // 2, + ) diff --git a/src/models/video_net.py b/src/models/video_net.py new file mode 100644 index 0000000..4392d8e --- /dev/null +++ b/src/models/video_net.py @@ -0,0 +1,260 @@ +import torch +import numpy as np +import torch.nn as nn +import torch.nn.functional as F +from torch.autograd import Function + + +Backward_tensorGrid = [{} for i in range(8)] +Backward_tensorGrid_cpu = {} + + +class LowerBound(Function): + @staticmethod + def forward(ctx, inputs, bound): + b = torch.ones_like(inputs) * bound + ctx.save_for_backward(inputs, b) + return torch.max(inputs, b) + + @staticmethod + def backward(ctx, grad_output): + inputs, b = ctx.saved_tensors + pass_through_1 = inputs >= b + pass_through_2 = grad_output < 0 + + pass_through = pass_through_1 | pass_through_2 + return pass_through.type(grad_output.dtype) * grad_output, None + + +class GDN(nn.Module): + def __init__(self, + ch, + inverse=False, + beta_min=1e-6, + gamma_init=0.1, + reparam_offset=2**-18): + super(GDN, self).__init__() + self.inverse = inverse + self.beta_min = beta_min + self.gamma_init = gamma_init + self.reparam_offset = reparam_offset + + self.build(ch) + + def build(self, ch): + self.pedestal = self.reparam_offset**2 + self.beta_bound = ((self.beta_min + self.reparam_offset**2)**0.5) + self.gamma_bound = self.reparam_offset + + beta = torch.sqrt(torch.ones(ch)+self.pedestal) + self.beta = nn.Parameter(beta) + + eye = torch.eye(ch) + g = self.gamma_init*eye + g = g + self.pedestal + gamma = torch.sqrt(g) + + self.gamma = nn.Parameter(gamma) + self.pedestal = self.pedestal + + def forward(self, inputs): + unfold = False + if inputs.dim() == 5: + unfold = True + bs, ch, d, w, h = inputs.size() + inputs = inputs.view(bs, ch, d*w, h) + + _, ch, _, _ = inputs.size() + + # Beta bound and reparam + beta = LowerBound.apply(self.beta, self.beta_bound) + beta = beta**2 - self.pedestal + + # Gamma bound and reparam + gamma = LowerBound.apply(self.gamma, self.gamma_bound) + gamma = gamma**2 - self.pedestal + gamma = gamma.view(ch, ch, 1, 1) + + # Norm pool calc + norm_ = nn.functional.conv2d(inputs**2, gamma, beta) + norm_ = torch.sqrt(norm_) + + # Apply norm + if self.inverse: + outputs = inputs * norm_ + else: + outputs = inputs / norm_ + + if unfold: + outputs = outputs.view(bs, ch, d, w, h) + return outputs + + +def torch_warp(tensorInput, tensorFlow): + if tensorInput.device == torch.device('cpu'): + if str(tensorFlow.size()) not in Backward_tensorGrid_cpu: + tensorHorizontal = torch.linspace(-1.0, 1.0, tensorFlow.size(3)).view( + 1, 1, 1, tensorFlow.size(3)).expand(tensorFlow.size(0), -1, tensorFlow.size(2), -1) + tensorVertical = torch.linspace(-1.0, 1.0, tensorFlow.size(2)).view( + 1, 1, tensorFlow.size(2), 1).expand(tensorFlow.size(0), -1, -1, tensorFlow.size(3)) + Backward_tensorGrid_cpu[str(tensorFlow.size())] = torch.cat( + [tensorHorizontal, tensorVertical], 1).cpu() + + tensorFlow = torch.cat([tensorFlow[:, 0:1, :, :] / ((tensorInput.size(3) - 1.0) / 2.0), + tensorFlow[:, 1:2, :, :] / ((tensorInput.size(2) - 1.0) / 2.0)], 1) + + grid = (Backward_tensorGrid_cpu[str(tensorFlow.size())] + tensorFlow) + return torch.nn.functional.grid_sample(input=tensorInput, + grid=grid.permute(0, 2, 3, 1), + mode='bilinear', + padding_mode='border', + align_corners=True) + else: + device_id = tensorInput.device.index + if str(tensorFlow.size()) not in Backward_tensorGrid[device_id]: + tensorHorizontal = torch.linspace(-1.0, 1.0, tensorFlow.size(3)).view( + 1, 1, 1, tensorFlow.size(3)).expand(tensorFlow.size(0), -1, tensorFlow.size(2), -1) + tensorVertical = torch.linspace(-1.0, 1.0, tensorFlow.size(2)).view( + 1, 1, tensorFlow.size(2), 1).expand(tensorFlow.size(0), -1, -1, tensorFlow.size(3)) + Backward_tensorGrid[device_id][str(tensorFlow.size())] = torch.cat( + [tensorHorizontal, tensorVertical], 1).cuda().to(device_id) + + tensorFlow = torch.cat([tensorFlow[:, 0:1, :, :] / ((tensorInput.size(3) - 1.0) / 2.0), + tensorFlow[:, 1:2, :, :] / ((tensorInput.size(2) - 1.0) / 2.0)], 1) + + grid = (Backward_tensorGrid[device_id][str(tensorFlow.size())] + tensorFlow) + return torch.nn.functional.grid_sample(input=tensorInput, + grid=grid.permute(0, 2, 3, 1), + mode='bilinear', + padding_mode='border', + align_corners=True) + + +def flow_warp(im, flow): + warp = torch_warp(im, flow) + return warp + + +def load_weight_form_np(me_model_dir, layername): + index = layername.find('modelL') + if index == -1: + print('load models error!!') + else: + name = layername[index:index + 11] + modelweight = me_model_dir + name + '-weight.npy' + modelbias = me_model_dir + name + '-bias.npy' + weightnp = np.load(modelweight) + biasnp = np.load(modelbias) + return torch.from_numpy(weightnp), torch.from_numpy(biasnp) + + +def bilinearupsacling(inputfeature): + inputheight = inputfeature.size()[2] + inputwidth = inputfeature.size()[3] + outfeature = F.interpolate( + inputfeature, (inputheight * 2, inputwidth * 2), mode='bilinear', align_corners=False) + return outfeature + + +class ResBlock(nn.Module): + def __init__(self, inputchannel, outputchannel, kernel_size, stride=1): + super(ResBlock, self).__init__() + self.relu1 = nn.ReLU() + self.conv1 = nn.Conv2d(inputchannel, outputchannel, + kernel_size, stride, padding=kernel_size//2) + torch.nn.init.xavier_uniform_(self.conv1.weight.data) + torch.nn.init.constant_(self.conv1.bias.data, 0.0) + self.relu2 = nn.ReLU() + self.conv2 = nn.Conv2d(outputchannel, outputchannel, + kernel_size, stride, padding=kernel_size//2) + torch.nn.init.xavier_uniform_(self.conv2.weight.data) + torch.nn.init.constant_(self.conv2.bias.data, 0.0) + if inputchannel != outputchannel: + self.adapt_conv = nn.Conv2d(inputchannel, outputchannel, 1) + torch.nn.init.xavier_uniform_(self.adapt_conv.weight.data) + torch.nn.init.constant_(self.adapt_conv.bias.data, 0.0) + else: + self.adapt_conv = None + + def forward(self, x): + x_1 = self.relu1(x) + firstlayer = self.conv1(x_1) + firstlayer = self.relu2(firstlayer) + seclayer = self.conv2(firstlayer) + if self.adapt_conv is None: + return x + seclayer + else: + return self.adapt_conv(x) + seclayer + + +class ResBlock_LeakyReLU_0_Point_1(nn.Module): + def __init__(self, d_model): + super(ResBlock_LeakyReLU_0_Point_1, self).__init__() + self.conv = nn.Sequential( + nn.Conv2d(d_model, d_model, 3, stride=1, padding=1), + nn.LeakyReLU(0.1, inplace=True), + nn.Conv2d(d_model, d_model, 3, stride=1, padding=1), + nn.LeakyReLU(0.1, inplace=True)) + + def forward(self, x): + x = x+self.conv(x) + return x + + +class MEBasic(nn.Module): + def __init__(self): + super(MEBasic, self).__init__() + self.conv1 = nn.Conv2d(8, 32, 7, 1, padding=3) + self.relu1 = nn.ReLU() + self.conv2 = nn.Conv2d(32, 64, 7, 1, padding=3) + self.relu2 = nn.ReLU() + self.conv3 = nn.Conv2d(64, 32, 7, 1, padding=3) + self.relu3 = nn.ReLU() + self.conv4 = nn.Conv2d(32, 16, 7, 1, padding=3) + self.relu4 = nn.ReLU() + self.conv5 = nn.Conv2d(16, 2, 7, 1, padding=3) + + + def forward(self, x): + x = self.relu1(self.conv1(x)) + x = self.relu2(self.conv2(x)) + x = self.relu3(self.conv3(x)) + x = self.relu4(self.conv4(x)) + x = self.conv5(x) + return x + + +class ME_Spynet(nn.Module): + def __init__(self): + super(ME_Spynet, self).__init__() + self.L = 4 + self.moduleBasic = torch.nn.ModuleList( + [MEBasic() for intLevel in range(4)]) + + def forward(self, im1, im2): + batchsize = im1.size()[0] + im1_pre = im1 + im2_pre = im2 + + im1list = [im1_pre] + im2list = [im2_pre] + for intLevel in range(self.L - 1): + im1list.append(F.avg_pool2d( + im1list[intLevel], kernel_size=2, stride=2)) + im2list.append(F.avg_pool2d( + im2list[intLevel], kernel_size=2, stride=2)) + + shape_fine = im2list[self.L - 1].size() + zeroshape = [batchsize, 2, shape_fine[2] // 2, shape_fine[3] // 2] + device = im1.device + flowfileds = torch.zeros( + zeroshape, dtype=torch.float32, device=device) + for intLevel in range(self.L): + flowfiledsUpsample = bilinearupsacling(flowfileds) * 2.0 + flowfileds = flowfiledsUpsample + \ + self.moduleBasic[intLevel](torch.cat([im1list[self.L - 1 - intLevel], + flow_warp(im2list[self.L - 1 - intLevel], + flowfiledsUpsample), + flowfiledsUpsample], 1)) + + return flowfileds diff --git a/src/models/waseda.py b/src/models/waseda.py new file mode 100644 index 0000000..bfadee6 --- /dev/null +++ b/src/models/waseda.py @@ -0,0 +1,95 @@ +# Copyright 2020 InterDigital Communications, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch.nn as nn + +from ..layers.layers import ( + ResidualBlock, + ResidualBlockUpsample, + ResidualBlockWithStride, + conv3x3, + subpel_conv3x3, +) + +from .priors import JointAutoregressiveHierarchicalPriors + + +class Cheng2020Anchor(JointAutoregressiveHierarchicalPriors): + """Anchor model variant from `"Learned Image Compression with + Discretized Gaussian Mixture Likelihoods and Attention Modules" + `_, by Zhengxue Cheng, Heming Sun, Masaru + Takeuchi, Jiro Katto. + + Uses residual blocks with small convolutions (3x3 and 1x1), and sub-pixel + convolutions for up-sampling. + + Args: + N (int): Number of channels + """ + + def __init__(self, N=192, **kwargs): + super().__init__(N=N, M=N, **kwargs) + + self.g_a = nn.Sequential( + ResidualBlockWithStride(3, N, stride=2), + ResidualBlock(N, N), + ResidualBlockWithStride(N, N, stride=2), + ResidualBlock(N, N), + ResidualBlockWithStride(N, N, stride=2), + ResidualBlock(N, N), + conv3x3(N, N, stride=2), + ) + + self.h_a = nn.Sequential( + conv3x3(N, N), + nn.LeakyReLU(inplace=True), + conv3x3(N, N), + nn.LeakyReLU(inplace=True), + conv3x3(N, N, stride=2), + nn.LeakyReLU(inplace=True), + conv3x3(N, N), + nn.LeakyReLU(inplace=True), + conv3x3(N, N, stride=2), + ) + + self.h_s = nn.Sequential( + conv3x3(N, N), + nn.LeakyReLU(inplace=True), + subpel_conv3x3(N, N, 2), + nn.LeakyReLU(inplace=True), + conv3x3(N, N * 3 // 2), + nn.LeakyReLU(inplace=True), + subpel_conv3x3(N * 3 // 2, N * 3 // 2, 2), + nn.LeakyReLU(inplace=True), + conv3x3(N * 3 // 2, N * 2), + ) + + self.g_s = nn.Sequential( + ResidualBlock(N, N), + ResidualBlockUpsample(N, N, 2), + ResidualBlock(N, N), + ResidualBlockUpsample(N, N, 2), + ResidualBlock(N, N), + ResidualBlockUpsample(N, N, 2), + ResidualBlock(N, N), + subpel_conv3x3(N, 3, 2), + ) + + @classmethod + def from_state_dict(cls, state_dict): + """Return a new model instance from `state_dict`.""" + N = state_dict["g_a.0.conv1.weight"].size(0) + net = cls(N) + net.load_state_dict(state_dict) + return net \ No newline at end of file diff --git a/src/ops/bound_ops.py b/src/ops/bound_ops.py new file mode 100644 index 0000000..9d1f6fd --- /dev/null +++ b/src/ops/bound_ops.py @@ -0,0 +1,53 @@ +# Copyright 2020 InterDigital Communications, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn + + +class LowerBoundFunction(torch.autograd.Function): + """Autograd function for the `LowerBound` operator.""" + + @staticmethod + def forward(ctx, input_, bound): + ctx.save_for_backward(input_, bound) + return torch.max(input_, bound) + + @staticmethod + def backward(ctx, grad_output): + input_, bound = ctx.saved_tensors + pass_through_if = (input_ >= bound) | (grad_output < 0) + return pass_through_if.type(grad_output.dtype) * grad_output, None + + +class LowerBound(nn.Module): + """Lower bound operator, computes `torch.max(x, bound)` with a custom + gradient. + + The derivative is replaced by the identity function when `x` is moved + towards the `bound`, otherwise the gradient is kept to zero. + """ + + def __init__(self, bound): + super().__init__() + self.register_buffer("bound", torch.Tensor([float(bound)])) + + @torch.jit.unused + def lower_bound(self, x): + return LowerBoundFunction.apply(x, self.bound) + + def forward(self, x): + if torch.jit.is_scripting(): + return torch.max(x, self.bound) + return self.lower_bound(x) diff --git a/src/ops/parametrizers.py b/src/ops/parametrizers.py new file mode 100644 index 0000000..6b2f10f --- /dev/null +++ b/src/ops/parametrizers.py @@ -0,0 +1,45 @@ +# Copyright 2020 InterDigital Communications, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn + +from .bound_ops import LowerBound + + +class NonNegativeParametrizer(nn.Module): + """ + Non negative reparametrization. + + Used for stability during training. + """ + + def __init__(self, minimum=0, reparam_offset=2 ** -18): + super().__init__() + + self.minimum = float(minimum) + self.reparam_offset = float(reparam_offset) + + pedestal = self.reparam_offset ** 2 + self.register_buffer("pedestal", torch.Tensor([pedestal])) + bound = (self.minimum + self.reparam_offset ** 2) ** 0.5 + self.lower_bound = LowerBound(bound) + + def init(self, x): + return torch.sqrt(torch.max(x + self.pedestal, self.pedestal)) + + def forward(self, x): + out = self.lower_bound(x) + out = out ** 2 - self.pedestal + return out diff --git a/src/utils/stream_helper.py b/src/utils/stream_helper.py new file mode 100644 index 0000000..3521064 --- /dev/null +++ b/src/utils/stream_helper.py @@ -0,0 +1,163 @@ +# Copyright 2020 InterDigital Communications, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import struct +from pathlib import Path +import torch +import torch.nn.functional as F +from PIL import Image +from torchvision.transforms import ToPILImage, ToTensor + + +def get_downsampled_shape(height, width, p): + + new_h = (height + p - 1) // p * p + new_w = (width + p - 1) // p * p + return int(new_h / p + 0.5), int(new_w / p + 0.5) + + +def filesize(filepath: str) -> int: + if not Path(filepath).is_file(): + raise ValueError(f'Invalid file "{filepath}".') + return Path(filepath).stat().st_size + + +def load_image(filepath: str) -> Image.Image: + return Image.open(filepath).convert("RGB") + + +def img2torch(img: Image.Image) -> torch.Tensor: + return ToTensor()(img).unsqueeze(0) + + +def torch2img(x: torch.Tensor) -> Image.Image: + return ToPILImage()(x.clamp_(0, 1).squeeze()) + + +def write_uints(fd, values, fmt=">{:d}I"): + fd.write(struct.pack(fmt.format(len(values)), *values)) + + +def write_uchars(fd, values, fmt=">{:d}B"): + fd.write(struct.pack(fmt.format(len(values)), *values)) + + +def read_uints(fd, n, fmt=">{:d}I"): + sz = struct.calcsize("I") + return struct.unpack(fmt.format(n), fd.read(n * sz)) + + +def read_uchars(fd, n, fmt=">{:d}B"): + sz = struct.calcsize("B") + return struct.unpack(fmt.format(n), fd.read(n * sz)) + + +def write_bytes(fd, values, fmt=">{:d}s"): + if len(values) == 0: + return + fd.write(struct.pack(fmt.format(len(values)), values)) + + +def read_bytes(fd, n, fmt=">{:d}s"): + sz = struct.calcsize("s") + return struct.unpack(fmt.format(n), fd.read(n * sz))[0] + + +def pad(x, p=2 ** 6): + h, w = x.size(2), x.size(3) + H = (h + p - 1) // p * p + W = (w + p - 1) // p * p + padding_left = (W - w) // 2 + padding_right = W - w - padding_left + padding_top = (H - h) // 2 + padding_bottom = H - h - padding_top + return F.pad( + x, + (padding_left, padding_right, padding_top, padding_bottom), + mode="constant", + value=0, + ) + + +def crop(x, size): + H, W = x.size(2), x.size(3) + h, w = size + padding_left = (W - w) // 2 + padding_right = W - w - padding_left + padding_top = (H - h) // 2 + padding_bottom = H - h - padding_top + return F.pad( + x, + (-padding_left, -padding_right, -padding_top, -padding_bottom), + mode="constant", + value=0, + ) + + +def encode_i(height, width, y_string, z_string, output): + with Path(output).open("wb") as f: + y_string_length = len(y_string) + z_string_length = len(z_string) + + write_uints(f, (height, width, y_string_length, z_string_length)) + write_bytes(f, y_string) + write_bytes(f, z_string) + + +def decode_i(inputpath): + with Path(inputpath).open("rb") as f: + header = read_uints(f, 4) + height = header[0] + width = header[1] + y_string_length = header[2] + z_string_length = header[3] + + y_string = read_bytes(f, y_string_length) + z_string = read_bytes(f, z_string_length) + + return height, width, y_string, z_string + + +def encode_p(height, width, mv_y_string, mv_z_string, y_string, z_string, output): + with Path(output).open("wb") as f: + mv_y_string_length = len(mv_y_string) + mv_z_string_length = len(mv_z_string) + y_string_length = len(y_string) + z_string_length = len(z_string) + + write_uints(f, (height, width, + mv_y_string_length, mv_z_string_length, + y_string_length, z_string_length)) + write_bytes(f, mv_y_string) + write_bytes(f, mv_z_string) + write_bytes(f, y_string) + write_bytes(f, z_string) + + +def decode_p(inputpath): + with Path(inputpath).open("rb") as f: + header = read_uints(f, 6) + height = header[0] + width = header[1] + mv_y_string_length = header[2] + mv_z_string_length = header[3] + y_string_length = header[4] + z_string_length = header[5] + + mv_y_string = read_bytes(f, mv_y_string_length) + mv_z_string = read_bytes(f, mv_z_string_length) + y_string = read_bytes(f, y_string_length) + z_string = read_bytes(f, z_string_length) + + return height, width, mv_y_string, mv_z_string, y_string, z_string diff --git a/src/zoo/image.py b/src/zoo/image.py new file mode 100644 index 0000000..8f26697 --- /dev/null +++ b/src/zoo/image.py @@ -0,0 +1,32 @@ +# Copyright 2020 InterDigital Communications, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ..models.waseda import ( + Cheng2020Anchor +) + +from ..models.priors import ( + FactorizedPrior, + ScaleHyperprior, + MeanScaleHyperprior, + JointAutoregressiveHierarchicalPriors +) + +model_architectures = { + "bmshj2018-factorized": FactorizedPrior, + "bmshj2018-hyperprior": ScaleHyperprior, + "mbt2018-mean": MeanScaleHyperprior, + "mbt2018": JointAutoregressiveHierarchicalPriors, + "cheng2020-anchor": Cheng2020Anchor, +} diff --git a/test_video.py b/test_video.py new file mode 100644 index 0000000..e592691 --- /dev/null +++ b/test_video.py @@ -0,0 +1,350 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import argparse +import math +import os +import concurrent.futures +import multiprocessing +import torch +import json +import numpy as np +from PIL import Image +from src.models.DCVC_net import DCVC_net +from src.zoo.image import model_architectures as architectures +import time +from tqdm import tqdm +import warnings +from pytorch_msssim import ms_ssim + + +warnings.filterwarnings("ignore", message="Setting attributes on ParameterList is not supported.") + + +def str2bool(v): + return str(v).lower() in ("yes", "y", "true", "t", "1") + + +def parse_args(): + parser = argparse.ArgumentParser(description="Example testing script") + + parser.add_argument('--i_frame_model_name', type=str, default="cheng2020-anchor") + parser.add_argument('--i_frame_model_path', type=str, nargs="+") + parser.add_argument('--model_path', type=str, nargs="+") + parser.add_argument('--test_config', type=str, required=True) + parser.add_argument("--worker", "-w", type=int, default=1, help="worker number") + parser.add_argument("--cuda", type=str2bool, nargs='?', const=True, default=False) + parser.add_argument("--cuda_device", default=None, + help="the cuda device used, e.g., 0; 0,1; 1,2,3; etc.") + parser.add_argument('--write_stream', type=str2bool, nargs='?', + const=True, default=False) + parser.add_argument("--write_recon_frame", type=str2bool, + nargs='?', const=True, default=False) + parser.add_argument('--recon_bin_path', type=str, default="recon_bin_path") + parser.add_argument('--output_json_result_path', type=str, required=True) + parser.add_argument("--model_type", type=str, default="psnr", help="psnr, msssim") + + + args = parser.parse_args() + return args + +def PSNR(input1, input2): + mse = torch.mean((input1 - input2) ** 2) + psnr = 20 * torch.log10(1 / torch.sqrt(mse)) + return psnr.item() + +def read_frame_to_torch(path): + input_image = Image.open(path).convert('RGB') + input_image = np.asarray(input_image).astype('float64').transpose(2, 0, 1) + input_image = torch.from_numpy(input_image).type(torch.FloatTensor) + input_image = input_image.unsqueeze(0)/255 + return input_image + +def write_torch_frame(frame, path): + frame_result = frame.clone() + frame_result = frame_result.cpu().detach().numpy().transpose(1, 2, 0)*255 + frame_result = np.clip(np.rint(frame_result), 0, 255) + frame_result = Image.fromarray(frame_result.astype('uint8'), 'RGB') + frame_result.save(path) + +def encode_one(args_dict, device): + i_frame_load_checkpoint = torch.load(args_dict['i_frame_model_path'], + map_location=torch.device('cpu')) + i_frame_net = architectures[args_dict['i_frame_model_name']].from_state_dict( + i_frame_load_checkpoint).eval() + + video_net = DCVC_net() + load_checkpoint = torch.load(args_dict['model_path'], map_location=torch.device('cpu')) + video_net.load_dict(load_checkpoint) + + video_net = video_net.to(device) + video_net.eval() + i_frame_net = i_frame_net.to(device) + i_frame_net.eval() + if args_dict['write_stream']: + video_net.update(force=True) + i_frame_net.update(force=True) + + sub_dir_name = args_dict['video_path'] + ref_frame = None + frame_types = [] + qualitys = [] + bits = [] + bits_mv_y = [] + bits_mv_z = [] + bits_y = [] + bits_z = [] + + gop_size = args_dict['gop'] + frame_pixel_num = 0 + frame_num = args_dict['frame_num'] + + recon_bin_folder = os.path.join(args_dict['recon_bin_path'], sub_dir_name, os.path.basename(args_dict['model_path'])[:-4]) + if not os.path.exists(recon_bin_folder): + os.makedirs(recon_bin_folder) + + # Figure out the naming convention + pngs = os.listdir(os.path.join(args_dict['dataset_path'], sub_dir_name)) + if 'im1.png' in pngs: + padding = 1 + elif 'im00001.png' in pngs: + padding = 5 + else: + raise ValueError('unknown image naming convention; please specify') + + with torch.no_grad(): + for frame_idx in range(frame_num): + ori_frame = read_frame_to_torch( + os.path.join(args_dict['dataset_path'], + sub_dir_name, + f"im{str(frame_idx+1).zfill(padding)}.png")) + ori_frame = ori_frame.to(device) + + if frame_pixel_num == 0: + frame_pixel_num = ori_frame.shape[2]*ori_frame.shape[3] + else: + assert(frame_pixel_num == ori_frame.shape[2]*ori_frame.shape[3]) + + if args_dict['write_stream']: + bin_path = os.path.join(recon_bin_folder, f"{frame_idx}.bin") + if frame_idx % gop_size == 0: + result = i_frame_net.encode_decode(ori_frame, bin_path) + ref_frame = result["x_hat"] + bpp = result["bpp"] + frame_types.append(0) + bits.append(bpp*frame_pixel_num) + bits_mv_y.append(0) + bits_mv_z.append(0) + bits_y.append(0) + bits_z.append(0) + else: + result = video_net.encode_decode(ref_frame, ori_frame, bin_path) + ref_frame = result['recon_image'] + bpp = result['bpp'] + frame_types.append(1) + bits.append(bpp*frame_pixel_num) + bits_mv_y.append(result['bpp_mv_y']*frame_pixel_num) + bits_mv_z.append(result['bpp_mv_z']*frame_pixel_num) + bits_y.append(result['bpp_y']*frame_pixel_num) + bits_z.append(result['bpp_z']*frame_pixel_num) + else: + if frame_idx % gop_size == 0: + result = i_frame_net(ori_frame) + bit = sum((torch.log(likelihoods).sum() / (-math.log(2))) + for likelihoods in result["likelihoods"].values()) + ref_frame = result["x_hat"] + frame_types.append(0) + bits.append(bit.item()) + bits_mv_y.append(0) + bits_mv_z.append(0) + bits_y.append(0) + bits_z.append(0) + else: + result = video_net(ref_frame, ori_frame) + ref_frame = result['recon_image'] + bpp = result['bpp'] + frame_types.append(1) + bits.append(bpp.item()*frame_pixel_num) + bits_mv_y.append(result['bpp_mv_y'].item()*frame_pixel_num) + bits_mv_z.append(result['bpp_mv_z'].item()*frame_pixel_num) + bits_y.append(result['bpp_y'].item()*frame_pixel_num) + bits_z.append(result['bpp_z'].item()*frame_pixel_num) + + ref_frame = ref_frame.clamp_(0, 1) + if args_dict['write_recon_frame']: + write_torch_frame(ref_frame.squeeze(),os.path.join(recon_bin_folder, f"recon_frame_{frame_idx}.png")) + if args_dict['model_type'] == 'psnr': + qualitys.append(PSNR(ref_frame, ori_frame)) + else: + qualitys.append( + ms_ssim(ref_frame, ori_frame, data_range=1.0).item()) + + cur_all_i_frame_bit = 0 + cur_all_i_frame_quality = 0 + cur_all_p_frame_bit = 0 + cur_all_p_frame_bit_mv_y = 0 + cur_all_p_frame_bit_mv_z = 0 + cur_all_p_frame_bit_y = 0 + cur_all_p_frame_bit_z = 0 + cur_all_p_frame_quality = 0 + cur_i_frame_num = 0 + cur_p_frame_num = 0 + for idx in range(frame_num): + if frame_types[idx] == 0: + cur_all_i_frame_bit += bits[idx] + cur_all_i_frame_quality += qualitys[idx] + cur_i_frame_num += 1 + else: + cur_all_p_frame_bit += bits[idx] + cur_all_p_frame_bit_mv_y += bits_mv_y[idx] + cur_all_p_frame_bit_mv_z += bits_mv_z[idx] + cur_all_p_frame_bit_y += bits_y[idx] + cur_all_p_frame_bit_z += bits_z[idx] + cur_all_p_frame_quality += qualitys[idx] + cur_p_frame_num += 1 + + log_result = {} + log_result['name'] = f"{os.path.basename(args_dict['model_path'])}_{sub_dir_name}" + log_result['ds_name'] = args_dict['ds_name'] + log_result['video_path'] = args_dict['video_path'] + log_result['frame_pixel_num'] = frame_pixel_num + log_result['i_frame_num'] = cur_i_frame_num + log_result['p_frame_num'] = cur_p_frame_num + log_result['ave_i_frame_bpp'] = cur_all_i_frame_bit / cur_i_frame_num / frame_pixel_num + log_result['ave_i_frame_quality'] = cur_all_i_frame_quality / cur_i_frame_num + if cur_p_frame_num > 0: + total_p_pixel_num = cur_p_frame_num * frame_pixel_num + log_result['ave_p_frame_bpp'] = cur_all_p_frame_bit / total_p_pixel_num + log_result['ave_p_frame_bpp_mv_y'] = cur_all_p_frame_bit_mv_y / total_p_pixel_num + log_result['ave_p_frame_bpp_mv_z'] = cur_all_p_frame_bit_mv_z / total_p_pixel_num + log_result['ave_p_frame_bpp_y'] = cur_all_p_frame_bit_y / total_p_pixel_num + log_result['ave_p_frame_bpp_z'] = cur_all_p_frame_bit_z / total_p_pixel_num + log_result['ave_p_frame_quality'] = cur_all_p_frame_quality / cur_p_frame_num + else: + log_result['ave_p_frame_bpp'] = 0 + log_result['ave_p_frame_quality'] = 0 + log_result['ave_p_frame_bpp_mv_y'] = 0 + log_result['ave_p_frame_bpp_mv_z'] = 0 + log_result['ave_p_frame_bpp_y'] = 0 + log_result['ave_p_frame_bpp_z'] = 0 + log_result['ave_all_frame_bpp'] = (cur_all_i_frame_bit + cur_all_p_frame_bit) / \ + (frame_num * frame_pixel_num) + log_result['ave_all_frame_quality'] = (cur_all_i_frame_quality + cur_all_p_frame_quality) / frame_num + return log_result + + +def worker(use_cuda, args): + if args['write_stream']: + torch.backends.cudnn.benchmark = False + if 'use_deterministic_algorithms' in dir(torch): + torch.use_deterministic_algorithms(True) + else: + torch.set_deterministic(True) + torch.manual_seed(0) + torch.set_num_threads(1) + np.random.seed(seed=0) + gpu_num = 0 + if use_cuda: + gpu_num = torch.cuda.device_count() + + process_name = multiprocessing.current_process().name + process_idx = int(process_name[process_name.rfind('-') + 1:]) + gpu_id = -1 + if gpu_num > 0: + gpu_id = process_idx % gpu_num + if gpu_id >= 0: + device = f"cuda:{gpu_id}" + else: + device = "cpu" + + result = encode_one(args, device) + result['model_idx'] = args['model_idx'] + return result + + +def filter_dict(result): + keys = ['i_frame_num', 'p_frame_num', 'ave_i_frame_bpp', 'ave_i_frame_quality', 'ave_p_frame_bpp', + 'ave_p_frame_bpp_mv_y', 'ave_p_frame_bpp_mv_z', 'ave_p_frame_bpp_y', + 'ave_p_frame_bpp_z', 'ave_p_frame_quality','ave_all_frame_bpp','ave_all_frame_quality'] + res = {k: v for k, v in result.items() if k in keys} + return res + + +def main(): + torch.backends.cudnn.enabled = True + args = parse_args() + if args.cuda_device is not None and args.cuda_device != '': + os.environ['CUDA_VISIBLE_DEVICES'] = args.cuda_device + os.environ['CUBLAS_WORKSPACE_CONFIG'] = ":4096:8" + worker_num = args.worker + assert worker_num >= 1 + + with open(args.test_config) as f: + config = json.load(f) + + multiprocessing.set_start_method("spawn") + threadpool_executor = concurrent.futures.ProcessPoolExecutor(max_workers=worker_num) + objs = [] + + count_frames = 0 + count_sequences = 0 + begin_time = time.time() + for ds_name in config: + for seq_name in config[ds_name]['sequences']: + count_sequences += 1 + for model_idx in range(len(args.model_path)): + cur_dict = {} + cur_dict['model_idx'] = model_idx + cur_dict['i_frame_model_path'] = args.i_frame_model_path[model_idx] + cur_dict['i_frame_model_name'] = args.i_frame_model_name + cur_dict['model_path'] = args.model_path[model_idx] + cur_dict['video_path'] = seq_name + cur_dict['gop'] = config[ds_name]['sequences'][seq_name]['gop'] + cur_dict['frame_num'] = config[ds_name]['sequences'][seq_name]['frames'] + cur_dict['dataset_path'] = config[ds_name]['base_path'] + cur_dict['write_stream'] = args.write_stream + cur_dict['write_recon_frame'] = args.write_recon_frame + cur_dict['recon_bin_path'] = args.recon_bin_path + cur_dict['model_type'] = args.model_type + cur_dict['ds_name'] = ds_name + + count_frames += cur_dict['frame_num'] + + obj = threadpool_executor.submit( + worker, + args.cuda, + cur_dict) + objs.append(obj) + + results = [] + for obj in tqdm(objs): + result = obj.result() + results.append(result) + + log_result = {} + + for ds_name in config: + log_result[ds_name] = {} + for seq in config[ds_name]['sequences']: + log_result[ds_name][seq] = {} + for model_idx in range(len(args.model_path)): + ckpt = os.path.basename(args.model_path[model_idx]) + for res in results: + if res['name'].startswith(ckpt) and ds_name == res['ds_name'] \ + and seq == res['video_path']: + log_result[ds_name][seq][ckpt] = filter_dict(res) + + with open(args.output_json_result_path, 'w') as fp: + json.dump(log_result, fp, indent=2) + + total_minutes = (time.time() - begin_time) / 60 + + count_models = len(args.model_path) + count_frames = count_frames // count_models + print('Test finished') + print(f'Tested {count_models} models on {count_frames} frames from {count_sequences} sequences') + print(f'Total elapsed time: {total_minutes:.1f} min') + + +if __name__ == "__main__": + main() diff --git a/write_stream_readme.md b/write_stream_readme.md new file mode 100644 index 0000000..826fd02 --- /dev/null +++ b/write_stream_readme.md @@ -0,0 +1,32 @@ +Currently writing bitstream is very slow due to the auto-regressive model. If you want to write bitstream, you need to build the arithmetic coder first. + +# Build +* Build on Windows + + CMake and Visual Studio 2019 are needed. + ```bash + cd src + mkdir build + cd build + conda activate $YOUR_PY38_ENV_NAME + cmake ../cpp -G "Visual Studio 16 2019" -A x64 + cmake --build . --config Release + ``` + +* Build on Linux (recommended) + + CMake and g++ are needed. + ```bash + sudo apt-get install cmake g++ + cd src + mkdir build + cd build + conda activate $YOUR_PY38_ENV_NAME + cmake ../cpp -DCMAKE_BUILD_TYPE=Release + make -j + ``` +# Test +Please append this into your test command: +``` +--write_stream True +``` \ No newline at end of file