From 8123302a772a9a53e1b01114c913c7b315b20a23 Mon Sep 17 00:00:00 2001
From: Josh Stone <cuviper@gmail.com>
Date: Sat, 4 May 2024 11:42:30 -0700
Subject: [PATCH] Generate the radix bases in `const` instead of `build.rs`

---
 build.rs               | 57 --------------------------
 src/biguint/convert.rs | 92 +++++++++++++++++++++++++++++++-----------
 2 files changed, 69 insertions(+), 80 deletions(-)

diff --git a/build.rs b/build.rs
index 278abf26..4bb21537 100644
--- a/build.rs
+++ b/build.rs
@@ -1,8 +1,4 @@
 use std::env;
-use std::error::Error;
-use std::fs::File;
-use std::io::Write;
-use std::path::Path;
 
 fn main() {
     let ptr_width = env::var("CARGO_CFG_TARGET_POINTER_WIDTH");
@@ -22,57 +18,4 @@ fn main() {
     }
 
     println!("cargo:rerun-if-changed=build.rs");
-
-    write_radix_bases().unwrap();
-}
-
-/// Write tables of the greatest power of each radix for the given bit size.  These are returned
-/// from `biguint::get_radix_base` to batch the multiplication/division of radix conversions on
-/// full `BigUint` values, operating on primitive integers as much as possible.
-///
-/// e.g. BASES_16[3] = (59049, 10) // 3¹⁰ fits in u16, but 3¹¹ is too big
-///      BASES_32[3] = (3486784401, 20)
-///      BASES_64[3] = (12157665459056928801, 40)
-///
-/// Powers of two are not included, just zeroed, as they're implemented with shifts.
-fn write_radix_bases() -> Result<(), Box<dyn Error>> {
-    let out_dir = env::var("OUT_DIR")?;
-    let dest_path = Path::new(&out_dir).join("radix_bases.rs");
-    let mut f = File::create(dest_path)?;
-
-    for &bits in &[16, 32, 64] {
-        let max = if bits < 64 {
-            (1 << bits) - 1
-        } else {
-            std::u64::MAX
-        };
-
-        writeln!(f, "#[deny(overflowing_literals)]")?;
-        writeln!(
-            f,
-            "pub(crate) static BASES_{bits}: [(u{bits}, usize); 257] = [",
-            bits = bits
-        )?;
-        for radix in 0u64..257 {
-            let (base, power) = if radix == 0 || radix.is_power_of_two() {
-                (0, 0)
-            } else {
-                let mut power = 1;
-                let mut base = radix;
-
-                while let Some(b) = base.checked_mul(radix) {
-                    if b > max {
-                        break;
-                    }
-                    base = b;
-                    power += 1;
-                }
-                (base, power)
-            };
-            writeln!(f, "    ({}, {}), // {}", base, power, radix)?;
-        }
-        writeln!(f, "];")?;
-    }
-
-    Ok(())
 }
