From 5fd7d18c92f050981995a4abfe68c3fa4621bfc2 Mon Sep 17 00:00:00 2001 From: Kert Date: Sat, 10 Apr 2021 15:13:02 -0700 Subject: [PATCH] Implement missing multiply with overflow checking (#5) * Adding doc links * Adding doc badge * Implement overflowing_mul * Add multiply test coverage * Update version, readme --- Cargo.toml | 3 +- README.md | 4 +- src/fixeduint.rs | 124 +++++++++++++++++++----------- src/machineword.rs | 55 ++++++++------ tests/mul_div.rs | 182 +++++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 300 insertions(+), 68 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 0e0992a..24604d2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,7 +1,8 @@ [package] name = "fixed-bigint" -version = "0.1.1" +version = "0.1.2" authors = ["kaidokert "] +documentation = "https://docs.rs/fixed-bigint" edition = "2018" description = """ Fixed-size big integer implementation for Rust diff --git a/README.md b/README.md index 0430ed4..b149534 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,7 @@ # Fixed BigInt [![crate](https://img.shields.io/crates/v/fixed-bigint.svg)](https://crates.io/crates/fixed-bigint) +[![documentation](https://docs.rs/fixed-bigint/badge.svg)](https://docs.rs/fixed-bigint/) [![minimum rustc 1.51](https://img.shields.io/badge/rustc-1.51+-red.svg)](https://rust-lang.github.io/rfcs/2495-min-rust-version.html) [![build status](https://github.com/kaidokert/fixed-bigint-rs/actions/workflows/rust.yml/badge.svg)](https://github.com/kaidokert/fixed-bigint-rs/actions) @@ -16,7 +17,8 @@ The crate is written for `no_std` and `no_alloc` environments with option for pa The arithmetic operands ( +, -, .add() ) panic on overflow, just like native integer types. Panic-free alternatives like `overlowing_add` and `wrapping_add` are supported. _TODO list_: - * Implement missing checked_mul-div, wrapping_mul/div, overflowing_mul/div. + * Implement WrappingShl/Shr, CheckedShl/Shr + * Implement AddAssign, MulAssign and other xyzAssign operands, memory and speed improvement * Implement experimental `unchecked_math` operands, unchecked_mul, unchecked_div etc. * Probably needs its own error structs instead of reusing core::fmt::Error and core::num::ParseIntError * Decimal string to/from conversion, currently only binary and hex strings are supported. diff --git a/src/fixeduint.rs b/src/fixeduint.rs index c57c654..56be043 100644 --- a/src/fixeduint.rs +++ b/src/fixeduint.rs @@ -12,10 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -use num_traits::{ - ops::overflowing::OverflowingAdd, ops::overflowing::OverflowingSub, Bounded, One, PrimInt, - ToPrimitive, Zero, -}; +use num_traits::ops::overflowing::{OverflowingAdd, OverflowingMul, OverflowingSub}; +use num_traits::{Bounded, One, PrimInt, ToPrimitive, Zero}; use core::convert::TryFrom; use core::fmt::Write; @@ -183,16 +181,6 @@ impl FixedUInt { Ok(()) } - fn from_doubleword(other: T::DoubleWord) -> Self { - let mut ret = Self::zero(); - ret.array[0] = T::from_double(other); - if N > 1 { - let tmp2 = other >> Self::WORD_BITS; - ret.array[1] = T::from_double(tmp2); - } - ret - } - // Here to avoid duplicating this in two traits fn saturating_add_impl(self, other: &Self) -> Self { let res = self.overflowing_add(&other); @@ -409,26 +397,83 @@ impl num_traits::Saturating for FixedUInt // #region Multiply/Divide -impl core::ops::Mul for FixedUInt { - type Output = Self; - fn mul(self, other: Self) -> >::Output { +impl num_traits::ops::overflowing::OverflowingMul + for FixedUInt +{ + fn overflowing_mul(&self, other: &Self) -> (Self, bool) { let mut ret = Self::zero(); - + let mut overflowed = false; + // Calculate N+1 rounds, to check for overflow + let max_rounds = N + 1; + let t_max = T::max_value().to_double(); for i in 0..N { - let mut row = Self::zero(); + let mut carry = T::DoubleWord::zero(); for j in 0..N { - if i + j < N { - let intermediate: T::DoubleWord = - self.array[j].to_double() * other.array[i].to_double(); - let mut f = Self::from_doubleword(intermediate); - let shiftq: usize = (Self::WORD_BITS * (i + j)) as usize; - f = f << shiftq; - row = f + row; + let round = i + j; + if round < max_rounds { + let mul_res = self.array[i].to_double() * other.array[j].to_double(); + let mut accumulator = T::DoubleWord::zero(); + if round < N { + accumulator = ret.array[round].to_double(); + } + accumulator = accumulator + mul_res + carry; + + if accumulator > t_max { + carry = accumulator >> Self::WORD_BITS; + accumulator = accumulator & t_max; + } else { + carry = T::DoubleWord::zero(); + } + if round < N { + ret.array[round] = T::from_double(accumulator); + } else { + overflowed = overflowed || !accumulator.is_zero(); + } } } - ret = ret + row; + if !carry.is_zero() { + overflowed = true; + } + } + (ret, overflowed) + } +} + +impl core::ops::Mul for FixedUInt { + type Output = Self; + fn mul(self, other: Self) -> >::Output { + let res = self.overflowing_mul(&other); + res.0 + } +} + +impl num_traits::WrappingMul for FixedUInt { + fn wrapping_mul(&self, other: &Self) -> Self { + self.overflowing_mul(&other).0 + } +} + +impl num_traits::CheckedMul for FixedUInt { + fn checked_mul(&self, other: &Self) -> Option { + let res = self.overflowing_mul(&other); + if res.1 { + None + } else { + Some(res.0) + } + } +} + +impl num_traits::ops::saturating::SaturatingMul + for FixedUInt +{ + fn saturating_mul(&self, other: &Self) -> Self { + let res = self.overflowing_mul(&other); + if res.1 { + Self::max_value() + } else { + res.0 } - ret } } @@ -467,15 +512,13 @@ impl core::ops::Div for FixedUInt { } } -impl num_traits::CheckedMul for FixedUInt { - fn checked_mul(&self, _: &Self) -> Option { - todo!() - } -} - impl num_traits::CheckedDiv for FixedUInt { - fn checked_div(&self, _: &Self) -> Option { - todo!() + fn checked_div(&self, other: &Self) -> Option { + if other.is_zero() { + None + } else { + Some(core::ops::Div::::div(*self, *other)) + } } } @@ -919,15 +962,6 @@ mod tests { assert_eq!(f.array, [4294967295]); } - #[test] - fn test_from_doubleword() { - let f = Bn8::from_doubleword(45); - assert_eq!(Some(45), f.to_u32()); - - let f = Bn8::from_doubleword(256); - assert_eq!(Some(256), f.to_u32()); - } - #[test] fn testsimple() { assert_eq!(Bn::::new(), Bn::::new()); diff --git a/src/machineword.rs b/src/machineword.rs index eff07d4..1fd3522 100644 --- a/src/machineword.rs +++ b/src/machineword.rs @@ -21,18 +21,20 @@ pub trait MachineWord: + num_traits::ops::overflowing::OverflowingSub + From + num_traits::WrappingShl + + OverflowingShl + + OverflowingShr + + core::fmt::Debug { type DoubleWord: num_traits::PrimInt + num_traits::Unsigned + num_traits::WrappingAdd - + num_traits::WrappingSub; + + num_traits::WrappingSub + + OverflowingShl; fn to_double(self) -> Self::DoubleWord; fn from_double(word: Self::DoubleWord) -> Self; // Todo: get rid of this, single use fn to_ne_bytes(self) -> [u8; 8]; - fn overflowing_shl(self, rhs: u32) -> (Self, bool); - fn overflowing_shr(self, rhs: u32) -> (Self, bool); } impl MachineWord for u8 { @@ -48,12 +50,6 @@ impl MachineWord for u8 { ret[0] = self; ret } - fn overflowing_shl(self, rhs: u32) -> (u8, bool) { - self.overflowing_shl(rhs) - } - fn overflowing_shr(self, rhs: u32) -> (u8, bool) { - self.overflowing_shr(rhs) - } } impl MachineWord for u16 { type DoubleWord = u32; @@ -69,12 +65,6 @@ impl MachineWord for u16 { halfslice.copy_from_slice(&self.to_ne_bytes()); ret } - fn overflowing_shl(self, rhs: u32) -> (u16, bool) { - self.overflowing_shl(rhs) - } - fn overflowing_shr(self, rhs: u32) -> (u16, bool) { - self.overflowing_shr(rhs) - } } impl MachineWord for u32 { type DoubleWord = u64; @@ -90,14 +80,37 @@ impl MachineWord for u32 { halfslice.copy_from_slice(&self.to_ne_bytes()); ret } - fn overflowing_shl(self, rhs: u32) -> (u32, bool) { - self.overflowing_shl(rhs) - } - fn overflowing_shr(self, rhs: u32) -> (u32, bool) { - self.overflowing_shr(rhs) - } } +// These should be in num_traits +pub trait OverflowingShl: Sized { + fn overflowing_shl(self, rhs: u32) -> (Self, bool); +} +pub trait OverflowingShr: Sized { + fn overflowing_shr(self, rhs: u32) -> (Self, bool); +} + +macro_rules! overflowing_shift_impl { + ($trait_name:ident, $method:ident, $t:ty) => { + impl $trait_name for $t { + #[inline] + fn $method(self, rhs: u32) -> ($t, bool) { + <$t>::$method(self, rhs) + } + } + }; +} + +overflowing_shift_impl!(OverflowingShl, overflowing_shl, u8); +overflowing_shift_impl!(OverflowingShl, overflowing_shl, u16); +overflowing_shift_impl!(OverflowingShl, overflowing_shl, u32); +overflowing_shift_impl!(OverflowingShl, overflowing_shl, u64); + +overflowing_shift_impl!(OverflowingShr, overflowing_shr, u8); +overflowing_shift_impl!(OverflowingShr, overflowing_shr, u16); +overflowing_shift_impl!(OverflowingShr, overflowing_shr, u32); +overflowing_shift_impl!(OverflowingShr, overflowing_shr, u64); + #[cfg(test)] mod tests { use super::*; diff --git a/tests/mul_div.rs b/tests/mul_div.rs index 7b74def..67d0007 100644 --- a/tests/mul_div.rs +++ b/tests/mul_div.rs @@ -119,3 +119,185 @@ fn test_rem() { test_rem_1::>(); test_rem_1::>(); } + +#[test] +fn test_overflowing_mul() { + fn test_8_bit< + INT: num_traits::PrimInt + + num_traits::ops::overflowing::OverflowingMul + + core::fmt::Debug + + core::convert::From, + >() { + let a: INT = 2.into(); + let b: INT = 3.into(); + let c: INT = 130.into(); + assert_eq!(a.overflowing_mul(&b), (6.into(), false)); + assert_eq!(a.overflowing_mul(&c), (4.into(), true)); + assert_eq!(Into::::into(128).overflowing_mul(&a), (0.into(), true)); + assert_eq!( + Into::::into(127).overflowing_mul(&a), + (254.into(), false) + ); + } + + test_8_bit::(); + test_8_bit::>(); + + fn test_16_bit< + INT: num_traits::PrimInt + + num_traits::ops::overflowing::OverflowingMul + + num_traits::WrappingMul + + num_traits::SaturatingMul + + core::fmt::Debug + + core::convert::From, + >() { + let a: INT = 2.into(); + let b: INT = 3.into(); + let c: INT = 32770.into(); + + assert_eq!(a.overflowing_mul(&b), (6.into(), false)); + assert_eq!(a.overflowing_mul(&c), (4.into(), true)); + assert_eq!(c.overflowing_mul(&a), (4.into(), true)); + assert_eq!( + Into::::into(32768).overflowing_mul(&a), + (0.into(), true) + ); + assert_eq!( + Into::::into(32767).overflowing_mul(&a), + (65534.into(), false) + ); + + let tests = [ + (0u16, 0u16, 0u16, false), + (2, 3, 6, false), + (2, 32767, 65534, false), + (2, 32768, 0, true), + (2, 32770, 4, true), + (255, 255, 65025, false), + (255, 256, 65280, false), + (256, 256, 0, true), + (256, 257, 256, true), + (257, 257, 513, true), + ]; + for (a, b, res, overflow) in &tests { + let ac: INT = (*a).into(); + let bc: INT = (*b).into(); + assert_eq!(ac.overflowing_mul(&bc), ((*res).into(), *overflow)); + assert_eq!(bc.overflowing_mul(&ac), ((*res).into(), *overflow)); + assert_eq!(ac.wrapping_mul(&bc), (*res).into()); + let checked = ac.checked_mul(&bc); + let saturating = ac.saturating_mul(&bc); + if *overflow { + assert_eq!(checked, None); + assert_eq!(saturating, INT::max_value()); + } else { + assert_eq!(checked, Some((*res).into())); + assert_eq!(saturating, (*res).into()) + } + } + } + + test_16_bit::(); + test_16_bit::>(); + test_16_bit::>(); + + fn test_32_bit< + INT: num_traits::PrimInt + + num_traits::ops::overflowing::OverflowingMul + + num_traits::WrappingMul + + num_traits::SaturatingMul + + core::fmt::Debug + + core::convert::From, + >() { + let a: INT = 2.into(); + let b: INT = 3.into(); + let c: INT = 2147483650.into(); + + assert_eq!(a.overflowing_mul(&b), (6.into(), false)); + assert_eq!(a.overflowing_mul(&c), (4.into(), true)); + assert_eq!(c.overflowing_mul(&a), (4.into(), true)); + let tests = [ + (0u32, 0u32, 0u32, false), + (2, 3, 6, false), + (2, 2_147_483_647, 4_294_967_294, false), + (2, 2_147_483_648, 0, true), + (2, 2_147_483_650, 4, true), + (65535, 65535, 4_294_836_225, false), + (65535, 65536, 4_294_901_760, false), + (65536, 65536, 0, true), + (65536, 65537, 65536, true), + (65537, 65537, 131073, true), + ]; + for (a, b, res, overflow) in &tests { + let ac: INT = (*a).into(); + let bc: INT = (*b).into(); + assert_eq!(ac.overflowing_mul(&bc), ((*res).into(), *overflow)); + assert_eq!(ac.wrapping_mul(&bc), (*res).into()); + let checked = ac.checked_mul(&bc); + let saturating = ac.saturating_mul(&bc); + if *overflow { + assert_eq!(checked, None); + assert_eq!(saturating, INT::max_value()); + } else { + assert_eq!(checked, Some((*res).into())); + assert_eq!(saturating, (*res).into()) + } + } + } + test_32_bit::(); + test_32_bit::>(); + test_32_bit::>(); +} + +#[test] +#[ignore] +fn test_full_range_mul() { + fn test_ref< + REF: num_traits::PrimInt + num_traits::ops::overflowing::OverflowingMul, + INT: num_traits::PrimInt + num_traits::ops::overflowing::OverflowingMul + core::fmt::Debug, + >() + where + INT: core::convert::From, + core::ops::Range: Iterator, + { + for i in REF::zero()..REF::max_value() { + for j in REF::zero()..REF::max_value() { + let ref_val = i.overflowing_mul(&j); + let lhs: INT = i.into(); + let rhs: INT = j.into(); + let int_val = lhs.overflowing_mul(&rhs); + assert_eq!(int_val, (ref_val.0.into(), ref_val.1)); + } + } + } + test_ref::>(); + test_ref::>(); + test_ref::>(); + // this would never finish, single-threaded + // test_ref::, u32>(); + // test_ref::, u32>(); + // test_ref::, u32>(); +} + +#[test] +fn test_checked_div() { + fn test< + INT: num_traits::PrimInt + core::fmt::Debug + num_traits::CheckedDiv + core::convert::From, + >() { + let a: INT = 2.into(); + let b: INT = 0.into(); + assert_eq!(Into::::into(128).checked_div(&a), Some(64.into())); + assert_eq!(Into::::into(128).checked_div(&b), None); + assert_eq!(Into::::into(0).checked_div(&b), None); + assert_eq!(Into::::into(0).checked_div(&a), Some(0.into())); + } + test::(); + test::>(); + test::(); + test::>(); + test::>(); + test::(); + test::>(); + test::>(); + test::>(); +}