diff --git a/.github/workflows/build_package.yml b/.github/workflows/build_package.yml index 1d4998e28ea9..724c4c324221 100644 --- a/.github/workflows/build_package.yml +++ b/.github/workflows/build_package.yml @@ -41,6 +41,9 @@ jobs: - os: ubuntu-18.04 build_package: py-runtime-pkg experimental: false + - os: ubuntu-18.04 + build_package: instrumented-py-runtime-pkg + experimental: true - os: ubuntu-18.04 build_package: py-xla-compiler-tools-pkg experimental: false @@ -122,7 +125,8 @@ jobs: run: | python -m pip install cibuildwheel==1.7.2 - - name: Write version info + - name: Write version info (release) + if: "!startsWith(matrix.build_package, 'instrumented-')" shell: bash run: | cat << EOF > ./main_checkout/version_info.json @@ -134,6 +138,19 @@ jobs: EOF cat ./main_checkout/version_info.json + - name: Write version info (instrumented) + if: "startsWith(matrix.build_package, 'instrumented-')" + shell: bash + run: | + cat << EOF > ./main_checkout/version_info.json + { + "package-suffix": "-instrumented${{ github.event.inputs.package_suffix }}", + "package-version": "${{ github.event.inputs.package_version }}", + "iree-revision": "$(cd ./main_checkout && git rev-parse HEAD)" + } + EOF + cat ./main_checkout/version_info.json + # The main distribution consists of the project being built, installed # and archived. We have to split it per operating system, and Linux # is special because we build under a manylinux container which gives @@ -167,6 +184,18 @@ jobs: mkdir -p $package_dir && touch $package_dir/setup.py python -m cibuildwheel --output-dir bindist $package_dir + - name: Build runtime wheels (instrumented) + if: "matrix.build_package == 'instrumented-py-runtime-pkg'" + shell: bash + run: | + package_dir="./iree-install/python_packages/iree_runtime" + export CIBW_BEFORE_ALL_LINUX="./main_checkout/build_tools/github_actions/install_tracy_cli_deps_manylinux2014.sh" + export CIBW_BEFORE_BUILD="python ./main_checkout/build_tools/github_actions/build_dist.py instrumented-py-runtime-pkg" + # TODO: cibuildwheel sanity checks this, but our setup.py is the + # *output* of the build :( Make source packages. + mkdir -p $package_dir && touch $package_dir/setup.py + python -m cibuildwheel --output-dir bindist $package_dir + # Experimental iree.compiler package. - name: Build compiler wheels if: "matrix.build_package == 'py-compiler-pkg'" diff --git a/CMakeLists.txt b/CMakeLists.txt index 6205283f1de2..231833c809bc 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -43,6 +43,7 @@ option(IREE_BUILD_TESTS "Builds IREE unit tests." ON) option(IREE_BUILD_BENCHMARKS "Builds IREE benchmark suites." OFF) option(IREE_BUILD_DOCS "Builds IREE docs." OFF) option(IREE_BUILD_SAMPLES "Builds IREE sample projects." ON) +option(IREE_BUILD_TRACY "Builds tracy server tools." OFF) option(IREE_BUILD_TENSORFLOW_ALL "Builds all TensorFlow compiler frontends." OFF) option(IREE_BUILD_TENSORFLOW_COMPILER "Builds TensorFlow compiler frontend." OFF) @@ -446,10 +447,6 @@ if(IREE_BUILD_COMPILER) add_subdirectory(build_tools/third_party/mlir-hlo EXCLUDE_FROM_ALL) endif() -if(IREE_ENABLE_EMITC) - add_subdirectory(build_tools/third_party/mlir-emitc EXCLUDE_FROM_ALL) -endif() - if(IREE_BUILD_TESTS) enable_testing(iree) endif() @@ -516,8 +513,18 @@ endif() add_subdirectory(iree/tools) +if(IREE_BUILD_TRACY) + if(NOT LINUX) + message(WARNING "Building Tracy (IREE_BUILD_TRACY) on non-Linux is unsupported and may fail below.") + endif() + add_subdirectory(build_tools/third_party/tracy ${CMAKE_CURRENT_BINARY_DIR}/tracy) + if(NOT TARGET IREETracyCaptureServer) + message(SEND_ERROR "Could not build Tracy. Either unset IREE_BUILD_TRACY or look for missing dependencies above and install them.") + endif() +endif() + # Order constraint: The python bindings install tools targets from iree/tools -# and must come after it. +# and tracy, and must come after it. if(${IREE_BUILD_PYTHON_BINDINGS}) add_subdirectory(bindings/python) endif() diff --git a/benchmarks/TensorFlow/CMakeLists.txt b/benchmarks/TensorFlow/CMakeLists.txt index 6f95f86c061a..48c79b3ebdb6 100644 --- a/benchmarks/TensorFlow/CMakeLists.txt +++ b/benchmarks/TensorFlow/CMakeLists.txt @@ -270,6 +270,27 @@ iree_mlir_benchmark_suite( "--batch_size=32" ) +# GPU, Vulkan, Mali, full-inference +iree_mlir_benchmark_suite( + MODULES + ${MOBILEBERT_FP16_MODULE} + + BENCHMARK_MODES + "full-inference" + TARGET_BACKEND + "vulkan-spirv" + TARGET_ARCHITECTURE + "GPU-Mali-Valhall" + TRANSLATION_FLAGS + "--iree-input-type=mhlo" + "--iree-flow-demote-f32-to-f16" + "--iree-vulkan-target-triple=valhall-unknown-android11" + "--iree-flow-inline-constants-max-byte-length=16" + "--iree-enable-fusion-with-reduction-ops" + DRIVER + "vulkan" +) + ################################################################################ # # # Speical benchmark configurations # diff --git a/bindings/python/iree/runtime/CMakeLists.txt b/bindings/python/iree/runtime/CMakeLists.txt index dccdf00fe256..1e1a11e99f46 100644 --- a/bindings/python/iree/runtime/CMakeLists.txt +++ b/bindings/python/iree/runtime/CMakeLists.txt @@ -4,6 +4,17 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +set(_python_extra_srcs) +set(_extra_install_tool_targets) +set(_tracy_enabled OFF) + +if(TARGET IREETracyCaptureServer) + message(STATUS "Bundline Tracy CLI tools with Python API") + set(_tracy_enabled ON) + list(APPEND _python_extra_srcs "scripts/iree-tracy-capture") + list(APPEND _extra_install_tool_targets "IREETracyCaptureServer") +endif() + ################################################################################ # Package ################################################################################ @@ -43,6 +54,8 @@ iree_py_library( "tracing.py" "scripts/iree_benchmark_trace/__main__.py" "scripts/iree_run_trace/__main__.py" + "scripts/iree_run_module/__main__.py" + ${_python_extra_srcs} PYEXT_DEPS ::PyExtRt ) @@ -59,6 +72,20 @@ iree_symlink_tool( TO_EXE_NAME iree-run-trace ) +iree_symlink_tool( + TARGET runtime + FROM_TOOL_TARGET iree_tools_iree-run-module + TO_EXE_NAME iree-run-module +) + +if(_tracy_enabled) + iree_symlink_tool( + TARGET runtime + FROM_TOOL_TARGET IREETracyCaptureServer + TO_EXE_NAME iree-tracy-capture + ) +endif() + ################################################################################ # Tests ################################################################################ @@ -102,7 +129,9 @@ iree_py_install_package( DEPS bindings_python_iree_runtime_PyExtRt iree_tools_iree-benchmark-trace + iree_tools_iree-run-module iree_tools_iree-run-trace + ${_extra_install_tool_targets} ADDL_PACKAGE_FILES ${CMAKE_CURRENT_SOURCE_DIR}/README.md ) @@ -116,7 +145,9 @@ install( install( TARGETS iree_tools_iree-benchmark-trace + iree_tools_iree-run-module iree_tools_iree-run-trace + ${_extra_install_tool_targets} DESTINATION "${PY_INSTALL_MODULE_DIR}" COMPONENT "${PY_INSTALL_COMPONENT}" ) diff --git a/bindings/python/iree/runtime/scripts/iree_run_module/__main__.py b/bindings/python/iree/runtime/scripts/iree_run_module/__main__.py new file mode 100644 index 000000000000..a5509a3d013a --- /dev/null +++ b/bindings/python/iree/runtime/scripts/iree_run_module/__main__.py @@ -0,0 +1,20 @@ +# Copyright 2021 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import os +import subprocess +import sys + + +def main(args=None): + if args is None: + args = sys.argv[1:] + exe = os.path.join(os.path.dirname(__file__), "..", "..", "iree-run-module") + return subprocess.call(args=[exe] + args) + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/bindings/python/iree/runtime/scripts/iree_tracy_capture/__main__.py b/bindings/python/iree/runtime/scripts/iree_tracy_capture/__main__.py new file mode 100644 index 000000000000..58f2118d7c3b --- /dev/null +++ b/bindings/python/iree/runtime/scripts/iree_tracy_capture/__main__.py @@ -0,0 +1,21 @@ +# Copyright 2021 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import os +import subprocess +import sys + + +def main(args=None): + if args is None: + args = sys.argv[1:] + exe = os.path.join(os.path.dirname(__file__), "..", "..", + "iree-tracy-capture") + return subprocess.call(args=[exe] + args) + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/bindings/python/iree/runtime/setup.py.in b/bindings/python/iree/runtime/setup.py.in index dc25a35ab59d..2d0cfac95ade 100644 --- a/bindings/python/iree/runtime/setup.py.in +++ b/bindings/python/iree/runtime/setup.py.in @@ -41,14 +41,18 @@ setup( package_data={ "": [ f"*{sysconfig.get_config_var('EXT_SUFFIX')}", + "iree-run-module*", "iree-run-trace*", "iree-benchmark-trace*", + "iree-tracy-capture*", ], }, entry_points={ "console_scripts": [ + "iree-run-module = iree.runtime.scripts.iree_run_module.__main__:main", "iree-run-trace = iree.runtime.scripts.iree_run_trace.__main__:main", "iree-benchmark-trace = iree.runtime.scripts.iree_benchmark_trace.__main__:main", + "iree-tracy-capture = iree.runtime.scripts.iree_tracy_capture.__main__:main", ], }, zip_safe=False, diff --git a/build_tools/github_actions/build_dist.py b/build_tools/github_actions/build_dist.py index 00fb32649bc0..f9e9545520e4 100644 --- a/build_tools/github_actions/build_dist.py +++ b/build_tools/github_actions/build_dist.py @@ -210,17 +210,30 @@ def build_py_pure_pkgs(): check=True) -def build_py_runtime_pkg(): +def build_py_runtime_pkg(instrumented: bool = False): """Builds the iree-install/python_packages/iree_runtime package. This includes native, python-version dependent code and is designed to be built multiple times. + + Note that an instrumented build may require additional dependencies. + See: install_tracy_cli_deps_manylinux2014.sh for how to set up on that + container. """ install_python_requirements() # Clean up install and build trees. shutil.rmtree(INSTALL_DIR, ignore_errors=True) remove_cmake_cache() + extra_cmake_flags = [] + + # Extra options for instrumentation. + if instrumented: + print("*** Enabling options for instrumented build ***") + extra_cmake_flags.extend([ + f"-DIREE_ENABLE_RUNTIME_TRACING=ON", + f"-DIREE_BUILD_TRACY=ON", + ]) # CMake configure. print("*** Configuring ***") @@ -234,7 +247,7 @@ def build_py_runtime_pkg(): f"-DIREE_BUILD_PYTHON_BINDINGS=ON", f"-DIREE_BUILD_SAMPLES=OFF", f"-DIREE_BUILD_TESTS=OFF", - ], + ] + extra_cmake_flags, check=True) print("*** Building ***") @@ -405,6 +418,8 @@ def build_py_tf_compiler_tools_pkg(): build_main_dist() elif command == "py-runtime-pkg": build_py_runtime_pkg() +elif command == "instrumented-py-runtime-pkg": + build_py_runtime_pkg(instrumented=True) elif command == "py-pure-pkgs": build_py_pure_pkgs() elif command == "py-xla-compiler-tools-pkg": diff --git a/build_tools/github_actions/install_tbb_manylinux2014.sh b/build_tools/github_actions/install_tbb_manylinux2014.sh new file mode 100755 index 000000000000..7c8e9b10d3ad --- /dev/null +++ b/build_tools/github_actions/install_tbb_manylinux2014.sh @@ -0,0 +1,50 @@ +#!/bin/bash +# Copyright 2021 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# The version of tbb installed on manylinux2014 is too old to support the +# parallel STL libraries on the installed GCC9-based toolchain. Further, +# Intel *broke* compatibility starting in 2021 for GCC<=10. +# To make matters worse, the prior 2020 versions did not have cmake or +# install support. +# Shame on you Intel. +# See: https://community.intel.com/t5/Intel-oneAPI-Threading-Building/tbb-task-has-not-been-declared/m-p/1254418 +# Since this is unlikely to be helpful outside of the old centos systems +# that manylinux2014 is based on (newer ones are based on Debian), +# we just tailor this specifically for docker images of that distro. + +# You can test this with either an official manylinux2014 docker image or +# our special one (which is really only special in that it includes bazel): +# docker run --rm -it -v $(pwd):/work stellaraccident/manylinux2014_x86_64-bazel-3.7.2:latest /bin/bash + +set -e + +mkdir -p /tmp/libtbb_build +cd /tmp/libtbb_build +curl -o tbbsrc.tgz -L https://github.com/oneapi-src/oneTBB/archive/refs/tags/v2020.3.tar.gz +tar xzf tbbsrc.tgz +cd oneTBB-*/ + +echo "****** BUILDING TBB ******" +make -j$(nproc) +cp -R include/* /usr/include +cp build/*_release/* /usr/lib64 +echo "prefix=/usr +exec_prefix=${prefix} +libdir=${exec_prefix}/lib64 +includedir=${prefix}/include + +Name: Threading Building Blocks +Description: Intel's parallelism library for C++ +URL: http://www.threadingbuildingblocks.org/ +Version: +Libs: -ltbb +Cflags: +" > /usr/lib64/pkgconfig/tbb.pc + +echo "****** DONE BUILDING TBB ******" + +cd / +rm -Rf /tmp/libtbb_build diff --git a/build_tools/github_actions/install_tracy_cli_deps_manylinux2014.sh b/build_tools/github_actions/install_tracy_cli_deps_manylinux2014.sh new file mode 100755 index 000000000000..2eda73690555 --- /dev/null +++ b/build_tools/github_actions/install_tracy_cli_deps_manylinux2014.sh @@ -0,0 +1,14 @@ +#!/bin/bash +# Copyright 2021 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# Installs deps on a manylinux2014 CentOS docker container needed for +# building Tracy CLI capture tool. + +set -e + +td="$(cd $(dirname $0) && pwd)" +yum -y install capstone-devel libzstd-devel +$td/install_tbb_manylinux2014.sh diff --git a/build_tools/third_party/mlir-emitc/CMakeLists.txt b/build_tools/third_party/mlir-emitc/CMakeLists.txt deleted file mode 100644 index 8725ba19c49e..000000000000 --- a/build_tools/third_party/mlir-emitc/CMakeLists.txt +++ /dev/null @@ -1,30 +0,0 @@ -# Copyright 2021 The IREE Authors -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -set(MLIR_EMITC_SOURCE_DIR - "${IREE_SOURCE_DIR}/third_party/mlir-emitc/" -) - -external_cc_library( - PACKAGE - emitc - NAME - TranslateToCpp - ROOT - ${MLIR_EMITC_SOURCE_DIR} - HDRS - "include/emitc/Target/Cpp/CppEmitter.h" - SRCS - "lib/Target/Cpp/TranslateToCpp.cpp" - DEPS - MLIREmitC - MLIRIR - MLIRSCF - MLIRStandard - INCLUDES - "${MLIR_EMITC_SOURCE_DIR}/include/" - PUBLIC -) diff --git a/build_tools/third_party/tracy/CMakeLists.txt b/build_tools/third_party/tracy/CMakeLists.txt new file mode 100644 index 000000000000..7c1e43df4789 --- /dev/null +++ b/build_tools/third_party/tracy/CMakeLists.txt @@ -0,0 +1,183 @@ +# Copyright 2021 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +cmake_minimum_required(VERSION 3.16.3) + +project(IREETracyServer C CXX) + +set(TRACY_SOURCE_DIR "${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/tracy") + +find_package(Threads REQUIRED) + +#------------------------------------------------------------------------------- +# Detect package manager +#------------------------------------------------------------------------------- + +message(STATUS "Checking for Tracy dependencies...") +find_program(PKGCONFIG pkg-config) +if(NOT PKGCONFIG) + message(STATUS "Could not find pkg-config (on Ubuntu/Debian, 'apt install pkg-config')") +else() + include(FindPkgConfig) + pkg_check_modules(TRACY_DEPS + capstone + tbb + libzstd + ) + pkg_check_modules(TRACY_GTK_DEPS + freetype2 + glfw3 + gtk+-3.0 + ) +endif() + +if(NOT TRACY_DEPS_FOUND) + message(STATUS "Could not find Tracy dependencies (Tracy server will not be built).") + message(STATUS "To build Tracy, install packages capstone, and tbb (on Ubuntu/Debian, 'apt install libcapstone-dev libtbb-dev libzstd-dev')") + return() +endif() + +if(NOT TRACY_GTK_DEPS_FOUND) + message(STATUS + "Could not find deps required to build graphical programs: " + "Tracy graphical profiler will not be built (on Ubuntu/Debian, 'apt install libglfw3-dev libfreetype-dev libgtk-3-dev')") +endif() + +#------------------------------------------------------------------------------- +# Configuration +#------------------------------------------------------------------------------- + +function(setup_cxx_options name) + set_target_properties(${name} + PROPERTIES + CXX_STANDARD 17 + ) + target_compile_options(${name} + PRIVATE + $<$:-Wno-unused-result> + ) + target_include_directories(${name} + PUBLIC + ${TRACY_SOURCE_DIR}/imgui + ${TRACY_DEPS_INCLUDE_DIRS} + ) + target_link_libraries(${name} + PRIVATE + ${TRACY_DEPS_LIBRARIES} + ${CMAKE_DL_LIBS} + ${CMAKE_THREAD_LIBS_INIT} + ) +endfunction() + +function(setup_graphics_deps name) + target_compile_definitions(${name} + PRIVATE + DISPLAY_SERVER_X11 + ) + target_include_directories(${name} + PUBLIC + ${TRACY_GTK_DEPS_INCLUDE_DIRS} + ) + target_link_libraries(${name} + PRIVATE + ${TRACY_GTK_DEPS_LIBRARIES} + ) +endfunction() + +#------------------------------------------------------------------------------- +# Common library +#------------------------------------------------------------------------------- + +file(GLOB COMMON_SRCS ${TRACY_SOURCE_DIR}/common/*.cpp) +add_library(IREETracyCommon + ${COMMON_SRCS} +) +setup_cxx_options(IREETracyCommon) + +#------------------------------------------------------------------------------- +# Server library +#------------------------------------------------------------------------------- + +file(GLOB SERVER_SRCS ${TRACY_SOURCE_DIR}/server/*.cpp) +add_library(IREETracyServer + ${SERVER_SRCS} +) +setup_cxx_options(IREETracyServer) +target_link_libraries(IREETracyServer + PRIVATE + IREETracyCommon +) + +#------------------------------------------------------------------------------- +# IMGUI library +#------------------------------------------------------------------------------- + +file(GLOB IMGUI_SOURCES ${TRACY_SOURCE_DIR}/imgui/*.cpp) +add_library(IREETracyIMGUI + ${IMGUI_SOURCES} +) +setup_cxx_options(IREETracyServer) + +#------------------------------------------------------------------------------- +# Standalone capture server +#------------------------------------------------------------------------------- + +add_executable(IREETracyCaptureServer + ${TRACY_SOURCE_DIR}/capture/src/capture.cpp +) +set_target_properties(IREETracyCaptureServer + PROPERTIES + OUTPUT_NAME "iree-tracy-capture" +) +setup_cxx_options(IREETracyCaptureServer) +target_link_libraries(IREETracyCaptureServer + PRIVATE + IREETracyCommon + IREETracyServer +) + +#------------------------------------------------------------------------------- +# Graphical frontends +#------------------------------------------------------------------------------- + +if(TRACY_GTK_DEPS_FOUND) + #----------------------------------------------------------------------------- + # NFD library + #----------------------------------------------------------------------------- + + set(NFD_SOURCES + ${TRACY_SOURCE_DIR}/nfd/nfd_common.c + ${TRACY_SOURCE_DIR}/nfd/nfd_gtk.c + ) + add_library(IREETracyNFD + ${NFD_SOURCES} + ) + setup_cxx_options(IREETracyNFD) + setup_graphics_deps(IREETracyNFD) + + #----------------------------------------------------------------------------- + # Profiler + #----------------------------------------------------------------------------- + + file(GLOB PROFILER_SRCS ${TRACY_SOURCE_DIR}/profiler/src/*.cpp) + add_executable(IREETracyProfiler + ${PROFILER_SRCS} + ) + set_target_properties(IREETracyProfiler + PROPERTIES + OUTPUT_NAME "iree-tracy-profiler" + ) + setup_cxx_options(IREETracyProfiler) + setup_graphics_deps(IREETracyProfiler) + target_link_libraries(IREETracyProfiler + PRIVATE + IREETracyIMGUI + IREETracyCommon + IREETracyNFD + IREETracyServer + ${CMAKE_THREAD_LIBS_INIT} + ) +endif() diff --git a/docs/README.md b/docs/README.md index d6625e32041a..43bf803aca7b 100644 --- a/docs/README.md +++ b/docs/README.md @@ -17,5 +17,8 @@ A high bar should be set for pages published to the website: overspecialize on a specific Linux distribution or a particular version of Visual Studio on Windows) +When in doubt, the guide at https://developers.google.com/style offers good +instructions. + Developer documentation _can_ compromise on each of these points. Pages may also be promoted to website/ after some refinement. diff --git a/docs/website/docs/bindings/index.md b/docs/website/docs/bindings/index.md new file mode 100644 index 000000000000..b4c32e38e80f --- /dev/null +++ b/docs/website/docs/bindings/index.md @@ -0,0 +1,8 @@ +# Bindings + +IREE offers specialized sets of bindings for running compiled programs from +various languages or with specific APIs: + +* [Runtime C API](./c-api.md) +* [Compiler and runtime Python bindings](./python.md) +* [Runtime TensorFlow Lite bindings](./tensorflow-lite.md) diff --git a/docs/website/docs/blog/index.md b/docs/website/docs/blog/index.md new file mode 100644 index 000000000000..6d789dc459ed --- /dev/null +++ b/docs/website/docs/blog/index.md @@ -0,0 +1,7 @@ +# Blog + +## Latest posts from the IREE team + +* 2021-10-15: [CUDA backend](./2021-10-15-cuda-backend.md) +* 2021-10-13: [Work in progress on Matrix Multiplication on CPU](./2021-10-13-mmt4d.md) +* 2021-07-19: [TFLite Support via TOSA](./2021-07-19-tflite-tosa.md) diff --git a/docs/website/docs/building-from-source/index.md b/docs/website/docs/building-from-source/index.md index 92a3f1b5c970..b0c3cadd3782 100644 --- a/docs/website/docs/building-from-source/index.md +++ b/docs/website/docs/building-from-source/index.md @@ -1,5 +1,14 @@ # Building IREE from source -Under construction. +While IREE does offer +[binary distributions](https://github.com/google/iree/releases) for its +compiler tools and [Python bindings](../bindings/python.md), building from +source is still useful when using IREE's runtime or when making changes to the +compiler itself. - +## Reference pages + +* [Getting started](./getting-started.md) +* [Optional features](./optional-features.md) like building the Python bindings +* [Android cross-compilation](./android.md) +* [RISC-V cross-compilation](./riscv.md) diff --git a/docs/website/docs/building-from-source/optional-features.md b/docs/website/docs/building-from-source/optional-features.md index a22d19697449..09a27f5ca8b8 100644 --- a/docs/website/docs/building-from-source/optional-features.md +++ b/docs/website/docs/building-from-source/optional-features.md @@ -1,4 +1,4 @@ -# Optional Features +# Optional features This page details the optional features and build modes for the project. Most of these are controlled by various CMake options, sometimes requiring diff --git a/docs/website/docs/deployment-configurations/index.md b/docs/website/docs/deployment-configurations/index.md new file mode 100644 index 000000000000..3f3ba5ef9081 --- /dev/null +++ b/docs/website/docs/deployment-configurations/index.md @@ -0,0 +1,18 @@ +# Deployment configurations + +IREE provides a flexible set of tools for various deployment scenarios. +Fully featured environments can use IREE for dynamic model deployments taking +advantage of multi-threaded hardware, while embedded systems can bypass IREE's +runtime entirely or interface with custom accelerators. + +## Stable configurations + +* [CPU - Dylib](./cpu-dylib.md) +* [CPU - Bare-Metal](./bare-metal.md) with minimal platform dependencies +* [GPU - Vulkan](./gpu-vulkan.md) + +These are just the most stable configurations IREE supports. Feel free to reach +out on any of IREE's +[communication channels](../index.md#communication-channels) if you have +questions about a specific platform, hardware accelerator, or set of system +features. diff --git a/docs/website/docs/index.md b/docs/website/docs/index.md index 578632479b42..60e9497f9471 100644 --- a/docs/website/docs/index.md +++ b/docs/website/docs/index.md @@ -68,10 +68,10 @@ Using IREE involves these general steps: 1. **Import your model** - Work in your framework of choice, then run your model through one of IREE's - import tools. + Work in your [framework of choice](./ml-frameworks), then run your model + through one of IREE's import tools. -2. **Select your deployment configuration** +2. **Select your [deployment configuration](./deployment-configurations)** Identify your target platform, accelerator(s), and other constraints. @@ -124,7 +124,7 @@ static or dynamic linkage and the associated function calls are generated. ### Running models IREE offers a low level C API, as well as several specialized sets of -_bindings_ for running IREE models using other languages: +[bindings](./bindings) for running IREE models using other languages: * [C API](bindings/c-api.md) * [Python](bindings/python.md) diff --git a/docs/website/docs/ml-frameworks/index.md b/docs/website/docs/ml-frameworks/index.md new file mode 100644 index 000000000000..ebdae587354a --- /dev/null +++ b/docs/website/docs/ml-frameworks/index.md @@ -0,0 +1,18 @@ +# ML frameworks + +## Supported frameworks + +IREE supports importing models from + +* [TensorFlow](./tensorflow.md) +* [TensorFlow Lite](./tensorflow-lite.md) +* [JAX](./jax.md) + +Importing from PyTorch and other frameworks is planned - stay tuned! + +## Samples + +Check out the samples in IREE's +[colab/ directory](https://github.com/google/iree/tree/main/colab) and the +[iree-samples repository](https://github.com/google/iree-samples) for examples +and workflow comparisons across frameworks. diff --git a/docs/website/docs/ml-frameworks/tensorflow-lite.md b/docs/website/docs/ml-frameworks/tensorflow-lite.md index 8db8a33e0f82..6ab1accda6c3 100644 --- a/docs/website/docs/ml-frameworks/tensorflow-lite.md +++ b/docs/website/docs/ml-frameworks/tensorflow-lite.md @@ -74,15 +74,21 @@ The flatbuffer can then be loaded to a VM module and run through IREE's runtime. ## Samples -| Colab notebooks | | -| -- | -- | -Text classification with TFLite and IREE | [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/iree/blob/main/colab/tflite_text_classification.ipynb) +* The +[tflitehub folder](https://github.com/google/iree-samples/tree/main/tflitehub) +in the [iree-samples repository](https://github.com/google/iree-samples) +contains test scripts to compile, run, and compare various TensorFlow Lite +models sourced from [TensorFlow Hub](https://tfhub.dev/). -An example smoke test of the +* An example smoke test of the [TensorFlow Lite C API](https://github.com/google/iree/tree/main/bindings/tflite) is available [here](https://github.com/google/iree/blob/main/bindings/tflite/smoke_test.cc). +| Colab notebooks | | +| -- | -- | +Text classification with TFLite and IREE | [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/iree/blob/main/colab/tflite_text_classification.ipynb) + !!! todo [Issue#3954](https://github.com/google/iree/issues/3954): Add documentation diff --git a/docs/website/mkdocs.yml b/docs/website/mkdocs.yml index 287524ec7cf9..bb838a9c8366 100644 --- a/docs/website/mkdocs.yml +++ b/docs/website/mkdocs.yml @@ -25,10 +25,7 @@ theme: - navigation.top # Back to top button - # Insiders only: navigation indices - # 1) supply index.md page under each section - # 2) list each index.md page in the `nav:` section below - # - navigation.indexes + - navigation.indexes # section names can link to index.md pages palette: # Light mode @@ -94,29 +91,34 @@ markdown_extensions: # Navigation with explicit ordering and nesting. # https://www.mkdocs.org/user-guide/configuration/#nav +# Note: may include external links and titles are optional for internal links nav: - Home: 'index.md' - 'ML frameworks': + - 'ml-frameworks/index.md' - TensorFlow: 'ml-frameworks/tensorflow.md' - TensorFlow Lite: 'ml-frameworks/tensorflow-lite.md' - JAX: 'ml-frameworks/jax.md' - 'Deployment configurations': + - 'deployment-configurations/index.md' - CPU - Dylib: 'deployment-configurations/cpu-dylib.md' - CPU - Bare-Metal: 'deployment-configurations/bare-metal.md' - GPU - Vulkan: 'deployment-configurations/gpu-vulkan.md' - 'Building from source': - # - 'building-from-source/index.md' # TODO(scotttodd): insiders + navigation.indexes + - 'building-from-source/index.md' - 'building-from-source/getting-started.md' - - Optional features: 'building-from-source/optional-features.md' - - Android cross-compilation: 'building-from-source/android.md' - - RISC-V cross-compilation: 'building-from-source/riscv.md' + - 'building-from-source/optional-features.md' + - 'building-from-source/android.md' + - 'building-from-source/riscv.md' - 'Bindings': + - 'bindings/index.md' - C API: 'bindings/c-api.md' - Python: 'bindings/python.md' - TensorFlow Lite: 'bindings/tensorflow-lite.md' - 'Community': - Projects: 'community/projects.md' - 'Blog': + - 'blog/index.md' - CUDA backend: 'blog/2021-10-15-cuda-backend.md' - Work in progress on Matrix Multiplication on CPU: 'blog/2021-10-13-mmt4d.md' - TFLite Support via TOSA: 'blog/2021-07-19-tflite-tosa.md' diff --git a/iree/compiler/Dialect/VM/Target/C/CMakeLists.txt b/iree/compiler/Dialect/VM/Target/C/CMakeLists.txt index cafdc19a2dc1..267121d6db09 100644 --- a/iree/compiler/Dialect/VM/Target/C/CMakeLists.txt +++ b/iree/compiler/Dialect/VM/Target/C/CMakeLists.txt @@ -8,6 +8,20 @@ if(${IREE_ENABLE_EMITC}) iree_add_all_subdirs() + iree_cc_library( + NAME + TranslateToCpp + HDRS + CppEmitter.h + SRCS + TranslateToCpp.cpp + DEPS + MLIREmitC + MLIRIR + MLIRSCF + MLIRStandard + ) + iree_cc_library( NAME C @@ -19,11 +33,11 @@ if(${IREE_ENABLE_EMITC}) "TranslationFlags.cpp" "TranslationRegistration.cpp" DEPS + ::TranslateToCpp LLVMSupport MLIRIR MLIRPass MLIRSupport - emitc::TranslateToCpp iree::compiler::Dialect::VM::Analysis iree::compiler::Dialect::VM::IR iree::compiler::Dialect::VM::Conversion::VMToEmitC diff --git a/iree/compiler/Dialect/VM/Target/C/CModuleTarget.cpp b/iree/compiler/Dialect/VM/Target/C/CModuleTarget.cpp index c058bb85d48a..90c8bc800c67 100644 --- a/iree/compiler/Dialect/VM/Target/C/CModuleTarget.cpp +++ b/iree/compiler/Dialect/VM/Target/C/CModuleTarget.cpp @@ -6,13 +6,13 @@ #include "iree/compiler/Dialect/VM/Target/C/CModuleTarget.h" -#include "emitc/Target/Cpp/CppEmitter.h" #include "iree/compiler/Dialect/Util/IR/UtilDialect.h" #include "iree/compiler/Dialect/Util/IR/UtilOps.h" #include "iree/compiler/Dialect/Util/Transforms/Passes.h" #include "iree/compiler/Dialect/VM/Analysis/RegisterAllocation.h" #include "iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.h" #include "iree/compiler/Dialect/VM/Conversion/VMToEmitC/DropExcludedExports.h" +#include "iree/compiler/Dialect/VM/Target/C/CppEmitter.h" #include "iree/compiler/Dialect/VM/Transforms/Passes.h" #include "llvm/ADT/TypeSwitch.h" #include "mlir/Pass/PassManager.h" diff --git a/iree/compiler/Dialect/VM/Target/C/CppEmitter.h b/iree/compiler/Dialect/VM/Target/C/CppEmitter.h new file mode 100644 index 000000000000..5efb1b273ccc --- /dev/null +++ b/iree/compiler/Dialect/VM/Target/C/CppEmitter.h @@ -0,0 +1,182 @@ +// Copyright 2020 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +// Formated in LLVM style. Avoid reformatting for upcoming upstreaming. +// clang-format off + +#ifndef EMITC_TARGET_CPP_CPPEMITTER_H +#define EMITC_TARGET_CPP_CPPEMITTER_H + +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Value.h" +#include "mlir/Support/IndentedOstream.h" +#include "llvm/ADT/ScopedHashTable.h" +#include "llvm/Support/raw_ostream.h" +#include + +namespace mlir { +namespace emitc { + +/// Convenience functions to produce interleaved output with functions returning +/// a LogicalResult. This is different than those in STL as functions used on +/// each element doesn't return a string. +template +inline LogicalResult +interleaveWithError(ForwardIterator begin, ForwardIterator end, + UnaryFunctor eachFn, NullaryFunctor betweenFn) { + if (begin == end) + return success(); + if (failed(eachFn(*begin))) + return failure(); + ++begin; + for (; begin != end; ++begin) { + betweenFn(); + if (failed(eachFn(*begin))) + return failure(); + } + return success(); +} + +template +inline LogicalResult interleaveWithError(const Container &c, + UnaryFunctor eachFn, + NullaryFunctor betweenFn) { + return interleaveWithError(c.begin(), c.end(), eachFn, betweenFn); +} + +template +inline LogicalResult interleaveCommaWithError(const Container &c, + raw_ostream &os, + UnaryFunctor eachFn) { + return interleaveWithError(c.begin(), c.end(), eachFn, [&]() { os << ", "; }); +} + +/// Emitter that uses dialect specific emitters to emit C++ code. +struct CppEmitter { + explicit CppEmitter(raw_ostream &os, bool declareVariablesAtTop); + + /// Emits attribute or returns failure. + LogicalResult emitAttribute(Location loc, Attribute attr); + + /// Emits operation 'op' with/without training semicolon or returns failure. + LogicalResult emitOperation(Operation &op, bool trailingSemicolon); + + /// Emits type 'type' or returns failure. + LogicalResult emitType(Location loc, Type type); + + /// Emits array of types as a std::tuple of the emitted types. + /// - emits void for an empty array; + /// - emits the type of the only element for arrays of size one; + /// - emits a std::tuple otherwise; + LogicalResult emitTypes(Location loc, ArrayRef types); + + /// Emits array of types as a std::tuple of the emitted types independently of + /// the array size. + LogicalResult emitTupleType(Location loc, ArrayRef types); + + /// Emits an assignment for a variable which has been declared previously. + LogicalResult emitVariableAssignment(OpResult result); + + /// Emits a variable declaration for a result of an operation. + LogicalResult emitVariableDeclaration(OpResult result, + bool trailingSemicolon); + + /// Emits the variable declaration and assignment prefix for 'op'. + /// - emits separate variable followed by std::tie for multi-valued operation; + /// - emits single type followed by variable for single result; + /// - emits nothing if no value produced by op; + /// Emits final '=' operator where a type is produced. Returns failure if + /// any result type could not be converted. + LogicalResult emitAssignPrefix(Operation &op); + + /// Emits a label for the block. + LogicalResult emitLabel(Block &block); + + /// Emits the operands and atttributes of the operation. All operands are + /// emitted first and then all attributes in alphabetical order. + LogicalResult emitOperandsAndAttributes(Operation &op, + ArrayRef exclude = {}); + + /// Emits the operands of the operation. All operands are emitted in order. + LogicalResult emitOperands(Operation &op); + + /// Return the existing or a new name for a Value. + StringRef getOrCreateName(Value val); + + /// Return the existing or a new label of a Block. + StringRef getOrCreateName(Block &block); + + /// Whether to map an mlir integer to a unsigned integer in C++. + bool shouldMapToUnsigned(IntegerType::SignednessSemantics val); + + /// RAII helper function to manage entering/exiting C++ scopes. + struct Scope { + Scope(CppEmitter &emitter) + : valueMapperScope(emitter.valueMapper), + blockMapperScope(emitter.blockMapper), emitter(emitter) { + emitter.valueInScopeCount.push(emitter.valueInScopeCount.top()); + emitter.labelInScopeCount.push(emitter.labelInScopeCount.top()); + } + ~Scope() { + emitter.valueInScopeCount.pop(); + emitter.labelInScopeCount.pop(); + } + + private: + llvm::ScopedHashTableScope valueMapperScope; + llvm::ScopedHashTableScope blockMapperScope; + CppEmitter &emitter; + }; + + /// Returns wether the Value is assigned to a C++ variable in the scope. + bool hasValueInScope(Value val); + + // Returns whether a label is assigned to the block. + bool hasBlockLabel(Block &block); + + /// Returns the output stream. + raw_indented_ostream &ostream() { return os; }; + + /// Returns if all variables for op results and basic block arguments need to + /// be declared at the beginning of a function. + bool shouldDeclareVariablesAtTop() { return declareVariablesAtTop; }; + +private: + using ValueMapper = llvm::ScopedHashTable; + using BlockMapper = llvm::ScopedHashTable; + + /// Output stream to emit to. + raw_indented_ostream os; + + /// Boolean to enforce that all variables for op results and block + /// arguments are declared at the beginning of the function. This also + /// includes results from ops located in nested regions. + bool declareVariablesAtTop; + + /// Map from value to name of C++ variable that contain the name. + ValueMapper valueMapper; + + /// Map from block to name of C++ label. + BlockMapper blockMapper; + + /// The number of values in the current scope. This is used to declare the + /// names of values in a scope. + std::stack valueInScopeCount; + std::stack labelInScopeCount; +}; + +/// Translates the given operation to C++ code. The operation or operations in +/// the region of 'op' need almost all be in EmitC dialect. The parameter +/// 'declareVariablesAtTop' enforces that all variables for op results and block +/// arguments are declared at the beginning of the function. +LogicalResult translateToCpp(Operation *op, raw_ostream &os, + bool declareVariablesAtTop = false); +} // namespace emitc +} // namespace mlir + +#endif // EMITC_TARGET_CPP_CPPEMITTER_H +// clang-format on diff --git a/iree/compiler/Dialect/VM/Target/C/README.md b/iree/compiler/Dialect/VM/Target/C/README.md new file mode 100644 index 000000000000..af8f5ef61bbc --- /dev/null +++ b/iree/compiler/Dialect/VM/Target/C/README.md @@ -0,0 +1,8 @@ +# MLIR EmitC + +The Cpp emitter is a partial copy of MLIR EmitC, forked from https://github.com/iml130/mlir-emitc. + +The initial import contains the C/C++ emitter (namely the files `CppEmitter.h` and `TranslateToCpp.cpp`) +and reflects the state of iml130/mlir-emitc@f9968f65 for those files. + +It is intended to switch to the C/C++ emitter in the MLIR core repository as soon as possible. \ No newline at end of file diff --git a/iree/compiler/Dialect/VM/Target/C/TranslateToCpp.cpp b/iree/compiler/Dialect/VM/Target/C/TranslateToCpp.cpp new file mode 100644 index 000000000000..b39b0325adef --- /dev/null +++ b/iree/compiler/Dialect/VM/Target/C/TranslateToCpp.cpp @@ -0,0 +1,839 @@ +// Copyright 2020 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +// Formated in LLVM style. Avoid reformatting for upcoming upstreaming. +// clang-format off + +#include "iree/compiler/Dialect/VM/Target/C/CppEmitter.h" + +#include "mlir/Dialect/EmitC/IR/EmitC.h" +#include "mlir/Dialect/SCF/SCF.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/Operation.h" +#include "mlir/Support/IndentedOstream.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/ADT/StringMap.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/FormatVariadic.h" + +#define DEBUG_TYPE "translate-to-cpp" + +using namespace mlir; +using namespace mlir::emitc; +using llvm::formatv; + +static LogicalResult printConstantOp(CppEmitter &emitter, Operation *operation, + Attribute value) { + OpResult result = operation->getResult(0); + + // Only emit an assignment as the variable was already declared when printing + // the FuncOp. + if (emitter.shouldDeclareVariablesAtTop()) { + // Skip the assignment if the emitc.constant has no value. + if (auto oAttr = value.dyn_cast()) { + if (oAttr.getValue().empty()) + return success(); + } + + if (failed(emitter.emitVariableAssignment(result))) + return failure(); + return emitter.emitAttribute(operation->getLoc(), value); + } + + // Emit a variable declaration for an emitc.constant op without value. + if (auto oAttr = value.dyn_cast()) { + if (oAttr.getValue().empty()) + // The semicolon gets printed by the emitOperation function. + return emitter.emitVariableDeclaration(result, + /*trailingSemicolon=*/false); + } + + // Emit a variable declaration. + if (failed(emitter.emitAssignPrefix(*operation))) + return failure(); + return emitter.emitAttribute(operation->getLoc(), value); +} + +static LogicalResult printOperation(CppEmitter &emitter, + emitc::ConstantOp constantOp) { + Operation *operation = constantOp.getOperation(); + Attribute value = constantOp.value(); + + return printConstantOp(emitter, operation, value); +} + +static LogicalResult printOperation(CppEmitter &emitter, + mlir::ConstantOp constantOp) { + Operation *operation = constantOp.getOperation(); + Attribute value = constantOp.value(); + + return printConstantOp(emitter, operation, value); +} + +static LogicalResult printOperation(CppEmitter &emitter, BranchOp branchOp) { + raw_ostream &os = emitter.ostream(); + Block &successor = *branchOp.getSuccessor(); + + for (auto pair : + llvm::zip(branchOp.getOperands(), successor.getArguments())) { + Value &operand = std::get<0>(pair); + BlockArgument &argument = std::get<1>(pair); + os << emitter.getOrCreateName(argument) << " = " + << emitter.getOrCreateName(operand) << ";\n"; + } + + os << "goto "; + if (!(emitter.hasBlockLabel(successor))) + return branchOp.emitOpError("unable to find label for successor block"); + os << emitter.getOrCreateName(successor); + return success(); +} + +static LogicalResult printOperation(CppEmitter &emitter, + CondBranchOp condBranchOp) { + raw_ostream &os = emitter.ostream(); + Block &trueSuccessor = *condBranchOp.getTrueDest(); + Block &falseSuccessor = *condBranchOp.getFalseDest(); + + os << "if (" << emitter.getOrCreateName(condBranchOp.getCondition()) + << ") {\n"; + + // If condition is true. + for (auto pair : llvm::zip(condBranchOp.getTrueOperands(), + trueSuccessor.getArguments())) { + Value &operand = std::get<0>(pair); + BlockArgument &argument = std::get<1>(pair); + os << emitter.getOrCreateName(argument) << " = " + << emitter.getOrCreateName(operand) << ";\n"; + } + + os << "goto "; + if (!(emitter.hasBlockLabel(trueSuccessor))) { + return condBranchOp.emitOpError("unable to find label for successor block"); + } + os << emitter.getOrCreateName(trueSuccessor) << ";\n"; + os << "} else {\n"; + // If condition is false. + for (auto pair : llvm::zip(condBranchOp.getFalseOperands(), + falseSuccessor.getArguments())) { + Value &operand = std::get<0>(pair); + BlockArgument &argument = std::get<1>(pair); + os << emitter.getOrCreateName(argument) << " = " + << emitter.getOrCreateName(operand) << ";\n"; + } + + os << "goto "; + if (!(emitter.hasBlockLabel(falseSuccessor))) { + return condBranchOp.emitOpError() + << "unable to find label for successor block"; + } + os << emitter.getOrCreateName(falseSuccessor) << ";\n"; + os << "}"; + return success(); +} + +static LogicalResult printOperation(CppEmitter &emitter, mlir::CallOp callOp) { + if (failed(emitter.emitAssignPrefix(*callOp.getOperation()))) + return failure(); + + raw_ostream &os = emitter.ostream(); + os << callOp.getCallee() << "("; + if (failed(emitter.emitOperands(*callOp.getOperation()))) + return failure(); + os << ")"; + return success(); +} + +static LogicalResult printOperation(CppEmitter &emitter, emitc::CallOp callOp) { + raw_ostream &os = emitter.ostream(); + Operation &op = *callOp.getOperation(); + + if (failed(emitter.emitAssignPrefix(op))) + return failure(); + os << callOp.callee(); + + auto emitArgs = [&](Attribute attr) -> LogicalResult { + if (auto t = attr.dyn_cast()) { + // Index attributes are treated specially as operand index. + if (t.getType().isIndex()) { + int64_t idx = t.getInt(); + if ((idx < 0) || (idx >= op.getNumOperands())) + return op.emitOpError("invalid operand index"); + if (!emitter.hasValueInScope(op.getOperand(idx))) + return op.emitOpError("operand ") + << idx << "'s value not defined in scope"; + os << emitter.getOrCreateName(op.getOperand(idx)); + return success(); + } + } + if (failed(emitter.emitAttribute(op.getLoc(), attr))) + return failure(); + + return success(); + }; + + if (callOp.template_args()) { + os << "<"; + if (failed(interleaveCommaWithError(*callOp.template_args(), os, emitArgs))) + return failure(); + os << ">"; + } + + os << "("; + + LogicalResult emittedArgs = + callOp.args() ? interleaveCommaWithError(*callOp.args(), os, emitArgs) + : emitter.emitOperands(op); + if (failed(emittedArgs)) + return failure(); + os << ")"; + return success(); +} + +static LogicalResult printOperation(CppEmitter &emitter, + emitc::ApplyOp applyOp) { + raw_ostream &os = emitter.ostream(); + Operation &op = *applyOp.getOperation(); + + if (failed(emitter.emitAssignPrefix(op))) + return failure(); + os << applyOp.applicableOperator(); + os << emitter.getOrCreateName(applyOp.getOperand()); + + return success(); +} + +static LogicalResult printOperation(CppEmitter &emitter, + emitc::IncludeOp includeOp) { + raw_ostream &os = emitter.ostream(); + + os << "#include "; + if (includeOp.is_standard_include()) + os << "<" << includeOp.include() << ">"; + else + os << "\"" << includeOp.include() << "\""; + + return success(); +} + +static LogicalResult printOperation(CppEmitter &emitter, scf::ForOp forOp) { + + raw_indented_ostream &os = emitter.ostream(); + + OperandRange operands = forOp.getIterOperands(); + Block::BlockArgListType iterArgs = forOp.getRegionIterArgs(); + Operation::result_range results = forOp.getResults(); + + if (!emitter.shouldDeclareVariablesAtTop()) { + for (OpResult result : results) { + if (failed(emitter.emitVariableDeclaration(result, + /*trailingSemicolon=*/true))) + return failure(); + } + } + + for (auto pair : llvm::zip(iterArgs, operands)) { + if (failed(emitter.emitType(forOp.getLoc(), std::get<0>(pair).getType()))) + return failure(); + os << " " << emitter.getOrCreateName(std::get<0>(pair)) << " = "; + os << emitter.getOrCreateName(std::get<1>(pair)) << ";"; + os << "\n"; + } + + os << "for ("; + if (failed( + emitter.emitType(forOp.getLoc(), forOp.getInductionVar().getType()))) + return failure(); + os << " "; + os << emitter.getOrCreateName(forOp.getInductionVar()); + os << " = "; + os << emitter.getOrCreateName(forOp.lowerBound()); + os << "; "; + os << emitter.getOrCreateName(forOp.getInductionVar()); + os << " < "; + os << emitter.getOrCreateName(forOp.upperBound()); + os << "; "; + os << emitter.getOrCreateName(forOp.getInductionVar()); + os << " += "; + os << emitter.getOrCreateName(forOp.step()); + os << ") {\n"; + os.indent(); + + Region &forRegion = forOp.region(); + auto regionOps = forRegion.getOps(); + + // We skip the trailing yield op because this updates the result variables + // of the for op in the generated code. Instead we update the iterArgs at + // the end of a loop iteration and set the result variables after the for + // loop. + for (auto it = regionOps.begin(); std::next(it) != regionOps.end(); ++it) { + if (failed(emitter.emitOperation(*it, /*trailingSemicolon=*/true))) + return failure(); + } + + Operation *yieldOp = forRegion.getBlocks().front().getTerminator(); + // Copy yield operands into iterArgs at the end of a loop iteration. + for (auto pair : llvm::zip(iterArgs, yieldOp->getOperands())) { + BlockArgument iterArg = std::get<0>(pair); + Value operand = std::get<1>(pair); + os << emitter.getOrCreateName(iterArg) << " = " + << emitter.getOrCreateName(operand) << ";\n"; + } + + os.unindent() << "}"; + + // Copy iterArgs into results after the for loop. + for (auto pair : llvm::zip(results, iterArgs)) { + OpResult result = std::get<0>(pair); + BlockArgument iterArg = std::get<1>(pair); + os << "\n" + << emitter.getOrCreateName(result) << " = " + << emitter.getOrCreateName(iterArg) << ";"; + } + + return success(); +} + +static LogicalResult printOperation(CppEmitter &emitter, scf::IfOp ifOp) { + raw_indented_ostream &os = emitter.ostream(); + + if (!emitter.shouldDeclareVariablesAtTop()) { + for (OpResult result : ifOp.getResults()) { + if (failed(emitter.emitVariableDeclaration(result, + /*trailingSemicolon=*/true))) + return failure(); + } + } + + os << "if ("; + if (failed(emitter.emitOperands(*ifOp.getOperation()))) + return failure(); + os << ") {\n"; + os.indent(); + + Region &thenRegion = ifOp.thenRegion(); + for (Operation &op : thenRegion.getOps()) { + // Note: This prints a superfluous semicolon if the terminating yield op has + // zero results. + if (failed(emitter.emitOperation(op, /*trailingSemicolon=*/true))) + return failure(); + } + + os.unindent() << "}"; + + Region &elseRegion = ifOp.elseRegion(); + if (!elseRegion.empty()) { + os << " else {\n"; + os.indent(); + + for (Operation &op : elseRegion.getOps()) { + // Note: This prints a superfluous semicolon if the terminating yield op + // has zero results. + if (failed(emitter.emitOperation(op, /*trailingSemicolon=*/true))) + return failure(); + } + + os.unindent() << "}"; + } + + return success(); +} + +static LogicalResult printOperation(CppEmitter &emitter, scf::YieldOp yieldOp) { + raw_ostream &os = emitter.ostream(); + Operation &parentOp = *yieldOp.getOperation()->getParentOp(); + + if (yieldOp.getNumOperands() != parentOp.getNumResults()) { + return yieldOp.emitError("number of operands does not to match the number " + "of the parent op's results"); + } + + if (failed(interleaveWithError( + llvm::zip(parentOp.getResults(), yieldOp.getOperands()), + [&](auto pair) -> LogicalResult { + auto result = std::get<0>(pair); + auto operand = std::get<1>(pair); + os << emitter.getOrCreateName(result) << " = "; + + if (!emitter.hasValueInScope(operand)) + return yieldOp.emitError("operand value not in scope"); + os << emitter.getOrCreateName(operand); + return success(); + }, + [&]() { os << ";\n"; }))) + return failure(); + + return success(); +} + +static LogicalResult printOperation(CppEmitter &emitter, ReturnOp returnOp) { + raw_ostream &os = emitter.ostream(); + os << "return"; + switch (returnOp.getNumOperands()) { + case 0: + return success(); + case 1: + os << " " << emitter.getOrCreateName(returnOp.getOperand(0)); + return success(emitter.hasValueInScope(returnOp.getOperand(0))); + default: + os << " std::make_tuple("; + if (failed(emitter.emitOperandsAndAttributes(*returnOp.getOperation()))) + return failure(); + os << ")"; + return success(); + } +} + +static LogicalResult printOperation(CppEmitter &emitter, ModuleOp moduleOp) { + CppEmitter::Scope scope(emitter); + + for (Operation &op : moduleOp) { + if (failed(emitter.emitOperation(op, /*trailingSemicolon=*/false))) + return failure(); + } + return success(); +} + +static LogicalResult printOperation(CppEmitter &emitter, FuncOp functionOp) { + // We need to declare variables at top if the function has multiple blocks. + if (!emitter.shouldDeclareVariablesAtTop() && + functionOp.getBlocks().size() > 1) { + return functionOp.emitOpError( + "with multiple blocks needs variables declared at top"); + } + + CppEmitter::Scope scope(emitter); + raw_indented_ostream &os = emitter.ostream(); + if (failed(emitter.emitTypes(functionOp.getLoc(), + functionOp.getType().getResults()))) + return failure(); + os << " " << functionOp.getName(); + + os << "("; + if (failed(interleaveCommaWithError( + functionOp.getArguments(), os, + [&](BlockArgument arg) -> LogicalResult { + if (failed(emitter.emitType(functionOp.getLoc(), arg.getType()))) + return failure(); + os << " " << emitter.getOrCreateName(arg); + return success(); + }))) + return failure(); + os << ") {\n"; + os.indent(); + if (emitter.shouldDeclareVariablesAtTop()) { + // Declare all variables that hold op results including those from nested + // regions. + WalkResult result = + functionOp.walk([&](Operation *op) -> WalkResult { + for (OpResult result : op->getResults()) { + if (failed(emitter.emitVariableDeclaration( + result, /*trailingSemicolon=*/true))) { + return WalkResult( + op->emitError("unable to declare result variable for op")); + } + } + return WalkResult::advance(); + }); + if (result.wasInterrupted()) + return failure(); + } + + Region::BlockListType &blocks = functionOp.getBlocks(); + // Create label names for basic blocks. + for (Block &block : blocks) { + emitter.getOrCreateName(block); + } + + // Declare variables for basic block arguments. + for (auto it = std::next(blocks.begin()); it != blocks.end(); ++it) { + Block &block = *it; + for (BlockArgument &arg : block.getArguments()) { + if (emitter.hasValueInScope(arg)) + return functionOp.emitOpError(" block argument #") + << arg.getArgNumber() << " is out of scope"; + if (failed( + emitter.emitType(block.getParentOp()->getLoc(), arg.getType()))) { + return failure(); + } + os << " " << emitter.getOrCreateName(arg) << ";\n"; + } + } + + for (Block &block : blocks) { + // Only print a label if there is more than one block. + if (blocks.size() > 1) { + if (failed(emitter.emitLabel(block))) + return failure(); + } + for (Operation &op : block.getOperations()) { + // When generating code for an scf.if or std.cond_br op no semicolon needs + // to be printed after the closing brace. + // When generating code for an scf.for op, printing a trailing semicolon + // is handled within the printOperation function. + bool trailingSemicolon = !isa(op); + + if (failed(emitter.emitOperation( + op, /*trailingSemicolon=*/trailingSemicolon))) + return failure(); + } + } + os.unindent() << "}\n"; + return success(); +} + +CppEmitter::CppEmitter(raw_ostream &os, bool declareVariablesAtTop) + : os(os), declareVariablesAtTop(declareVariablesAtTop) { + valueInScopeCount.push(0); + labelInScopeCount.push(0); +} + +/// Return the existing or a new name for a Value. +StringRef CppEmitter::getOrCreateName(Value val) { + if (!valueMapper.count(val)) + valueMapper.insert(val, formatv("v{0}", ++valueInScopeCount.top())); + return *valueMapper.begin(val); +} + +/// Return the existing or a new label for a Block. +StringRef CppEmitter::getOrCreateName(Block &block) { + if (!blockMapper.count(&block)) + blockMapper.insert(&block, formatv("label{0}", ++labelInScopeCount.top())); + return *blockMapper.begin(&block); +} + +bool CppEmitter::shouldMapToUnsigned(IntegerType::SignednessSemantics val) { + switch (val) { + case IntegerType::Signless: + return false; + case IntegerType::Signed: + return false; + case IntegerType::Unsigned: + return true; + default: + llvm_unreachable("unsupported IntegerType"); + } +} + +bool CppEmitter::hasValueInScope(Value val) { return valueMapper.count(val); } + +bool CppEmitter::hasBlockLabel(Block &block) { + return blockMapper.count(&block); +} + +LogicalResult CppEmitter::emitAttribute(Location loc, Attribute attr) { + auto printInt = [&](APInt val, bool isUnsigned) { + if (val.getBitWidth() == 1) { + if (val.getBoolValue()) + os << "true"; + else + os << "false"; + } else { + SmallString<128> strValue; + val.toString(strValue, 10, !isUnsigned, false); + os << strValue; + } + }; + + auto printFloat = [&](APFloat val) { + if (val.isFinite()) { + SmallString<128> strValue; + // Use default values of toString except don't truncate zeros. + val.toString(strValue, 0, 0, false); + switch (llvm::APFloatBase::SemanticsToEnum(val.getSemantics())) { + case llvm::APFloatBase::S_IEEEsingle: + os << "(float)"; + break; + case llvm::APFloatBase::S_IEEEdouble: + os << "(double)"; + break; + default: + break; + }; + os << strValue; + } else if (val.isNaN()) { + os << "NAN"; + } else if (val.isInfinity()) { + if (val.isNegative()) + os << "-"; + os << "INFINITY"; + } + }; + + // Print floating point attributes. + if (auto fAttr = attr.dyn_cast()) { + printFloat(fAttr.getValue()); + return success(); + } + if (auto dense = attr.dyn_cast()) { + os << '{'; + interleaveComma(dense, os, [&](APFloat val) { printFloat(val); }); + os << '}'; + return success(); + } + + // Print integer attributes. + if (auto iAttr = attr.dyn_cast()) { + if (auto iType = iAttr.getType().dyn_cast()) { + printInt(iAttr.getValue(), shouldMapToUnsigned(iType.getSignedness())); + return success(); + } + if (auto iType = iAttr.getType().dyn_cast()) { + printInt(iAttr.getValue(), false); + return success(); + } + } + if (auto dense = attr.dyn_cast()) { + if (auto iType = dense.getType() + .cast() + .getElementType() + .dyn_cast()) { + os << '{'; + interleaveComma(dense, os, [&](APInt val) { + printInt(val, shouldMapToUnsigned(iType.getSignedness())); + }); + os << '}'; + return success(); + } + if (auto iType = dense.getType() + .cast() + .getElementType() + .dyn_cast()) { + os << '{'; + interleaveComma(dense, os, [&](APInt val) { printInt(val, false); }); + os << '}'; + return success(); + } + } + + // Print opaque attributes. + if (auto oAttr = attr.dyn_cast()) { + os << oAttr.getValue(); + return success(); + } + + // Print symbolic reference attributes. + if (auto sAttr = attr.dyn_cast()) { + if (sAttr.getNestedReferences().size() > 1) + return emitError(loc, "attribute has more than 1 nested reference"); + os << sAttr.getRootReference().getValue(); + return success(); + } + + // Print type attributes. + if (auto type = attr.dyn_cast()) + return emitType(loc, type.getValue()); + + return emitError(loc, "cannot emit attribute of type ") << attr.getType(); +} + +LogicalResult CppEmitter::emitOperands(Operation &op) { + auto emitOperandName = [&](Value result) -> LogicalResult { + if (!hasValueInScope(result)) + return op.emitOpError() << "operand value not in scope"; + os << getOrCreateName(result); + return success(); + }; + return interleaveCommaWithError(op.getOperands(), os, emitOperandName); +} + +LogicalResult +CppEmitter::emitOperandsAndAttributes(Operation &op, + ArrayRef exclude) { + if (failed(emitOperands(op))) + return failure(); + // Insert comma in between operands and non-filtered attributes if needed. + if (op.getNumOperands() > 0) { + for (NamedAttribute attr : op.getAttrs()) { + if (!llvm::is_contained(exclude, attr.first.strref())) { + os << ", "; + break; + } + } + } + // Emit attributes. + auto emitNamedAttribute = [&](NamedAttribute attr) -> LogicalResult { + if (llvm::is_contained(exclude, attr.first.strref())) + return success(); + os << "/* " << attr.first << " */"; + if (failed(emitAttribute(op.getLoc(), attr.second))) + return failure(); + return success(); + }; + return interleaveCommaWithError(op.getAttrs(), os, emitNamedAttribute); +} + +LogicalResult CppEmitter::emitVariableAssignment(OpResult result) { + if (!hasValueInScope(result)) { + return result.getDefiningOp()->emitOpError( + "result variable for the operation has not been declared"); + } + os << getOrCreateName(result) << " = "; + return success(); +} + +LogicalResult CppEmitter::emitVariableDeclaration(OpResult result, + bool trailingSemicolon) { + if (hasValueInScope(result)) { + return result.getDefiningOp()->emitError( + "result variable for the operation already declared"); + } + if (failed(emitType(result.getOwner()->getLoc(), result.getType()))) + return failure(); + os << " " << getOrCreateName(result); + if (trailingSemicolon) + os << ";\n"; + return success(); +} + +LogicalResult CppEmitter::emitAssignPrefix(Operation &op) { + switch (op.getNumResults()) { + case 0: + break; + case 1: { + OpResult result = op.getResult(0); + if (shouldDeclareVariablesAtTop()) { + if (failed(emitVariableAssignment(result))) + return failure(); + } else { + if (failed(emitVariableDeclaration(result, /*trailingSemicolon=*/false))) + return failure(); + os << " = "; + } + break; + } + default: + if (!shouldDeclareVariablesAtTop()) { + for (OpResult result : op.getResults()) { + if (failed(emitVariableDeclaration(result, /*trailingSemicolon=*/true))) + return failure(); + } + } + os << "std::tie("; + interleaveComma(op.getResults(), os, + [&](Value result) { os << getOrCreateName(result); }); + os << ") = "; + } + return success(); +} + +LogicalResult CppEmitter::emitLabel(Block &block) { + if (!hasBlockLabel(block)) + return block.getParentOp()->emitError("label for block not found"); + os << getOrCreateName(block) << ":\n"; + return success(); +} + +LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) { + LogicalResult status = + llvm::TypeSwitch(&op) + // EmitC ops. + .Case( + [&](auto op) { return printOperation(*this, op); }) + // SCF ops. + .Case( + [&](auto op) { return printOperation(*this, op); }) + // Standard ops. + .Case( + [&](auto op) { return printOperation(*this, op); }) + .Default([&](Operation *) { + return op.emitOpError("unable to find printer for op"); + }); + + if (failed(status)) + return failure(); + os << (trailingSemicolon ? ";\n" : "\n"); + return success(); +} + +LogicalResult CppEmitter::emitType(Location loc, Type type) { + if (auto iType = type.dyn_cast()) { + switch (iType.getWidth()) { + case 1: + return (os << "bool"), success(); + case 8: + case 16: + case 32: + case 64: + if (shouldMapToUnsigned(iType.getSignedness())) + return (os << "uint" << iType.getWidth() << "_t"), success(); + else + return (os << "int" << iType.getWidth() << "_t"), success(); + default: + return emitError(loc, "cannot emit integer type ") << type; + } + } + if (auto fType = type.dyn_cast()) { + switch (fType.getWidth()) { + case 32: + return (os << "float"), success(); + case 64: + return (os << "double"), success(); + default: + return emitError(loc, "cannot emit float type ") << type; + } + } + if (auto iType = type.dyn_cast()) + return (os << "size_t"), success(); + if (auto tType = type.dyn_cast()) { + if (!tType.hasRank()) + return emitError(loc, "cannot emit unranked tensor type"); + if (!tType.hasStaticShape()) + return emitError(loc, "cannot emit tensor type with non static shape"); + os << "Tensor<"; + if (failed(emitType(loc, tType.getElementType()))) + return failure(); + auto shape = tType.getShape(); + for (auto dimSize : shape) { + os << ", "; + os << dimSize; + } + os << ">"; + return success(); + } + if (auto tType = type.dyn_cast()) + return emitTupleType(loc, tType.getTypes()); + if (auto oType = type.dyn_cast()) { + os << oType.getValue(); + return success(); + } + return emitError(loc, "cannot emit type ") << type; +} + +LogicalResult CppEmitter::emitTypes(Location loc, ArrayRef types) { + switch (types.size()) { + case 0: + os << "void"; + return success(); + case 1: + return emitType(loc, types.front()); + default: + return emitTupleType(loc, types); + } +} + +LogicalResult CppEmitter::emitTupleType(Location loc, ArrayRef types) { + os << "std::tuple<"; + if (failed(interleaveCommaWithError( + types, os, [&](Type type) { return emitType(loc, type); }))) + return failure(); + os << ">"; + return success(); +} + +LogicalResult emitc::translateToCpp(Operation *op, raw_ostream &os, + bool declareVariablesAtTop) { + CppEmitter emitter(os, declareVariablesAtTop); + return emitter.emitOperation(*op, /*trailingSemicolon=*/false); +} +// clang-format on diff --git a/iree/hal/cts/command_buffer_test.cc b/iree/hal/cts/command_buffer_test.cc index 234bc8e9aad3..bd9a1cd5eb37 100644 --- a/iree/hal/cts/command_buffer_test.cc +++ b/iree/hal/cts/command_buffer_test.cc @@ -37,6 +37,64 @@ class CommandBufferTest : public CtsTestBase { } protected: + std::vector RunFillBufferTest(iree_device_size_t buffer_size, + iree_device_size_t target_offset, + iree_device_size_t fill_length, + const void* pattern, + iree_host_size_t pattern_length) { + iree_hal_command_buffer_t* command_buffer; + IREE_CHECK_OK(iree_hal_command_buffer_create( + device_, IREE_HAL_COMMAND_BUFFER_MODE_ONE_SHOT, + IREE_HAL_COMMAND_CATEGORY_ANY, IREE_HAL_QUEUE_AFFINITY_ANY, + &command_buffer)); + iree_hal_buffer_t* device_buffer; + IREE_CHECK_OK(iree_hal_allocator_allocate_buffer( + iree_hal_device_allocator(device_), + IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL | IREE_HAL_MEMORY_TYPE_HOST_VISIBLE, + IREE_HAL_BUFFER_USAGE_ALL, buffer_size, &device_buffer)); + + IREE_CHECK_OK(iree_hal_command_buffer_begin(command_buffer)); + // Start with a zero fill on the entire buffer... + uint8_t zero_val = 0x0; + IREE_CHECK_OK(iree_hal_command_buffer_fill_buffer( + command_buffer, device_buffer, /*target_offset=*/0, + /*length=*/buffer_size, &zero_val, + /*pattern_length=*/sizeof(zero_val))); + // (buffer barrier between the fill operations) + iree_hal_buffer_barrier_t buffer_barrier; + buffer_barrier.source_scope = IREE_HAL_ACCESS_SCOPE_TRANSFER_WRITE; + buffer_barrier.target_scope = IREE_HAL_ACCESS_SCOPE_TRANSFER_WRITE | + IREE_HAL_ACCESS_SCOPE_DISPATCH_WRITE; + buffer_barrier.buffer = device_buffer; + buffer_barrier.offset = 0; + buffer_barrier.length = buffer_size; + IREE_CHECK_OK(iree_hal_command_buffer_execution_barrier( + command_buffer, IREE_HAL_EXECUTION_STAGE_TRANSFER, + IREE_HAL_EXECUTION_STAGE_TRANSFER | IREE_HAL_EXECUTION_STAGE_DISPATCH, + IREE_HAL_EXECUTION_BARRIER_FLAG_NONE, /*memory_barrier_count=*/0, NULL, + /*buffer_barrier_count=*/1, &buffer_barrier)); + // ... then fill the pattern on top. + IREE_CHECK_OK(iree_hal_command_buffer_fill_buffer( + command_buffer, device_buffer, + /*target_offset=*/target_offset, /*length=*/fill_length, + /*pattern=*/pattern, + /*pattern_length=*/pattern_length)); + IREE_CHECK_OK(iree_hal_command_buffer_end(command_buffer)); + IREE_CHECK_OK(SubmitCommandBufferAndWait(IREE_HAL_COMMAND_CATEGORY_ANY, + command_buffer)); + + std::vector actual_data(buffer_size); + IREE_CHECK_OK( + iree_hal_buffer_read_data(device_buffer, /*source_offset=*/0, + /*target_buffer=*/actual_data.data(), + /*data_length=*/buffer_size)); + + iree_hal_command_buffer_release(command_buffer); + iree_hal_buffer_release(device_buffer); + + return actual_data; + } + static constexpr iree_device_size_t kBufferSize = 4096; }; @@ -83,67 +141,6 @@ TEST_P(CommandBufferTest, SubmitEmpty) { iree_hal_command_buffer_release(command_buffer); } -TEST_P(CommandBufferTest, FillBufferWithRepeatedBytes) { - iree_hal_command_buffer_t* command_buffer; - IREE_ASSERT_OK(iree_hal_command_buffer_create( - device_, IREE_HAL_COMMAND_BUFFER_MODE_ONE_SHOT, - IREE_HAL_COMMAND_CATEGORY_TRANSFER, IREE_HAL_QUEUE_AFFINITY_ANY, - &command_buffer)); - - iree_hal_buffer_t* device_buffer; - IREE_ASSERT_OK(iree_hal_allocator_allocate_buffer( - device_allocator_, - IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL | IREE_HAL_MEMORY_TYPE_HOST_VISIBLE, - IREE_HAL_BUFFER_USAGE_ALL, kBufferSize, &device_buffer)); - - std::vector reference_buffer(kBufferSize); - - IREE_ASSERT_OK(iree_hal_command_buffer_begin(command_buffer)); - - // Fill the device buffer with segments of different values so that we can - // test both fill and offset/size. - uint8_t val1 = 0x07; - IREE_ASSERT_OK(iree_hal_command_buffer_fill_buffer( - command_buffer, device_buffer, - /*target_offset=*/0, /*length=*/kBufferSize / 4, /*pattern=*/&val1, - /*pattern_length=*/sizeof(val1))); - std::memset(reference_buffer.data(), val1, kBufferSize / 4); - - uint8_t val2 = 0xbe; - IREE_ASSERT_OK( - iree_hal_command_buffer_fill_buffer(command_buffer, device_buffer, - /*target_offset=*/kBufferSize / 4, - /*length=*/kBufferSize / 4, - /*pattern=*/&val2, - /*pattern_length=*/sizeof(val2))); - std::memset(reference_buffer.data() + kBufferSize / 4, val2, kBufferSize / 4); - - uint8_t val3 = 0x54; - IREE_ASSERT_OK( - iree_hal_command_buffer_fill_buffer(command_buffer, device_buffer, - /*target_offset=*/kBufferSize / 2, - /*length=*/kBufferSize / 2, - /*pattern=*/&val3, - /*pattern_length=*/sizeof(val3))); - std::memset(reference_buffer.data() + kBufferSize / 2, val3, kBufferSize / 2); - - IREE_ASSERT_OK(iree_hal_command_buffer_end(command_buffer)); - - IREE_ASSERT_OK(SubmitCommandBufferAndWait(IREE_HAL_COMMAND_CATEGORY_TRANSFER, - command_buffer)); - - // Read the device buffer and compare. - std::vector actual_data(kBufferSize); - IREE_ASSERT_OK(iree_hal_buffer_read_data(device_buffer, /*source_offset=*/0, - /*target_buffer=*/actual_data.data(), - /*data_length=*/kBufferSize)); - EXPECT_THAT(actual_data, ContainerEq(reference_buffer)); - - // Must release the command buffer before resources used by it. - iree_hal_command_buffer_release(command_buffer); - iree_hal_buffer_release(device_buffer); -} - TEST_P(CommandBufferTest, CopyWholeBuffer) { iree_hal_command_buffer_t* command_buffer; IREE_ASSERT_OK(iree_hal_command_buffer_create( @@ -257,6 +254,126 @@ TEST_P(CommandBufferTest, CopySubBuffer) { iree_hal_buffer_release(host_buffer); } +TEST_P(CommandBufferTest, FillBuffer_pattern1_offset0_length1) { + iree_device_size_t buffer_size = 16; + iree_device_size_t target_offset = 0; + iree_device_size_t fill_length = 1; + uint8_t pattern = 0x07; + std::vector reference_buffer{0x07, 0x00, 0x00, 0x00, // + 0x00, 0x00, 0x00, 0x00, // + 0x00, 0x00, 0x00, 0x00, // + 0x00, 0x00, 0x00, 0x00}; + std::vector actual_buffer = + RunFillBufferTest(buffer_size, target_offset, fill_length, + (void*)&pattern, sizeof(pattern)); + EXPECT_THAT(actual_buffer, ContainerEq(reference_buffer)); +} + +TEST_P(CommandBufferTest, FillBuffer_pattern1_offset0_length3) { + iree_device_size_t buffer_size = 16; + iree_device_size_t target_offset = 0; + iree_device_size_t fill_length = 3; + uint8_t pattern = 0x07; + std::vector reference_buffer{0x07, 0x07, 0x07, 0x00, // + 0x00, 0x00, 0x00, 0x00, // + 0x00, 0x00, 0x00, 0x00, // + 0x00, 0x00, 0x00, 0x00}; + std::vector actual_buffer = + RunFillBufferTest(buffer_size, target_offset, fill_length, + (void*)&pattern, sizeof(pattern)); + EXPECT_THAT(actual_buffer, ContainerEq(reference_buffer)); +} + +TEST_P(CommandBufferTest, FillBuffer_pattern1_offset0_length8) { + iree_device_size_t buffer_size = 16; + iree_device_size_t target_offset = 0; + iree_device_size_t fill_length = 8; + uint8_t pattern = 0x07; + std::vector reference_buffer{0x07, 0x07, 0x07, 0x07, // + 0x07, 0x07, 0x07, 0x07, // + 0x00, 0x00, 0x00, 0x00, // + 0x00, 0x00, 0x00, 0x00}; + std::vector actual_buffer = + RunFillBufferTest(buffer_size, target_offset, fill_length, + (void*)&pattern, sizeof(pattern)); + EXPECT_THAT(actual_buffer, ContainerEq(reference_buffer)); +} + +TEST_P(CommandBufferTest, FillBuffer_pattern1_offset2_length8) { + iree_device_size_t buffer_size = 16; + iree_device_size_t target_offset = 2; + iree_device_size_t fill_length = 8; + uint8_t pattern = 0x07; + std::vector reference_buffer{0x00, 0x00, 0x07, 0x07, // + 0x07, 0x07, 0x07, 0x07, // + 0x07, 0x07, 0x00, 0x00, // + 0x00, 0x00, 0x00, 0x00}; + std::vector actual_buffer = + RunFillBufferTest(buffer_size, target_offset, fill_length, + (void*)&pattern, sizeof(pattern)); + EXPECT_THAT(actual_buffer, ContainerEq(reference_buffer)); +} + +TEST_P(CommandBufferTest, FillBuffer_pattern2_offset0_length8) { + iree_device_size_t buffer_size = 16; + iree_device_size_t target_offset = 0; + iree_device_size_t fill_length = 8; + uint16_t pattern = 0xAB23; + std::vector reference_buffer{0x23, 0xAB, 0x23, 0xAB, // + 0x23, 0xAB, 0x23, 0xAB, // + 0x00, 0x00, 0x00, 0x00, // + 0x00, 0x00, 0x00, 0x00}; + std::vector actual_buffer = + RunFillBufferTest(buffer_size, target_offset, fill_length, + (void*)&pattern, sizeof(pattern)); + EXPECT_THAT(actual_buffer, ContainerEq(reference_buffer)); +} + +TEST_P(CommandBufferTest, FillBuffer_pattern2_offset0_length10) { + iree_device_size_t buffer_size = 16; + iree_device_size_t target_offset = 0; + iree_device_size_t fill_length = 10; + uint16_t pattern = 0xAB23; + std::vector reference_buffer{0x23, 0xAB, 0x23, 0xAB, // + 0x23, 0xAB, 0x23, 0xAB, // + 0x23, 0xAB, 0x00, 0x00, // + 0x00, 0x00, 0x00, 0x00}; + std::vector actual_buffer = + RunFillBufferTest(buffer_size, target_offset, fill_length, + (void*)&pattern, sizeof(pattern)); + EXPECT_THAT(actual_buffer, ContainerEq(reference_buffer)); +} + +TEST_P(CommandBufferTest, FillBuffer_pattern2_offset2_length8) { + iree_device_size_t buffer_size = 16; + iree_device_size_t target_offset = 2; + iree_device_size_t fill_length = 8; + uint16_t pattern = 0xAB23; + std::vector reference_buffer{0x00, 0x00, 0x23, 0xAB, // + 0x23, 0xAB, 0x23, 0xAB, // + 0x23, 0xAB, 0x00, 0x00, // + 0x00, 0x00, 0x00, 0x00}; + std::vector actual_buffer = + RunFillBufferTest(buffer_size, target_offset, fill_length, + (void*)&pattern, sizeof(pattern)); + EXPECT_THAT(actual_buffer, ContainerEq(reference_buffer)); +} + +TEST_P(CommandBufferTest, FillBuffer_pattern4_offset0_length8) { + iree_device_size_t buffer_size = 16; + iree_device_size_t target_offset = 0; + iree_device_size_t fill_length = 8; + uint32_t pattern = 0xAB23CD45; + std::vector reference_buffer{0x45, 0xCD, 0x23, 0xAB, // + 0x45, 0xCD, 0x23, 0xAB, // + 0x00, 0x00, 0x00, 0x00, // + 0x00, 0x00, 0x00, 0x00}; + std::vector actual_buffer = + RunFillBufferTest(buffer_size, target_offset, fill_length, + (void*)&pattern, sizeof(pattern)); + EXPECT_THAT(actual_buffer, ContainerEq(reference_buffer)); +} + INSTANTIATE_TEST_SUITE_P( AllDrivers, CommandBufferTest, ::testing::ValuesIn(testing::EnumerateAvailableDrivers()), diff --git a/iree/hal/vulkan/BUILD b/iree/hal/vulkan/BUILD index 97e243af17bd..d531e68baf25 100644 --- a/iree/hal/vulkan/BUILD +++ b/iree/hal/vulkan/BUILD @@ -26,6 +26,8 @@ cc_library( name = "vulkan", srcs = [ "api.cc", + "builtin_executables.cc", + "builtin_executables.h", "command_queue.h", "debug_reporter.cc", "debug_reporter.h", @@ -92,6 +94,7 @@ cc_library( "//iree/base/internal:synchronization", "//iree/base/internal/flatcc:parsing", "//iree/hal", + "//iree/hal/vulkan/builtin", "//iree/hal/vulkan/util:arena", "//iree/hal/vulkan/util:intrusive_list", "//iree/hal/vulkan/util:ref_ptr", diff --git a/iree/hal/vulkan/CMakeLists.txt b/iree/hal/vulkan/CMakeLists.txt index 8126033e4f52..caee9456689e 100644 --- a/iree/hal/vulkan/CMakeLists.txt +++ b/iree/hal/vulkan/CMakeLists.txt @@ -23,6 +23,8 @@ iree_cc_library( "vulkan_driver.h" SRCS "api.cc" + "builtin_executables.cc" + "builtin_executables.h" "command_queue.h" "debug_reporter.cc" "debug_reporter.h" @@ -82,6 +84,7 @@ iree_cc_library( iree::base::logging iree::base::tracing iree::hal + iree::hal::vulkan::builtin iree::hal::vulkan::util::arena iree::hal::vulkan::util::intrusive_list iree::hal::vulkan::util::ref_ptr diff --git a/iree/hal/vulkan/builtin/BUILD b/iree/hal/vulkan/builtin/BUILD new file mode 100644 index 000000000000..083e92f188db --- /dev/null +++ b/iree/hal/vulkan/builtin/BUILD @@ -0,0 +1,24 @@ +# Copyright 2021 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +load("//build_tools/embed_data:build_defs.bzl", "c_embed_data") + +package( + default_visibility = ["//visibility:public"], + features = ["layering_check"], + licenses = ["notice"], # Apache 2.0 +) + +c_embed_data( + name = "builtin", + srcs = [ + "fill_unaligned.spv", + ], + c_file_output = "builtin_shaders_spv.c", + flatten = True, + h_file_output = "builtin_shaders_spv.h", + identifier = "builtin_shaders_spv", +) diff --git a/iree/hal/vulkan/builtin/CMakeLists.txt b/iree/hal/vulkan/builtin/CMakeLists.txt new file mode 100644 index 000000000000..251419690c46 --- /dev/null +++ b/iree/hal/vulkan/builtin/CMakeLists.txt @@ -0,0 +1,28 @@ +################################################################################ +# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from # +# iree/hal/vulkan/builtin/BUILD # +# # +# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary # +# CMake-only content. # +# # +# To disable autogeneration for this file entirely, delete this header. # +################################################################################ + +iree_add_all_subdirs() + +iree_c_embed_data( + NAME + builtin + SRCS + "fill_unaligned.spv" + C_FILE_OUTPUT + "builtin_shaders_spv.c" + H_FILE_OUTPUT + "builtin_shaders_spv.h" + IDENTIFIER + "builtin_shaders_spv" + FLATTEN + PUBLIC +) + +### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ### diff --git a/iree/hal/vulkan/builtin/compile_shaders.sh b/iree/hal/vulkan/builtin/compile_shaders.sh new file mode 100644 index 000000000000..fd5f571d8803 --- /dev/null +++ b/iree/hal/vulkan/builtin/compile_shaders.sh @@ -0,0 +1,24 @@ +#!/bin/bash +# Copyright 2021 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +# Compiles input .glsl files into output .spv binary files. As these files are +# updated infrequently and their binary sizes are small, we check in both files +# and don't take a hard dependency on the shader compiler tool. +# +# To use, ensure `glslc` is on your PATH (such as by installing the Vulkan SDK +# or builting it from its source at https://github.com/google/shaderc) and run +# the script. + +set -e +set -x + +BUILTIN_DIR="$(dirname $0)" + +glslc \ + -Os -fshader-stage=compute -mfmt=bin \ + ${BUILTIN_DIR}/fill_unaligned.glsl \ + -o ${BUILTIN_DIR}/fill_unaligned.spv diff --git a/iree/hal/vulkan/builtin/fill_unaligned.glsl b/iree/hal/vulkan/builtin/fill_unaligned.glsl new file mode 100644 index 000000000000..9ba434e66810 --- /dev/null +++ b/iree/hal/vulkan/builtin/fill_unaligned.glsl @@ -0,0 +1,64 @@ +// Copyright 2021 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#version 450 + +// Polyfill for buffer fills that are not aligned to 4 byte offsets or lengths. +// This only implements the unaligned edges of fill operations. vkCmdFillBuffer +// should be used for the aligned interior (if any). +// +// Repeats the 4 byte value |fill_pattern| into |output_elements|, between +// |fill_offset_bytes| and |fill_offset_bytes| + |fill_length_bytes|. + +layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in; + +layout(set = 3, binding = 0) buffer OutputBuffer { uint output_elements[]; }; + +layout(push_constant) uniform Constants { + // TODO(scotttodd): low and high for 8 byte pattern + uint fill_pattern; + uint fill_pattern_width; // should be 1 or 2 (or 8 later on) + uint fill_offset_bytes; // must be aligned to pattern width + uint fill_length_bytes; +} input_constants; + +void FillBufferUnalignedHelper(uint fill_offset_bytes, uint fill_length_bytes) { + uint fill_aligned_offset = fill_offset_bytes % 4; + uint fill_aligned_start_bytes = fill_offset_bytes - fill_aligned_offset; + uint fill_aligned_start_index = fill_aligned_start_bytes / 4; + + uint shifted_pattern = 0x00000000; + if (input_constants.fill_pattern_width == 1) { + // Shift the pattern into each segment that is within the fill range. + uint fill_start = fill_aligned_offset; + uint fill_end = min(4, fill_start + fill_length_bytes); + for (uint i = fill_start; i < fill_end; ++i) { + shifted_pattern |= input_constants.fill_pattern << (8 * i); + } + } else if (input_constants.fill_pattern_width == 2) { + // Shift the pattern into the only supported segment in the fill range. + shifted_pattern = input_constants.fill_pattern << (8 * fill_aligned_offset); + } + output_elements[fill_aligned_start_index] = shifted_pattern; +} + +void main() { + uint start_byte = input_constants.fill_offset_bytes; + uint end_byte = + input_constants.fill_offset_bytes + input_constants.fill_length_bytes; + + // Unaligned start fill, if needed. + if (start_byte % 4 != 0 || input_constants.fill_length_bytes < 4) { + FillBufferUnalignedHelper(start_byte, input_constants.fill_length_bytes); + } + // Unaligned end fill, if needed. + if ((end_byte % 4 != 0) && + (start_byte % 4 + input_constants.fill_length_bytes > 4)) { + uint end_rounded_down = (end_byte / 4) * 4; + uint length_end = end_byte - end_rounded_down; + FillBufferUnalignedHelper(end_rounded_down, length_end); + } +} diff --git a/iree/hal/vulkan/builtin/fill_unaligned.spv b/iree/hal/vulkan/builtin/fill_unaligned.spv new file mode 100644 index 000000000000..d457e5d2887a Binary files /dev/null and b/iree/hal/vulkan/builtin/fill_unaligned.spv differ diff --git a/iree/hal/vulkan/builtin_executables.cc b/iree/hal/vulkan/builtin_executables.cc new file mode 100644 index 000000000000..3d6b918af532 --- /dev/null +++ b/iree/hal/vulkan/builtin_executables.cc @@ -0,0 +1,204 @@ +// Copyright 2021 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/hal/vulkan/builtin_executables.h" + +#include + +#include "iree/base/tracing.h" +#include "iree/hal/vulkan/builtin/builtin_shaders_spv.h" +#include "iree/hal/vulkan/native_descriptor_set.h" +#include "iree/hal/vulkan/native_descriptor_set_layout.h" +#include "iree/hal/vulkan/native_executable_layout.h" +#include "iree/hal/vulkan/status_util.h" + +namespace iree { +namespace hal { +namespace vulkan { + +namespace { + +typedef struct iree_hal_vulkan_builtin_fill_unaligned_constants_t { + uint32_t fill_pattern; + uint32_t fill_pattern_width; + uint32_t fill_offset_bytes; + uint32_t fill_length_bytes; +} iree_hal_vulkan_builtin_fill_unaligned_constants_t; + +static_assert(sizeof(iree_hal_vulkan_builtin_fill_unaligned_constants_t) == + IREE_HAL_VULKAN_BUILTIN_PUSH_CONSTANT_COUNT, + "push constant count must match struct size"); + +} // namespace + +BuiltinExecutables::BuiltinExecutables(VkDeviceHandle* logical_device) + : logical_device_(logical_device) {} + +BuiltinExecutables::~BuiltinExecutables() { + if (pipeline_ != VK_NULL_HANDLE) { + logical_device_->syms()->vkDestroyPipeline(*logical_device_, pipeline_, + logical_device_->allocator()); + } + + if (executable_layout_) { + iree_hal_executable_layout_destroy(executable_layout_); + } + + for (size_t i = 0; i < IREE_HAL_VULKAN_BUILTIN_DESCRIPTOR_SET_COUNT; ++i) { + iree_hal_descriptor_set_layout_release(descriptor_set_layouts_[i]); + } +} + +iree_status_t BuiltinExecutables::InitializeExecutables() { + IREE_TRACE_SCOPE(); + + // Create descriptor set layouts for our compute pipeline. + // Even though we're just using one set, we still need to create layout + // bindings for those preceding it. + for (size_t i = 0; i < IREE_HAL_VULKAN_BUILTIN_DESCRIPTOR_SET_COUNT; ++i) { + iree_hal_descriptor_set_layout_t* layout = NULL; + iree_hal_descriptor_set_layout_binding_t layout_binding; + layout_binding.binding = 0; + layout_binding.type = IREE_HAL_DESCRIPTOR_TYPE_STORAGE_BUFFER; + layout_binding.access = i < IREE_HAL_VULKAN_BUILTIN_DESCRIPTOR_SET + ? IREE_HAL_MEMORY_ACCESS_NONE + : IREE_HAL_MEMORY_ACCESS_WRITE; + IREE_RETURN_IF_ERROR(iree_hal_vulkan_native_descriptor_set_layout_create( + logical_device_, + i < IREE_HAL_VULKAN_BUILTIN_DESCRIPTOR_SET + ? IREE_HAL_DESCRIPTOR_SET_LAYOUT_USAGE_TYPE_IMMUTABLE + : IREE_HAL_DESCRIPTOR_SET_LAYOUT_USAGE_TYPE_PUSH_ONLY, + /*binding_count=*/1, &layout_binding, &layout)); + descriptor_set_layouts_[i] = layout; + } + + iree_status_t status = iree_ok_status(); + + // Create shader module. + VkShaderModule fill_unaligned_shader = VK_NULL_HANDLE; + if (iree_status_is_ok(status)) { + VkShaderModuleCreateInfo shader_create_info; + shader_create_info.sType = VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO; + shader_create_info.pNext = NULL; + shader_create_info.flags = 0; + shader_create_info.codeSize = builtin_shaders_spv_create()[0].size; + shader_create_info.pCode = + (const uint32_t*)builtin_shaders_spv_create()[0].data; + status = VK_RESULT_TO_STATUS(logical_device_->syms()->vkCreateShaderModule( + *logical_device_, &shader_create_info, logical_device_->allocator(), + &fill_unaligned_shader)); + } + + // Create pipeline layout. + if (iree_status_is_ok(status)) { + status = iree_hal_vulkan_native_executable_layout_create( + logical_device_, IREE_HAL_VULKAN_BUILTIN_PUSH_CONSTANT_COUNT / 4, + IREE_HAL_VULKAN_BUILTIN_DESCRIPTOR_SET_COUNT, descriptor_set_layouts_, + &executable_layout_); + } + + // Create pipeline. + if (iree_status_is_ok(status)) { + VkComputePipelineCreateInfo pipeline_create_info; + pipeline_create_info.sType = VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO; + pipeline_create_info.pNext = NULL; + pipeline_create_info.flags = VK_PIPELINE_CREATE_ALLOW_DERIVATIVES_BIT; + pipeline_create_info.layout = + iree_hal_vulkan_native_executable_layout_handle(executable_layout_); + pipeline_create_info.basePipelineHandle = VK_NULL_HANDLE; + pipeline_create_info.basePipelineIndex = 0; + VkPipelineShaderStageCreateInfo* stage_create_info = + &pipeline_create_info.stage; + stage_create_info->sType = + VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO; + stage_create_info->pNext = NULL; + stage_create_info->flags = 0; + stage_create_info->stage = VK_SHADER_STAGE_COMPUTE_BIT; + stage_create_info->module = fill_unaligned_shader; + stage_create_info->pName = "main"; + stage_create_info->pSpecializationInfo = NULL; + status = + VK_RESULT_TO_STATUS(logical_device_->syms()->vkCreateComputePipelines( + *logical_device_, /*pipeline_cache=*/VK_NULL_HANDLE, + /*pipeline_count=*/1, &pipeline_create_info, + logical_device_->allocator(), &pipeline_)); + } + + // Destroy shader module now that the pipeline is created. + if (fill_unaligned_shader != VK_NULL_HANDLE) { + logical_device_->syms()->vkDestroyShaderModule( + *logical_device_, fill_unaligned_shader, logical_device_->allocator()); + } + + return status; +} + +iree_status_t BuiltinExecutables::FillBufferUnaligned( + VkCommandBuffer command_buffer, DescriptorSetArena* descriptor_set_arena, + iree_hal_buffer_t* target_buffer, iree_device_size_t target_offset, + iree_device_size_t length, const void* pattern, + iree_host_size_t pattern_length, const void* push_constants_to_restore) { + IREE_TRACE_SCOPE(); + + iree_hal_vulkan_builtin_fill_unaligned_constants_t constants; + switch (pattern_length) { + case 1: + constants.fill_pattern = *static_cast(pattern); + break; + case 2: + constants.fill_pattern = *static_cast(pattern); + break; + case 4: + constants.fill_pattern = *static_cast(pattern); + break; + default: + return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "pattern length (%" PRIhsz + ") is not a power of two or is too large", + pattern_length); + } + + iree_hal_descriptor_set_binding_t binding; + binding.binding = 0; + binding.buffer = target_buffer; + binding.offset = 0; + binding.length = IREE_WHOLE_BUFFER; + IREE_RETURN_IF_ERROR(descriptor_set_arena->BindDescriptorSet( + command_buffer, executable_layout_, + IREE_HAL_VULKAN_BUILTIN_DESCRIPTOR_SET, /*binding_count=*/1, &binding)); + + logical_device_->syms()->vkCmdBindPipeline( + command_buffer, VK_PIPELINE_BIND_POINT_COMPUTE, pipeline_); + + constants.fill_pattern_width = pattern_length; + constants.fill_offset_bytes = target_offset; + constants.fill_length_bytes = length; + logical_device_->syms()->vkCmdPushConstants( + command_buffer, + iree_hal_vulkan_native_executable_layout_handle(executable_layout_), + VK_SHADER_STAGE_COMPUTE_BIT, /*offset=*/0, + sizeof(iree_hal_vulkan_builtin_fill_unaligned_constants_t), &constants); + + // TODO(scotttodd): insert memory barrier if we need to do dispatch<->dispatch + // synchronization. The barriers inserted normally by callers would be for + // transfer<->dispatch. + + logical_device_->syms()->vkCmdDispatch(command_buffer, 1, 1, 1); + + // Restore push constants. + logical_device_->syms()->vkCmdPushConstants( + command_buffer, + iree_hal_vulkan_native_executable_layout_handle(executable_layout_), + VK_SHADER_STAGE_COMPUTE_BIT, /*offset=*/0, + sizeof(iree_hal_vulkan_builtin_fill_unaligned_constants_t), + push_constants_to_restore); + + return iree_ok_status(); +} + +} // namespace vulkan +} // namespace hal +} // namespace iree diff --git a/iree/hal/vulkan/builtin_executables.h b/iree/hal/vulkan/builtin_executables.h new file mode 100644 index 000000000000..ea251027dfeb --- /dev/null +++ b/iree/hal/vulkan/builtin_executables.h @@ -0,0 +1,69 @@ +// Copyright 2021 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_HAL_VULKAN_BUILTIN_EXECUTABLES_H_ +#define IREE_HAL_VULKAN_BUILTIN_EXECUTABLES_H_ + +#include + +#include "iree/base/api.h" +#include "iree/hal/api.h" +#include "iree/hal/vulkan/descriptor_set_arena.h" +#include "iree/hal/vulkan/dynamic_symbols.h" +#include "iree/hal/vulkan/handle_util.h" +#include "iree/hal/vulkan/util/ref_ptr.h" + +namespace iree { +namespace hal { +namespace vulkan { + +// The `maxBoundDescriptorSets` limit is 4 on many devices we support and we +// want to avoid conflicts with what the compiler uses, so we'll expect the +// compiler to have reserved the index 3 for our exclusive use. +#define IREE_HAL_VULKAN_BUILTIN_DESCRIPTOR_SET_COUNT 4 +#define IREE_HAL_VULKAN_BUILTIN_DESCRIPTOR_SET 3 + +#define IREE_HAL_VULKAN_BUILTIN_PUSH_CONSTANT_COUNT 16 + +class BuiltinExecutables { + public: + BuiltinExecutables(VkDeviceHandle* logical_device); + ~BuiltinExecutables(); + + const ref_ptr& syms() const { + return logical_device_->syms(); + } + + iree_status_t InitializeExecutables(); + + // Fills a buffer without 4 byte offset or length requirements. + // + // This only implements the unaligned edges of fills, vkCmdFillBuffer should + // be used for the aligned interior (if any). + // + // |push_constants_to_restore| will be pushed using vkCmdPushConstants over + // the bytes used by this call. + iree_status_t FillBufferUnaligned( + VkCommandBuffer command_buffer, DescriptorSetArena* descriptor_set_arena, + iree_hal_buffer_t* target_buffer, iree_device_size_t target_offset, + iree_device_size_t length, const void* pattern, + iree_host_size_t pattern_length, const void* push_constants_to_restore); + + private: + VkDeviceHandle* logical_device_ = NULL; + + iree_hal_descriptor_set_layout_t* + descriptor_set_layouts_[IREE_HAL_VULKAN_BUILTIN_DESCRIPTOR_SET_COUNT] = { + NULL}; + iree_hal_executable_layout_t* executable_layout_ = NULL; + VkPipeline pipeline_ = VK_NULL_HANDLE; +}; + +} // namespace vulkan +} // namespace hal +} // namespace iree + +#endif // IREE_HAL_VULKAN_BUILTIN_EXECUTABLES_H_ diff --git a/iree/hal/vulkan/descriptor_set_arena.cc b/iree/hal/vulkan/descriptor_set_arena.cc index d6056691f187..cefa6bc39dcf 100644 --- a/iree/hal/vulkan/descriptor_set_arena.cc +++ b/iree/hal/vulkan/descriptor_set_arena.cc @@ -42,27 +42,32 @@ static void PopulateDescriptorSetWriteInfos( iree_hal_buffer_allocated_buffer(binding.buffer)); buffer_info.offset = iree_hal_buffer_byte_offset(binding.buffer) + binding.offset; - // Round up to a multiple of 32-bit. 32-bit is the most native bitwidth on - // GPUs; it has the best support compared to other bitwidths. We use VMA to - // manage GPU memory for us and VMA should already handled proper alignment - // when performing allocations; here we just need to provide the proper - // "view" to Vulkan drivers over the allocated memory. - // - // Note this is needed because we can see unusal buffers like tensor<3xi8>. - // Depending on GPU capabilities, this might not always be directly - // supported by the hardware. Under such circumstances, we need to emulate - // i8 support with i32. Shader CodeGen takes care of that: the shader will - // read the buffer as tensor and perform bit shifts to extract each - // byte and conduct computations. The extra additional byte is read but - // not really used by the shader. Here in application we need to match the - // ABI and provide the buffer as 32-bit aligned, otherwise the whole read by - // the shader is considered as out of bounds per the Vulkan spec. - // See https://github.com/google/iree/issues/2022#issuecomment-640617234 - // for more details. - buffer_info.range = iree_device_align( - std::min(binding.length, - iree_hal_buffer_byte_length(binding.buffer) - binding.offset), - 4); + if (binding.length == IREE_WHOLE_BUFFER) { + buffer_info.range = VK_WHOLE_SIZE; + } else { + // Round up to a multiple of 32-bit. 32-bit is the most native bitwidth on + // GPUs; it has the best support compared to other bitwidths. We use VMA + // to manage GPU memory for us and VMA should already handled proper + // alignment when performing allocations; here we just need to provide the + // proper "view" to Vulkan drivers over the allocated memory. + // + // Note this is needed because we can see unusal buffers like + // tensor<3xi8>. Depending on GPU capabilities, this might not always be + // directly supported by the hardware. Under such circumstances, we need + // to emulate i8 support with i32. Shader CodeGen takes care of that: the + // shader will read the buffer as tensor and perform bit shifts to + // extract each byte and conduct computations. The extra additional byte + // is read but not really used by the shader. Here in application we need + // to match the ABI and provide the buffer as 32-bit aligned, otherwise + // the whole read by the shader is considered as out of bounds per the + // Vulkan spec. See + // https://github.com/google/iree/issues/2022#issuecomment-640617234 for + // more details. + buffer_info.range = iree_device_align( + std::min(binding.length, iree_hal_buffer_byte_length(binding.buffer) - + binding.offset), + 4); + } auto& write_info = write_infos[i]; write_info.sType = VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET; diff --git a/iree/hal/vulkan/direct_command_buffer.cc b/iree/hal/vulkan/direct_command_buffer.cc index 96f8eedd2ba5..734fac1b202e 100644 --- a/iree/hal/vulkan/direct_command_buffer.cc +++ b/iree/hal/vulkan/direct_command_buffer.cc @@ -50,6 +50,15 @@ typedef struct iree_hal_vulkan_direct_command_buffer_t { // This must remain valid until all in-flight submissions of the command // buffer complete. DescriptorSetGroup descriptor_set_group; + + BuiltinExecutables* builtin_executables; + + // Shadow copy of push constants used during normal operation, for restoring + // after builtin_executables uses vkCmdPushConstants. Size must be greater + // than or equal to the push constant memory used by builtin_executables. + // TODO(scotttodd): use [maxPushConstantsSize - 16, maxPushConstantsSize] + // instead of [0, 16] to reduce frequency of updates + uint8_t push_constants_storage[IREE_HAL_VULKAN_BUILTIN_PUSH_CONSTANT_COUNT]; } iree_hal_vulkan_direct_command_buffer_t; extern const iree_hal_command_buffer_vtable_t @@ -71,6 +80,7 @@ iree_status_t iree_hal_vulkan_direct_command_buffer_allocate( iree_hal_queue_affinity_t queue_affinity, iree_hal_vulkan_tracing_context_t* tracing_context, iree::hal::vulkan::DescriptorPoolCache* descriptor_pool_cache, + iree::hal::vulkan::BuiltinExecutables* builtin_executables, iree_hal_command_buffer_t** out_command_buffer) { IREE_ASSERT_ARGUMENT(logical_device); IREE_ASSERT_ARGUMENT(command_pool); @@ -109,6 +119,8 @@ iree_status_t iree_hal_vulkan_direct_command_buffer_allocate( DescriptorSetArena(descriptor_pool_cache); new (&command_buffer->descriptor_set_group) DescriptorSetGroup(); + command_buffer->builtin_executables = builtin_executables; + *out_command_buffer = (iree_hal_command_buffer_t*)command_buffer; } else { command_pool->Free(handle); @@ -512,14 +524,47 @@ static iree_status_t iree_hal_vulkan_direct_command_buffer_fill_buffer( VkBuffer target_device_buffer = iree_hal_vulkan_vma_buffer_handle( iree_hal_buffer_allocated_buffer(target_buffer)); - // Note that fill only accepts 4-byte aligned values so we need to splat out - // our variable-length pattern. - target_offset += iree_hal_buffer_byte_offset(target_buffer); - uint32_t dword_pattern = - iree_hal_vulkan_splat_pattern(pattern, pattern_length); - command_buffer->syms->vkCmdFillBuffer(command_buffer->handle, - target_device_buffer, target_offset, - length, dword_pattern); + // vkCmdFillBuffer requires a 4 byte alignment for the offset, pattern, and + // length. We use a polyfill here that fills the unaligned start and end of + // fill operations, if needed. + + if (target_offset % 4 != 0 || length % 4 != 0) { + // TODO(scotttodd): only restore push constants that have been modified? + // (this can pass uninitialized memory right now, which + // *should* be safe but is wasteful) + IREE_RETURN_IF_ERROR( + command_buffer->builtin_executables->FillBufferUnaligned( + command_buffer->handle, &(command_buffer->descriptor_set_arena), + target_buffer, target_offset, length, pattern, pattern_length, + command_buffer->push_constants_storage)); + + // Continue using vkCmdFillBuffer below, but only for the inner aligned + // portion of the fill operation. + // For example: + // original offset 2, length 8 + // aligned offset 4, length 4 + // [0x00,0x00,0xAB,0xAB | 0xAB,0xAB,0xAB,0xAB | 0xAB,0xAB,0x00,0x00] + // <-------> <---------------------> <-------> + // unaligned vkCmdFillBuffer unaligned + iree_device_size_t aligned_target_offset = + iree_device_align(target_offset, 4); + iree_device_size_t target_end = target_offset + length; + iree_device_size_t rounded_down_target_end = (target_end / 4) * 4; + length -= (aligned_target_offset - target_offset) + + (target_end - rounded_down_target_end); + target_offset = aligned_target_offset; + } + + if (length > 0) { + // Note that vkCmdFillBuffer only accepts 4-byte aligned values so we need + // to splat out our variable-length pattern. + target_offset += iree_hal_buffer_byte_offset(target_buffer); + uint32_t dword_pattern = + iree_hal_vulkan_splat_pattern(pattern, pattern_length); + command_buffer->syms->vkCmdFillBuffer(command_buffer->handle, + target_device_buffer, target_offset, + length, dword_pattern); + } return iree_ok_status(); } @@ -584,6 +629,13 @@ static iree_status_t iree_hal_vulkan_direct_command_buffer_push_constants( iree_hal_vulkan_direct_command_buffer_t* command_buffer = iree_hal_vulkan_direct_command_buffer_cast(base_command_buffer); + iree_host_size_t storage_size = + IREE_ARRAYSIZE(command_buffer->push_constants_storage); + if (offset < storage_size) { + memcpy(command_buffer->push_constants_storage + offset, values, + std::min(values_length, storage_size) - offset); + } + command_buffer->syms->vkCmdPushConstants( command_buffer->handle, iree_hal_vulkan_native_executable_layout_handle(executable_layout), diff --git a/iree/hal/vulkan/direct_command_buffer.h b/iree/hal/vulkan/direct_command_buffer.h index cc1d097b4957..606e8595c019 100644 --- a/iree/hal/vulkan/direct_command_buffer.h +++ b/iree/hal/vulkan/direct_command_buffer.h @@ -9,6 +9,7 @@ #include "iree/base/api.h" #include "iree/hal/api.h" +#include "iree/hal/vulkan/builtin_executables.h" #include "iree/hal/vulkan/descriptor_pool_cache.h" #include "iree/hal/vulkan/handle_util.h" #include "iree/hal/vulkan/tracing.h" @@ -26,6 +27,7 @@ iree_status_t iree_hal_vulkan_direct_command_buffer_allocate( iree_hal_queue_affinity_t queue_affinity, iree_hal_vulkan_tracing_context_t* tracing_context, iree::hal::vulkan::DescriptorPoolCache* descriptor_pool_cache, + iree::hal::vulkan::BuiltinExecutables* builtin_executables, iree_hal_command_buffer_t** out_command_buffer); // Returns the native Vulkan VkCommandBuffer handle. diff --git a/iree/hal/vulkan/vulkan_device.cc b/iree/hal/vulkan/vulkan_device.cc index 953abba52367..c5a78734004a 100644 --- a/iree/hal/vulkan/vulkan_device.cc +++ b/iree/hal/vulkan/vulkan_device.cc @@ -14,6 +14,7 @@ #include "iree/base/internal/math.h" #include "iree/base/tracing.h" #include "iree/hal/vulkan/api.h" +#include "iree/hal/vulkan/builtin_executables.h" #include "iree/hal/vulkan/command_queue.h" #include "iree/hal/vulkan/descriptor_pool_cache.h" #include "iree/hal/vulkan/direct_command_buffer.h" @@ -363,6 +364,8 @@ typedef struct iree_hal_vulkan_device_t { // Used only for emulated timeline semaphores. TimePointSemaphorePool* semaphore_pool; TimePointFencePool* fence_pool; + + BuiltinExecutables* builtin_executables; } iree_hal_vulkan_device_t; extern const iree_hal_device_vtable_t iree_hal_vulkan_device_vtable; @@ -621,6 +624,12 @@ static iree_status_t iree_hal_vulkan_device_create_internal( transfer_queue_set); } + if (iree_status_is_ok(status)) { + device->builtin_executables = + new BuiltinExecutables(device->logical_device); + status = device->builtin_executables->InitializeExecutables(); + } + if (iree_status_is_ok(status)) { *out_device = (iree_hal_device_t*)device; } else { @@ -647,6 +656,7 @@ static void iree_hal_vulkan_device_destroy(iree_hal_device_t* base_device) { // Now that no commands are outstanding we can release all resources that may // have been in use. + delete device->builtin_executables; delete device->descriptor_pool_cache; delete device->semaphore_pool; delete device->fence_pool; @@ -930,6 +940,12 @@ static CommandQueue* iree_hal_vulkan_device_select_queue( iree_hal_vulkan_device_t* device, iree_hal_command_category_t command_categories, iree_hal_queue_affinity_t queue_affinity) { + // TODO(scotttodd): revisit queue selection logic and remove this + // * the unaligned buffer fill polyfill and tracing timestamp queries may + // both insert dispatches into command buffers that at compile time are + // expected to only contain transfer commands + // * we could set a bit at recording time if emulation or tracing is used + // and submit to the right queue based on that command_categories |= IREE_HAL_COMMAND_CATEGORY_DISPATCH; // TODO(benvanik): meaningful heuristics for affinity. We don't generate @@ -949,6 +965,12 @@ static iree_status_t iree_hal_vulkan_device_create_command_buffer( iree_hal_command_buffer_t** out_command_buffer) { iree_hal_vulkan_device_t* device = iree_hal_vulkan_device_cast(base_device); + // TODO(scotttodd): revisit queue selection logic and remove this + // * the unaligned buffer fill polyfill and tracing timestamp queries may + // both insert dispatches into command buffers that at compile time are + // expected to only contain transfer commands + // * we could set a bit at recording time if emulation or tracing is used + // and submit to the right queue based on that command_categories |= IREE_HAL_COMMAND_CATEGORY_DISPATCH; // Select the command pool to used based on the types of commands used. @@ -974,7 +996,7 @@ static iree_status_t iree_hal_vulkan_device_create_command_buffer( return iree_hal_vulkan_direct_command_buffer_allocate( device->logical_device, command_pool, mode, command_categories, queue_affinity, queue->tracing_context(), device->descriptor_pool_cache, - out_command_buffer); + device->builtin_executables, out_command_buffer); } static iree_status_t iree_hal_vulkan_device_create_descriptor_set( diff --git a/iree/tools/CMakeLists.txt b/iree/tools/CMakeLists.txt index 202a2d51820d..a8848849ced5 100644 --- a/iree/tools/CMakeLists.txt +++ b/iree/tools/CMakeLists.txt @@ -48,10 +48,6 @@ if(IREE_ENABLE_EMITC) set(IREE_EMITC_CONDITIONAL_DEP MLIREmitC ) - set(IREE_TRANSLATE_CONDITIONAL_DEPS - MLIREmitC - emitc::TranslateToCpp - ) endif() iree_cc_binary( @@ -359,7 +355,7 @@ if(${IREE_BUILD_COMPILER}) iree::compiler::Dialect::VM::Target::Bytecode iree::compiler::Dialect::VM::Target::init_targets iree::compiler::Translation::IREEVM - ${IREE_TRANSLATE_CONDITIONAL_DEPS} + ${IREE_EMITC_CONDITIONAL_DEP} PUBLIC ) diff --git a/iree/vm/value.h b/iree/vm/value.h index b292f9d5908f..445d80f6430b 100644 --- a/iree/vm/value.h +++ b/iree/vm/value.h @@ -58,6 +58,26 @@ typedef struct iree_vm_value_t { }; } iree_vm_value_t; +static inline iree_vm_value_t iree_vm_value_make_none() { + iree_vm_value_t result; + result.type = IREE_VM_VALUE_TYPE_NONE; + return result; +} + +static inline iree_vm_value_t iree_vm_value_make_i8(int8_t value) { + iree_vm_value_t result; + result.type = IREE_VM_VALUE_TYPE_I8; + result.i8 = value; + return result; +} + +static inline iree_vm_value_t iree_vm_value_make_i16(int16_t value) { + iree_vm_value_t result; + result.type = IREE_VM_VALUE_TYPE_I16; + result.i16 = value; + return result; +} + static inline iree_vm_value_t iree_vm_value_make_i32(int32_t value) { iree_vm_value_t result; result.type = IREE_VM_VALUE_TYPE_I32; @@ -82,7 +102,7 @@ static inline int64_t iree_vm_value_get_i64(iree_vm_value_t *value) { return value->i64; } -static inline iree_vm_value_t iree_vm_value_make_f32(int32_t value) { +static inline iree_vm_value_t iree_vm_value_make_f32(float value) { iree_vm_value_t result; result.type = IREE_VM_VALUE_TYPE_F32; result.f32 = value;