From 783965cfb84420ba8fbcbc52a57c71316c12b8ba Mon Sep 17 00:00:00 2001 From: Merry Date: Sun, 28 Jan 2024 14:13:15 +0000 Subject: [PATCH] oaknut: Implement DualCodeBlock and related support --- CMakeLists.txt | 1 + README.md | 40 ++++- include/oaknut/dual_code_block.hpp | 153 ++++++++++++++++++ .../oaknut/impl/arm64_encode_helpers.inc.hpp | 37 ++--- include/oaknut/oaknut.hpp | 124 +++++++------- tests/basic.cpp | 96 ++++++++--- tests/fpsimd.cpp | 24 +-- tests/general.cpp | 24 +-- tests/vector_code_gen.cpp | 10 +- 9 files changed, 380 insertions(+), 129 deletions(-) create mode 100644 include/oaknut/dual_code_block.hpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 6e8ec74..92a7850 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -18,6 +18,7 @@ endif() # Source project files set(header_files ${CMAKE_CURRENT_SOURCE_DIR}/include/oaknut/code_block.hpp + ${CMAKE_CURRENT_SOURCE_DIR}/include/oaknut/dual_code_block.hpp ${CMAKE_CURRENT_SOURCE_DIR}/include/oaknut/feature_detection/cpu_feature.hpp ${CMAKE_CURRENT_SOURCE_DIR}/include/oaknut/feature_detection/feature_detection.hpp ${CMAKE_CURRENT_SOURCE_DIR}/include/oaknut/feature_detection/id_registers.hpp diff --git a/README.md b/README.md index 4ae7913..8e32760 100644 --- a/README.md +++ b/README.md @@ -21,7 +21,7 @@ EmittedFunction EmitExample(oaknut::CodeGenerator& code, int value) { using namespace oaknut::util; - EmittedFunction result = code.ptr(); + EmittedFunction result = code.xptr(); code.MOV(W0, value); code.RET(); @@ -32,7 +32,7 @@ EmittedFunction EmitExample(oaknut::CodeGenerator& code, int value) int main() { oaknut::CodeBlock mem{4096}; - oaknut::CodeGenerator code{mem.ptr()}; + oaknut::CodeGenerator code{mem.ptr(), mem.ptr()}; mem.unprotect(); @@ -47,12 +47,45 @@ int main() } ``` +CodeGenerator takes two pointers. The first pointer is the memory address to write to, and the second pointer is the memory address that the code will be executing from. This allows you to write to a buffer before copying to the final destination for execution, or to have to use dual-mapped memory blocks to avoid memory protection overhead. + +Below is an example of using the oaknut-provided utility header for dual-mapped memory blocks: + +```cpp +#include +#include +#include + +using EmittedFunction = ; + +int main() +{ + using namespace oaknut::util; + + oaknut::DualCodeBlock mem{4096}; + oaknut::CodeGenerator code{mem.wptr(), mem.xptr()}; + + const auto result = code.xptr(); + + code.MOV(W0, value); + code.RET(); + + mem.invalidate_all(); + + std::printf("%i\n", fn()); // Output: 42 + + return 0; +} +``` + ### Emit to `std::vector` If you wish to merely emit code into memory without executing it, or if you are developing a cross-compiler that is not running on an ARM64 device, you can use `oaknut::VectorCodeGenerator` instead. Provide `oaknut::VectorCodeGenerator` with a reference to a `std::vector` and it will append to that vector. +The second pointer argument represents the destination address the code will eventually be executed from. + Simple example: ```cpp @@ -64,7 +97,7 @@ Simple example: int main() { std::vector vec; - oaknut::VectorCodeGenerator code{vec}; + oaknut::VectorCodeGenerator code{vec, (uint32_t*)0x1000}; code.MOV(W0, 42); code.RET(); @@ -81,6 +114,7 @@ int main() | ------ | --------------------- | -------- | | `` | Yes | Provides `CodeGenerator` and `VectorCodeGenerator` for code emission, as well as the `oaknut::util` namespace. | | `` | No | Utility header that provides `CodeBlock`, allocates, alters permissions of, and invalidates executable memory. | +| `` | No | Utility header that provides `DualCodeBlock`, which allocates two mirrored memory blocks (with RW and RX permissions respectively). | | `` | Yes | Provides `OaknutException` which is thrown on an error. | | `` | Yes | Utility header that provides `CpuFeatures` which can be used to describe AArch64 features. | | `` | No | Utility header that provides `detect_features` and `read_id_registers` for determining available AArch64 features. | diff --git a/include/oaknut/dual_code_block.hpp b/include/oaknut/dual_code_block.hpp new file mode 100644 index 0000000..2f5edda --- /dev/null +++ b/include/oaknut/dual_code_block.hpp @@ -0,0 +1,153 @@ +// SPDX-FileCopyrightText: Copyright (c) 2024 merryhime +// SPDX-License-Identifier: MIT + +#pragma once + +#include +#include +#include +#include + +#if defined(_WIN32) +# define NOMINMAX +# include +#elif defined(__APPLE__) +# include +# include + +# include +# include +# include +# include +# include +#else +# define _GNU_SOURCE +# include +#endif + +namespace oaknut { + +class DualCodeBlock { +public: + explicit DualCodeBlock(std::size_t size) + : m_size(size) + { +#if defined(_WIN32) + m_wmem = m_xmem = (std::uint32_t*)VirtualAlloc(nullptr, size, MEM_COMMIT, PAGE_EXECUTE_READWRITE); + if (m_wmem == nullptr) + throw std::bad_alloc{}; +#elif defined(__APPLE__) + m_wmem = (std::uint32_t*)mmap(nullptr, size, PROT_READ | PROT_WRITE, MAP_ANON | MAP_PRIVATE, -1, 0); + if (m_wmem == MAP_FAILED) + throw std::bad_alloc{}; + + vm_prot_t cur_prot, max_prot; + kern_return_t ret = vm_remap(mach_task_self(), (vm_address_t*)&m_xmem, size, 0, VM_FLAGS_ANYWHERE | VM_FLAGS_RANDOM_ADDR, mach_task_self(), (mach_vm_address_t)m_wmem, false, &cur_prot, &max_prot, VM_INHERIT_NONE); + if (ret != KERN_SUCCESS) + throw std::bad_alloc{}; + + mprotect(m_xmem, size, PROT_READ | PROT_EXEC); +#else + fd = memfd_create("oaknut_dual_code_block", 0); + if (fd < 0) + throw std::bad_alloc{}; + + int ret = ftruncate(fd, size); + if (ret != 0) + throw std::bad_alloc{}; + + m_wmem = (std::uint32_t*)mmap(nullptr, size, PROT_READ | PROT_WRITE, MAP_SHARED, fd, 0); + m_xmem = (std::uint32_t*)mmap(nullptr, size, PROT_READ | PROT_EXEC, MAP_SHARED, fd, 0); + + if (m_wmem == MAP_FAILED || m_xmem == MAP_FAILED) + throw std::bad_alloc{}; +#endif + } + + ~DualCodeBlock() + { +#if defined(_WIN32) + VirtualFree((void*)m_xmem, 0, MEM_RELEASE); +#elif defined(__APPLE__) +#else + munmap(m_wmem, m_size); + munmap(m_xmem, m_size); + close(fd); +#endif + } + + DualCodeBlock(const DualCodeBlock&) = delete; + DualCodeBlock& operator=(const DualCodeBlock&) = delete; + DualCodeBlock(DualCodeBlock&&) = delete; + DualCodeBlock& operator=(DualCodeBlock&&) = delete; + + /// Pointer to executable mirror of memory (permissions: R-X) + std::uint32_t* xptr() const + { + return m_xmem; + } + + /// Pointer to writeable mirror of memory (permissions: RW-) + std::uint32_t* wptr() const + { + return m_wmem; + } + + /// Invalidate should be used with executable memory pointers. + void invalidate(std::uint32_t* mem, std::size_t size) + { +#if defined(__APPLE__) + sys_icache_invalidate(mem, size); +#elif defined(_WIN32) + FlushInstructionCache(GetCurrentProcess(), mem, size); +#else + static std::size_t icache_line_size = 0x10000, dcache_line_size = 0x10000; + + std::uint64_t ctr; + __asm__ volatile("mrs %0, ctr_el0" + : "=r"(ctr)); + + const std::size_t isize = icache_line_size = std::min(icache_line_size, 4 << ((ctr >> 0) & 0xf)); + const std::size_t dsize = dcache_line_size = std::min(dcache_line_size, 4 << ((ctr >> 16) & 0xf)); + + const std::uintptr_t end = (std::uintptr_t)mem + size; + + for (std::uintptr_t addr = ((std::uintptr_t)mem) & ~(dsize - 1); addr < end; addr += dsize) { + __asm__ volatile("dc cvau, %0" + : + : "r"(addr) + : "memory"); + } + __asm__ volatile("dsb ish\n" + : + : + : "memory"); + + for (std::uintptr_t addr = ((std::uintptr_t)mem) & ~(isize - 1); addr < end; addr += isize) { + __asm__ volatile("ic ivau, %0" + : + : "r"(addr) + : "memory"); + } + __asm__ volatile("dsb ish\nisb\n" + : + : + : "memory"); +#endif + } + + void invalidate_all() + { + invalidate(m_xmem, m_size); + } + +protected: +#if !defined(_WIN32) && !defined(__APPLE__) + int fd = -1; +#endif + std::uint32_t* m_xmem = nullptr; + std::uint32_t* m_wmem = nullptr; + std::size_t m_size = 0; +}; + +} // namespace oaknut diff --git a/include/oaknut/impl/arm64_encode_helpers.inc.hpp b/include/oaknut/impl/arm64_encode_helpers.inc.hpp index 130a30a..1767f67 100644 --- a/include/oaknut/impl/arm64_encode_helpers.inc.hpp +++ b/include/oaknut/impl/arm64_encode_helpers.inc.hpp @@ -112,8 +112,8 @@ std::uint32_t encode(AddrOffset v) { static_assert(std::popcount(splat) == size - align); - const auto encode_fn = [](std::uintptr_t current_addr, std::uintptr_t target) { - const std::ptrdiff_t diff = target - current_addr; + const auto encode_fn = [](std::ptrdiff_t current_offset, std::ptrdiff_t target_offset) { + const std::ptrdiff_t diff = target_offset - current_offset; return pdep(AddrOffset::encode(diff)); }; @@ -122,19 +122,16 @@ std::uint32_t encode(AddrOffset v) return pdep(encoding); }, [&](Label* label) -> std::uint32_t { - if (label->m_addr) { - return encode_fn(Policy::current_address(), *label->m_addr); + if (label->m_offset) { + return encode_fn(Policy::offset(), *label->m_offset); } - label->m_wbs.emplace_back(Label::Writeback{Policy::current_address(), ~splat, static_cast(encode_fn)}); + label->m_wbs.emplace_back(Label::Writeback{Policy::offset(), ~splat, static_cast(encode_fn)}); return 0u; }, - [&]([[maybe_unused]] const void* p) -> std::uint32_t { - if constexpr (Policy::has_absolute_addresses) { - return encode_fn(Policy::current_address(), reinterpret_cast(p)); - } else { - throw OaknutException{ExceptionType::RequiresAbsoluteAddressesContext}; - } + [&](const void* p) -> std::uint32_t { + const std::ptrdiff_t diff = reinterpret_cast(p) - Policy::template xptr(); + return pdep(AddrOffset::encode(diff)); }, }, v.m_payload); @@ -145,25 +142,21 @@ std::uint32_t encode(PageOffset v) { static_assert(std::popcount(splat) == size); - const auto encode_fn = [](std::uintptr_t current_addr, std::uintptr_t target) { - return pdep(PageOffset::encode(current_addr, target)); + const auto encode_fn = [](std::ptrdiff_t current_offset, std::ptrdiff_t target_offset) { + return pdep(PageOffset::encode(std::bit_cast(current_offset), std::bit_cast(target_offset))); }; return std::visit(detail::overloaded{ [&](Label* label) -> std::uint32_t { - if (label->m_addr) { - return encode_fn(Policy::current_address(), *label->m_addr); + if (label->m_offset) { + return encode_fn(Policy::offset(), *label->m_offset); } - label->m_wbs.emplace_back(Label::Writeback{Policy::current_address(), ~splat, static_cast(encode_fn)}); + label->m_wbs.emplace_back(Label::Writeback{Policy::offset(), ~splat, static_cast(encode_fn)}); return 0u; }, - [&]([[maybe_unused]] const void* p) -> std::uint32_t { - if constexpr (Policy::has_absolute_addresses) { - return encode_fn(Policy::current_address(), reinterpret_cast(p)); - } else { - throw OaknutException{ExceptionType::RequiresAbsoluteAddressesContext}; - } + [&](const void* p) -> std::uint32_t { + return pdep(PageOffset::encode(Policy::template xptr(), reinterpret_cast(p))); }, }, v.m_payload); diff --git a/include/oaknut/oaknut.hpp b/include/oaknut/oaknut.hpp index 39b9a03..aa80f81 100644 --- a/include/oaknut/oaknut.hpp +++ b/include/oaknut/oaknut.hpp @@ -30,58 +30,56 @@ struct Label { bool is_bound() const { - return m_addr.has_value(); + return m_offset.has_value(); } - template - T ptr() const + std::ptrdiff_t offset() const { - static_assert(std::is_pointer_v || std::is_same_v || std::is_same_v); - return reinterpret_cast(m_addr.value()); + return m_offset.value(); } private: template friend class BasicCodeGenerator; - explicit Label(std::uintptr_t addr) - : m_addr(addr) + explicit Label(std::ptrdiff_t offset) + : m_offset(offset) {} - using EmitFunctionType = std::uint32_t (*)(std::uintptr_t wb_addr, std::uintptr_t resolved_addr); + using EmitFunctionType = std::uint32_t (*)(std::ptrdiff_t wb_offset, std::ptrdiff_t resolved_offset); struct Writeback { - std::uintptr_t m_wb_addr; + std::ptrdiff_t m_wb_offset; std::uint32_t m_mask; EmitFunctionType m_fn; }; - std::optional m_addr; + std::optional m_offset; std::vector m_wbs; }; template class BasicCodeGenerator : public Policy { public: - BasicCodeGenerator(typename Policy::constructor_argument_type arg) - : Policy(arg) + BasicCodeGenerator(typename Policy::constructor_argument_type arg, std::uint32_t* xmem) + : Policy(arg, xmem) {} Label l() const { - return Label{Policy::current_address()}; + return Label{Policy::offset()}; } void l(Label& label) const { - if (label.m_addr) + if (label.is_bound()) throw OaknutException{ExceptionType::LabelRedefinition}; - const auto target_addr = Policy::current_address(); - label.m_addr = target_addr; + const auto target_offset = Policy::offset(); + label.m_offset = target_offset; for (auto& wb : label.m_wbs) { - const std::uint32_t value = wb.m_fn(wb.m_wb_addr, target_addr); - Policy::set_at_address(wb.m_wb_addr, value, wb.m_mask); + const std::uint32_t value = wb.m_fn(wb.m_wb_offset, target_offset); + Policy::set_at_offset(wb.m_wb_offset, value, wb.m_mask); } label.m_wbs.clear(); } @@ -160,17 +158,13 @@ class BasicCodeGenerator : public Policy { // Convenience function for moving pointers to registers void MOVP2R(XReg xd, const void* addr) { - if constexpr (Policy::has_absolute_addresses) { - const int64_t diff = reinterpret_cast(addr) - Policy::current_address(); - if (diff >= -0xF'FFFF && diff <= 0xF'FFFF) { - ADR(xd, addr); - } else if (PageOffset<21, 12>::valid(Policy::current_address(), reinterpret_cast(addr))) { - ADRL(xd, addr); - } else { - MOV(xd, reinterpret_cast(addr)); - } + const int64_t diff = reinterpret_cast(addr) - Policy::template xptr(); + if (diff >= -0xF'FFFF && diff <= 0xF'FFFF) { + ADR(xd, addr); + } else if (PageOffset<21, 12>::valid(Policy::template xptr(), reinterpret_cast(addr))) { + ADRL(xd, addr); } else { - throw OaknutException{ExceptionType::RequiresAbsoluteAddressesContext}; + MOV(xd, reinterpret_cast(addr)); } } @@ -179,7 +173,7 @@ class BasicCodeGenerator : public Policy { if (alignment < 4 || (alignment & (alignment - 1)) != 0) throw OaknutException{ExceptionType::InvalidAlignment}; - while (Policy::template ptr() & (alignment - 1)) { + while (Policy::offset() & (alignment - 1)) { NOP(); } } @@ -209,23 +203,47 @@ class BasicCodeGenerator : public Policy { struct PointerCodeGeneratorPolicy { public: + std::ptrdiff_t offset() const + { + return (m_ptr - m_wmem) * sizeof(std::uint32_t); + } + + void set_offset(std::ptrdiff_t offset) + { + if ((offset % sizeof(std::uint32_t)) != 0) + throw OaknutException{ExceptionType::InvalidAlignment}; + m_ptr = m_wmem + offset / sizeof(std::uint32_t); + } + template - T ptr() const + T wptr() const { static_assert(std::is_pointer_v || std::is_same_v || std::is_same_v); return reinterpret_cast(m_ptr); } - void set_ptr(std::uint32_t* ptr_) + template + T xptr() const + { + static_assert(std::is_pointer_v || std::is_same_v || std::is_same_v); + return reinterpret_cast(m_xmem + (m_ptr - m_wmem)); + } + + void set_wptr(std::uint32_t* p) { - m_ptr = ptr_; + m_ptr = p; + } + + void set_xptr(std::uint32_t* p) + { + m_ptr = m_wmem + (p - m_xmem); } protected: using constructor_argument_type = std::uint32_t*; - PointerCodeGeneratorPolicy(std::uint32_t* ptr_) - : m_ptr(ptr_) + PointerCodeGeneratorPolicy(std::uint32_t* wmem, std::uint32_t* xmem) + : m_ptr(wmem), m_wmem(wmem), m_xmem(xmem) {} void append(std::uint32_t instruction) @@ -233,35 +251,37 @@ struct PointerCodeGeneratorPolicy { *m_ptr++ = instruction; } - static constexpr bool has_absolute_addresses = true; - - std::uintptr_t current_address() const - { - return reinterpret_cast(m_ptr); - } - - void set_at_address(std::uintptr_t addr, std::uint32_t value, std::uint32_t mask) const + void set_at_offset(std::ptrdiff_t offset, std::uint32_t value, std::uint32_t mask) const { - std::uint32_t* p = reinterpret_cast(addr); + std::uint32_t* p = m_wmem + offset / sizeof(std::uint32_t); *p = (*p & mask) | value; } private: std::uint32_t* m_ptr; + std::uint32_t* const m_wmem; + std::uint32_t* const m_xmem; }; struct VectorCodeGeneratorPolicy { public: std::ptrdiff_t offset() const { - return static_cast(m_vec.size() * sizeof(std::uint32_t)); + return m_vec.size() * sizeof(std::uint32_t); + } + + template + T xptr() const + { + static_assert(std::is_pointer_v || std::is_same_v || std::is_same_v); + return reinterpret_cast(m_xmem + m_vec.size()); } protected: using constructor_argument_type = std::vector&; - VectorCodeGeneratorPolicy(std::vector& vec) - : m_vec(vec) + VectorCodeGeneratorPolicy(std::vector& vec, std::uint32_t* xmem) + : m_vec(vec), m_xmem(xmem) {} void append(std::uint32_t instruction) @@ -269,21 +289,15 @@ struct VectorCodeGeneratorPolicy { m_vec.push_back(instruction); } - static constexpr bool has_absolute_addresses = false; - - std::uintptr_t current_address() const - { - return static_cast(m_vec.size() * sizeof(std::uint32_t)); - } - - void set_at_address(std::uintptr_t addr, std::uint32_t value, std::uint32_t mask) const + void set_at_offset(std::ptrdiff_t offset, std::uint32_t value, std::uint32_t mask) const { - std::uint32_t& p = m_vec[addr / sizeof(std::uint32_t)]; + std::uint32_t& p = m_vec[offset / sizeof(std::uint32_t)]; p = (p & mask) | value; } private: std::vector& m_vec; + std::uint32_t* const m_xmem; }; using CodeGenerator = BasicCodeGenerator; diff --git a/tests/basic.cpp b/tests/basic.cpp index 91bc846..38342ca 100644 --- a/tests/basic.cpp +++ b/tests/basic.cpp @@ -8,6 +8,7 @@ #include #include "oaknut/code_block.hpp" +#include "oaknut/dual_code_block.hpp" #include "oaknut/oaknut.hpp" #include "rand_int.hpp" @@ -17,7 +18,7 @@ using namespace oaknut::util; TEST_CASE("Basic Test") { CodeBlock mem{4096}; - CodeGenerator code{mem.ptr()}; + CodeGenerator code{mem.ptr(), mem.ptr()}; mem.unprotect(); @@ -31,14 +32,28 @@ TEST_CASE("Basic Test") REQUIRE(result == 42); } +TEST_CASE("Basic Test (Dual)") +{ + DualCodeBlock mem{4096}; + CodeGenerator code{mem.wptr(), mem.xptr()}; + + code.MOV(W0, 42); + code.RET(); + + mem.invalidate_all(); + + int result = ((int (*)())mem.xptr())(); + REQUIRE(result == 42); +} + TEST_CASE("Fibonacci") { CodeBlock mem{4096}; - CodeGenerator code{mem.ptr()}; + CodeGenerator code{mem.ptr(), mem.ptr()}; mem.unprotect(); - auto fib = code.ptr(); + auto fib = code.xptr(); Label start, end, zero, recurse; code.l(start); @@ -77,6 +92,49 @@ TEST_CASE("Fibonacci") REQUIRE(fib(9) == 34); } +TEST_CASE("Fibonacci (Dual)") +{ + DualCodeBlock mem{4096}; + CodeGenerator code{mem.wptr(), mem.xptr()}; + + auto fib = code.xptr(); + Label start, end, zero, recurse; + + code.l(start); + code.STP(X29, X30, SP, PRE_INDEXED, -32); + code.STP(X20, X19, SP, 16); + code.MOV(X29, SP); + code.MOV(W19, W0); + code.SUBS(W0, W0, 1); + code.B(LT, zero); + code.B(NE, recurse); + code.MOV(W0, 1); + code.B(end); + + code.l(zero); + code.MOV(W0, WZR); + code.B(end); + + code.l(recurse); + code.BL(start); + code.MOV(W20, W0); + code.SUB(W0, W19, 2); + code.BL(start); + code.ADD(W0, W0, W20); + + code.l(end); + code.LDP(X20, X19, SP, 16); + code.LDP(X29, X30, SP, POST_INDEXED, 32); + code.RET(); + + mem.invalidate_all(); + + REQUIRE(fib(0) == 0); + REQUIRE(fib(1) == 1); + REQUIRE(fib(5) == 5); + REQUIRE(fib(9) == 34); +} + TEST_CASE("Immediate generation (32-bit)", "[slow]") { CodeBlock mem{4096}; @@ -84,9 +142,9 @@ TEST_CASE("Immediate generation (32-bit)", "[slow]") for (int i = 0; i < 0x100000; i++) { const std::uint32_t value = RandInt(0, 0xffffffff); - CodeGenerator code{mem.ptr()}; + CodeGenerator code{mem.ptr(), mem.ptr()}; - auto f = code.ptr(); + auto f = code.xptr(); mem.unprotect(); code.MOV(W0, value); code.RET(); @@ -104,9 +162,9 @@ TEST_CASE("Immediate generation (64-bit)", "[slow]") for (int i = 0; i < 0x100000; i++) { const std::uint64_t value = RandInt(0, 0xffffffff'ffffffff); - CodeGenerator code{mem.ptr()}; + CodeGenerator code{mem.ptr(), mem.ptr()}; - auto f = code.ptr(); + auto f = code.xptr(); mem.unprotect(); code.MOV(X0, value); code.RET(); @@ -124,9 +182,9 @@ TEST_CASE("ADR", "[slow]") for (std::int64_t i = -1048576; i < 1048576; i++) { const std::intptr_t value = reinterpret_cast(mem.ptr()) + i; - CodeGenerator code{mem.ptr()}; + CodeGenerator code{mem.ptr(), mem.ptr()}; - auto f = code.ptr(); + auto f = code.xptr(); mem.unprotect(); code.ADR(X0, reinterpret_cast(value)); code.RET(); @@ -160,9 +218,9 @@ TEST_CASE("ADRP", "[slow]") const std::intptr_t value = reinterpret_cast(mem.ptr()) + diff; const std::uint64_t expect = static_cast(value) & ~static_cast(0xfff); - CodeGenerator code{mem.ptr()}; + CodeGenerator code{mem.ptr(), mem.ptr()}; - auto f = code.ptr(); + auto f = code.xptr(); mem.unprotect(); code.ADRP(X0, reinterpret_cast(value)); code.RET(); @@ -183,9 +241,9 @@ TEST_CASE("ADRL (near)") const std::int64_t diff = i; const std::intptr_t value = reinterpret_cast(mem_ptr) + diff; - CodeGenerator code{mem_ptr}; + CodeGenerator code{mem_ptr, mem_ptr}; - auto f = code.ptr(); + auto f = code.xptr(); mem.unprotect(); code.ADRL(X0, reinterpret_cast(value)); code.RET(); @@ -206,9 +264,9 @@ TEST_CASE("ADRL (far)", "[slow]") const std::int64_t diff = RandInt(-4294967296 + 100, 4294967295 - 100); const std::intptr_t value = reinterpret_cast(mem_ptr) + diff; - CodeGenerator code{mem_ptr}; + CodeGenerator code{mem_ptr, mem_ptr}; - auto f = code.ptr(); + auto f = code.xptr(); mem.unprotect(); code.ADRL(X0, reinterpret_cast(value)); code.RET(); @@ -230,9 +288,9 @@ TEST_CASE("MOVP2R (far)", "[slow]") std::numeric_limits::max()); const std::intptr_t value = reinterpret_cast(mem_ptr) + diff; - CodeGenerator code{mem_ptr}; + CodeGenerator code{mem_ptr, mem_ptr}; - auto f = code.ptr(); + auto f = code.xptr(); mem.unprotect(); code.MOVP2R(X0, reinterpret_cast(value)); code.RET(); @@ -252,9 +310,9 @@ TEST_CASE("MOVP2R (4GiB boundary)") const auto test = [&](std::int64_t diff) { const std::intptr_t value = reinterpret_cast(mem_ptr) + diff; - CodeGenerator code{mem_ptr}; + CodeGenerator code{mem_ptr, mem_ptr}; - auto f = code.ptr(); + auto f = code.xptr(); mem.unprotect(); code.MOVP2R(X0, reinterpret_cast(value)); code.RET(); diff --git a/tests/fpsimd.cpp b/tests/fpsimd.cpp index e0cb0e2..d164f8e 100644 --- a/tests/fpsimd.cpp +++ b/tests/fpsimd.cpp @@ -8,18 +8,18 @@ #include "oaknut/oaknut.hpp" -#define T(HEX, CMD) \ - TEST_CASE(#CMD) \ - { \ - using namespace oaknut; \ - using namespace oaknut::util; \ - \ - std::uint32_t result; \ - CodeGenerator code{&result}; \ - \ - code.CMD; \ - \ - REQUIRE(result == HEX); \ +#define T(HEX, CMD) \ + TEST_CASE(#CMD) \ + { \ + using namespace oaknut; \ + using namespace oaknut::util; \ + \ + std::uint32_t result; \ + CodeGenerator code{&result, &result}; \ + \ + code.CMD; \ + \ + REQUIRE(result == HEX); \ } T(0x5ee0bb61, ABS(D1, D27)) diff --git a/tests/general.cpp b/tests/general.cpp index 2caf465..0acb35f 100644 --- a/tests/general.cpp +++ b/tests/general.cpp @@ -8,18 +8,18 @@ #include "oaknut/oaknut.hpp" -#define T(HEX, CMD) \ - TEST_CASE(#CMD) \ - { \ - using namespace oaknut; \ - using namespace oaknut::util; \ - \ - std::uint32_t result; \ - CodeGenerator code{&result}; \ - \ - code.CMD; \ - \ - REQUIRE(result == HEX); \ +#define T(HEX, CMD) \ + TEST_CASE(#CMD) \ + { \ + using namespace oaknut; \ + using namespace oaknut::util; \ + \ + std::uint32_t result; \ + CodeGenerator code{&result, &result}; \ + \ + code.CMD; \ + \ + REQUIRE(result == HEX); \ } T(0x1a0f01c3, ADC(W3, W14, W15)) diff --git a/tests/vector_code_gen.cpp b/tests/vector_code_gen.cpp index baceeef..e06b135 100644 --- a/tests/vector_code_gen.cpp +++ b/tests/vector_code_gen.cpp @@ -18,14 +18,13 @@ using namespace oaknut::util; TEST_CASE("Basic Test (VectorCodeGenerator)") { + CodeBlock mem{4096}; std::vector vec; - VectorCodeGenerator code{vec}; + VectorCodeGenerator code{vec, mem.ptr()}; code.MOV(W0, 42); code.RET(); - CodeBlock mem{4096}; - mem.unprotect(); std::memcpy(mem.ptr(), vec.data(), vec.size() * sizeof(std::uint32_t)); mem.protect(); @@ -37,8 +36,9 @@ TEST_CASE("Basic Test (VectorCodeGenerator)") TEST_CASE("Fibonacci (VectorCodeGenerator)") { + CodeBlock mem{4096}; std::vector vec; - VectorCodeGenerator code{vec}; + VectorCodeGenerator code{vec, mem.ptr()}; Label start, end, zero, recurse; @@ -69,8 +69,6 @@ TEST_CASE("Fibonacci (VectorCodeGenerator)") code.LDP(X29, X30, SP, POST_INDEXED, 32); code.RET(); - CodeBlock mem{4096}; - mem.unprotect(); std::memcpy(mem.ptr(), vec.data(), vec.size() * sizeof(std::uint32_t)); mem.protect();