diff --git a/src/biguint/convert.rs b/src/biguint/convert.rs
index 59df82b9..7e4dee57 100644
--- a/src/biguint/convert.rs
+++ b/src/biguint/convert.rs
@@ -119,7 +119,7 @@ fn from_radix_digits_be(v: &[u8], radix: u32) -> BigUint {
 
     let mut data = Vec::with_capacity(big_digits.to_usize().unwrap_or(0));
 
-    let (base, power) = get_radix_base(radix, big_digit::BITS);
+    let (base, power) = get_radix_base(radix);
     let radix = radix as BigDigit;
 
     let r = v.len() % power;
@@ -688,7 +688,7 @@ pub(super) fn to_radix_digits_le(u: &BigUint, radix: u32) -> Vec<u8> {
 
     let mut digits = u.clone();
 
-    let (base, power) = get_radix_base(radix, big_digit::HALF_BITS);
+    let (base, power) = get_half_radix_base(radix);
     let radix = radix as BigDigit;
 
     // For very large numbers, the O(n²) loop of repeated `div_rem_digit` dominates the
@@ -783,33 +783,79 @@ pub(crate) fn to_str_radix_reversed(u: &BigUint, radix: u32) -> Vec<u8> {
     res
 }
 
-/// Returns the greatest power of the radix for the given bit size
+/// Returns the greatest power of the radix for the `BigDigit` bit size
 #[inline]
-fn get_radix_base(radix: u32, bits: u8) -> (BigDigit, usize) {
-    mod gen {
-        include! { concat!(env!("OUT_DIR"), "/radix_bases.rs") }
-    }
+fn get_radix_base(radix: u32) -> (BigDigit, usize) {
+    static BASES: [(BigDigit, usize); 257] = generate_radix_bases(big_digit::MAX);
+    debug_assert!(!radix.is_power_of_two());
+    debug_assert!((3..256).contains(&radix));
+    BASES[radix as usize]
+}
 
-    debug_assert!(
-        2 <= radix && radix <= 256,
-        "The radix must be within 2...256"
-    );
+/// Returns the greatest power of the radix for half the `BigDigit` bit size
+#[inline]
+fn get_half_radix_base(radix: u32) -> (BigDigit, usize) {
+    static BASES: [(BigDigit, usize); 257] = generate_radix_bases(big_digit::HALF);
     debug_assert!(!radix.is_power_of_two());
-    debug_assert!(bits <= big_digit::BITS);
+    debug_assert!((3..256).contains(&radix));
+    BASES[radix as usize]
+}
 
-    match bits {
-        16 => {
-            let (base, power) = gen::BASES_16[radix as usize];
-            (base as BigDigit, power)
+/// Generate tables of the greatest power of each radix that is less that the given maximum. These
+/// are returned from `get_radix_base` to batch the multiplication/division of radix conversions on
+/// full `BigUint` values, operating on primitive integers as much as possible.
+///
+/// e.g. BASES_16[3] = (59049, 10) // 3¹⁰ fits in u16, but 3¹¹ is too big
+///      BASES_32[3] = (3486784401, 20)
+///      BASES_64[3] = (12157665459056928801, 40)
+///
+/// Powers of two are not included, just zeroed, as they're implemented with shifts.
+const fn generate_radix_bases(max: BigDigit) -> [(BigDigit, usize); 257] {
+    let mut bases = [(0, 0); 257];
+
+    let mut radix: BigDigit = 3;
+    while radix < 256 {
+        if !radix.is_power_of_two() {
+            let mut power = 1;
+            let mut base = radix;
+
+            while let Some(b) = base.checked_mul(radix) {
+                if b > max {
+                    break;
+                }
+                base = b;
+                power += 1;
+            }
+            bases[radix as usize] = (base, power)
         }
-        32 => {
-            let (base, power) = gen::BASES_32[radix as usize];
-            (base as BigDigit, power)
+        radix += 1;
+    }
+
+    bases
+}
+
+#[test]
+fn test_radix_bases() {
+    for radix in 3u32..256 {
+        if !radix.is_power_of_two() {
+            let (base, power) = get_radix_base(radix);
+            let radix = BigDigit::try_from(radix).unwrap();
+            let power = u32::try_from(power).unwrap();
+            assert_eq!(base, radix.pow(power));
+            assert!(radix.checked_pow(power + 1).is_none());
         }
-        64 => {
-            let (base, power) = gen::BASES_64[radix as usize];
-            (base as BigDigit, power)
+    }
+}
+
+#[test]
+fn test_half_radix_bases() {
+    for radix in 3u32..256 {
+        if !radix.is_power_of_two() {
+            let (base, power) = get_half_radix_base(radix);
+            let radix = BigDigit::try_from(radix).unwrap();
+            let power = u32::try_from(power).unwrap();
+            assert_eq!(base, radix.pow(power));
+            assert!(radix.pow(power + 1) > big_digit::HALF);
         }
-        _ => panic!("Invalid bigdigit size"),
     }
 }