From aec600dd15194ab22488154c2fc650fbca75cb2f Mon Sep 17 00:00:00 2001 From: Aaron Feickert <66188213+AaronFeickert@users.noreply.github.com> Date: Mon, 12 Feb 2024 14:42:04 -0600 Subject: [PATCH] Simplify constant-time Gray decomposition --- src/gray.rs | 33 ++++++++++++--------------------- 1 file changed, 12 insertions(+), 21 deletions(-) diff --git a/src/gray.rs b/src/gray.rs index 7ef083e..15778ec 100644 --- a/src/gray.rs +++ b/src/gray.rs @@ -4,7 +4,7 @@ use alloc::{vec, vec::Vec}; use core::num::NonZeroU32; -use crypto_bigint::{NonZero, Uint}; +use crypto_bigint::{NonZero, U64}; /// An iterator for arbitrary-base Gray codes. #[allow(non_snake_case)] @@ -77,46 +77,37 @@ impl GrayIterator { /// Otherwise, returns the Gray code as a `u32` digit vector. #[allow(non_snake_case)] pub(crate) fn decompose(N: u32, M: u32, v: u32) -> Option> { - type U32 = Uint<1>; - if N <= 1 || M == 0 { return None; } - // Each of these `u32` values can fit into a single-limb `Uint`, regardless of target - let mut v_U32 = U32::from_u32(v); - let N_nonzero = NonZero::>::from_u32(NonZeroU32::new(N)?); + // Convert to constant-time-friendly `U64` + let mut v_U64 = U64::from_u32(v); + let N_nonzero = NonZero::::from_u32(NonZeroU32::new(N)?); // Get a base-`N` decomposition in constant time let mut base_N = Vec::with_capacity(M as usize); for _ in 0..M { - let (q, r) = v_U32.div_rem(&N_nonzero); + let (q, r) = v_U64.div_rem(&N_nonzero); base_N.push(r); - v_U32 = q; + v_U64 = q; } // Now get the Gray decomposition from the base-`N` decomposition - let mut shift = U32::ZERO; - let mut digits = vec![U32::ZERO; M as usize]; + let mut shift = U64::ZERO; + let mut digits = vec![U64::ZERO; M as usize]; for i in (0..M).rev() { digits[i as usize] = base_N[i as usize].saturating_add(&shift).rem(&N_nonzero); shift = shift.saturating_add(&N_nonzero).saturating_sub(&digits[i as usize]); } - // On a 32-bit target, the single word is already `u32` - #[cfg(target_pointer_width = "32")] - let digits_u32 = digits.iter().map(|d| d.as_words()[0]).collect::>(); - - // On a 64-bit target, the single word is `u64`, but should not overflow `u32` - #[cfg(target_pointer_width = "64")] - let digits_u32 = digits + // Get the digits as `u32` + digits .iter() - .map(|d| u32::try_from(d.as_words()[0]).ok()) - .collect::>>()?; - - Some(digits_u32) + .map(|d| u32::try_from(u64::from(*d)).ok()) + .collect::>>() } }