diff --git a/src/lib/pubkey/pqcrystals/pqcrystals_helpers.h b/src/lib/pubkey/pqcrystals/pqcrystals_helpers.h index c86a9bcd2a1..4d69815d34f 100644 --- a/src/lib/pubkey/pqcrystals/pqcrystals_helpers.h +++ b/src/lib/pubkey/pqcrystals/pqcrystals_helpers.h @@ -13,8 +13,10 @@ #include #include +#include #include +#include #include namespace Botan { @@ -133,6 +135,91 @@ consteval static auto precompute_zetas(T q, T monty, T root_of_unity) { return result; } +namespace detail { + +/** + * Wraps any XOF to limit the number of bytes that can be produced to @p bound. + * When the bound is reached, the XOF will throw an Internal_Error. + */ +template + requires requires(XofT xof) { + { xof.template output<1>() } -> std::convertible_to>; + { xof.template output<42>() } -> std::convertible_to>; + } +class Bounded_XOF final { + private: + template + using MappedValueT = std::invoke_result_t>; + + public: + template + constexpr static auto default_transformer(const std::array& x) { + return x; + } + + template + constexpr static bool default_predicate(const T&) { + return true; + } + + public: + Bounded_XOF() + requires std::default_initializable + : m_bytes_consumed(0) {} + + explicit Bounded_XOF(XofT xof) : m_xof(xof), m_bytes_consumed(0) {} + + /** + * @returns the next byte from the XOF that fulfills @p predicate. + */ + template )> + requires std::invocable + constexpr auto next_byte(PredicateFnT&& predicate = default_predicate<1, uint8_t>) { + return next<1>([](const auto bytes) { return bytes[0]; }, std::forward(predicate)); + } + + /** + * Pulls the next @p bytes from the XOF and applies @p transformer to the + * output. The result is returned if @p predicate is fulfilled. + * @returns the transformed output of the XOF that fulfills @p predicate. + */ + template ), + typename PredicateFnT = decltype(default_predicate>)> + requires std::invocable> && + std::invocable> + constexpr auto next(MapFnT&& transformer = default_transformer, + PredicateFnT&& predicate = default_predicate>) { + while(true) { + auto output = transformer(take()); + if(predicate(output)) { + return output; + } + } + } + + private: + template + constexpr std::array take() { + m_bytes_consumed += bytes; + if(m_bytes_consumed > bound) { + throw Internal_Error("XOF consumed more bytes than allowed"); + } + return m_xof.template output(); + } + + private: + XofT m_xof; + size_t m_bytes_consumed; +}; + +} // namespace detail + +class XOF; + +template +using Bounded_XOF = detail::Bounded_XOF; + } // namespace Botan #endif diff --git a/src/tests/test_crystals.cpp b/src/tests/test_crystals.cpp index 10e81fc5a79..7d38781a750 100644 --- a/src/tests/test_crystals.cpp +++ b/src/tests/test_crystals.cpp @@ -499,9 +499,88 @@ std::vector test_encoding() { }; } +class MockedXOF { + public: + MockedXOF() : m_counter(0) {} + + template + auto output() { + std::array result; + for(uint8_t& byte : result) { + byte = static_cast(m_counter++); + } + return result; + } + + private: + size_t m_counter; +}; + +template +using Mocked_Bounded_XOF = Botan::detail::Bounded_XOF; + +std::vector test_bounded_xof() { + return { + CHECK("zero bound is reached immediately", + [](Test::Result& result) { + Mocked_Bounded_XOF<0> xof; + result.test_throws("output<1> throws", [&xof]() { xof.next_byte(); }); + }), + + CHECK("bounded XOF with small bound", + [](Test::Result& result) { + Mocked_Bounded_XOF<3> xof; + result.test_is_eq("next_byte() returns 0", xof.next_byte(), uint8_t(0)); + result.test_is_eq("next_byte() returns 1", xof.next_byte(), uint8_t(1)); + result.test_is_eq("next_byte() returns 2", xof.next_byte(), uint8_t(2)); + result.test_throws("next_byte() throws", [&xof]() { xof.next_byte(); }); + }), + + CHECK("filter bytes", + [](Test::Result& result) { + auto filter = [](uint8_t byte) { + //test + return byte % 2 == 1; + }; + + Mocked_Bounded_XOF<5> xof; + result.test_is_eq("next_byte() returns 1", xof.next_byte(filter), uint8_t(1)); + result.test_is_eq("next_byte() returns 3", xof.next_byte(filter), uint8_t(3)); + result.test_throws("next_byte() throws", [&]() { xof.next_byte(filter); }); + }), + + CHECK("map bytes", + [](Test::Result& result) { + auto map = [](auto bytes) { return Botan::load_be(bytes); }; + + Mocked_Bounded_XOF<17> xof; + result.test_is_eq("next returns 0x00010203", xof.next<4>(map), uint32_t(0x00010203)); + result.test_is_eq("next returns 0x04050607", xof.next<4>(map), uint32_t(0x04050607)); + result.test_is_eq("next returns 0x08090A0B", xof.next<4>(map), uint32_t(0x08090A0B)); + result.test_is_eq("next returns 0x0C0D0E0F", xof.next<4>(map), uint32_t(0x0C0D0E0F)); + result.test_throws("next() throws", [&]() { xof.next<4>(map); }); + }), + + CHECK("map and filter bytes", + [](Test::Result& result) { + auto map = [](std::array bytes) -> uint32_t { return bytes[0] + bytes[1] + bytes[2]; }; + auto filter = [](uint32_t number) { return number < 50; }; + + Mocked_Bounded_XOF<17> xof; + result.test_is_eq("next returns 3", xof.next<3>(map, filter), uint32_t(3)); + result.test_is_eq("next returns 12", xof.next<3>(map, filter), uint32_t(12)); + result.test_is_eq("next returns 21", xof.next<3>(map, filter), uint32_t(21)); + result.test_is_eq("next returns 30", xof.next<3>(map, filter), uint32_t(30)); + result.test_is_eq("next returns 39", xof.next<3>(map, filter), uint32_t(39)); + result.test_throws("next() throws", [&]() { xof.next<3>(map, filter); }); + }), + }; +} + } // namespace -BOTAN_REGISTER_TEST_FN("pubkey", "crystals", test_extended_euclidean_algorithm, test_polynomial_basics, test_encoding); +BOTAN_REGISTER_TEST_FN( + "pubkey", "crystals", test_extended_euclidean_algorithm, test_polynomial_basics, test_encoding, test_bounded_xof); } // namespace Botan_Tests