Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

optimise some of the bounds checks #34

Merged
merged 6 commits into from
Jan 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 30 additions & 18 deletions reed-solomon-novelpoly/src/field/inc_afft.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,16 @@ pub struct AdditiveFFT {
}

/// Formal derivative of polynomial in the new?? basis
pub fn formal_derivative(cos: &mut [Additive], size: usize) {
for i in 1..size {
pub fn formal_derivative(cos: &mut [Additive]) {
for i in 1..cos.len() {
let length = ((i ^ (i - 1)) + 1) >> 1;
for j in (i - length)..i {
cos[j] ^= cos.get(j + length).copied().unwrap_or(Additive::ZERO);
}
}
let mut i = size;
let mut i = cos.len();
while i < FIELD_SIZE && i < cos.len() {
for j in 0..size {
for j in 0..cos.len() {
cos[j] ^= cos.get(j + i).copied().unwrap_or(Additive::ZERO);
}
i <<= 1;
Expand All @@ -32,9 +32,11 @@ pub fn formal_derivative(cos: &mut [Additive], size: usize) {

/// Formal derivative of polynomial in tweaked?? basis
#[allow(non_snake_case)]
pub fn tweaked_formal_derivative(codeword: &mut [Additive], n: usize) {
pub fn tweaked_formal_derivative(codeword: &mut [Additive]) {
#[cfg(b_is_not_one)]
let B = unsafe { &AFFT.B };
#[cfg(b_is_not_one)]
let n = codeword.len();

// We change nothing when multiplying by b from B.
#[cfg(b_is_not_one)]
Expand All @@ -44,7 +46,7 @@ pub fn tweaked_formal_derivative(codeword: &mut [Additive], n: usize) {
codeword[i + 1] = codeword[i + 1].mul(b);
}

formal_derivative(codeword, n);
formal_derivative(codeword);

// Again changes nothing by multiplying by b although b differs here.
#[cfg(b_is_not_one)]
Expand Down Expand Up @@ -86,21 +88,25 @@ fn b_is_one() {
// We're hunting for the differences and trying to undersrtand the algorithm.

/// Inverse additive FFT in the "novel polynomial basis"
#[inline(always)]
pub fn inverse_afft(data: &mut [Additive], size: usize, index: usize) {
unsafe { &AFFT }.inverse_afft(data, size, index)
}

#[cfg(all(target_feature = "avx", feature = "avx"))]
#[inline(always)]
pub fn inverse_afft_faster8(data: &mut [Additive], size: usize, index: usize) {
unsafe { &AFFT }.inverse_afft_faster8(data, size, index)
}

/// Additive FFT in the "novel polynomial basis"
#[inline(always)]
pub fn afft(data: &mut [Additive], size: usize, index: usize) {
unsafe { &AFFT }.afft(data, size, index)
}

#[cfg(all(target_feature = "avx", feature = "avx"))]
#[inline(always)]
/// Additive FFT in the "novel polynomial basis"
pub fn afft_faster8(data: &mut [Additive], size: usize, index: usize) {
unsafe { &AFFT }.afft_faster8(data, size, index)
Expand Down Expand Up @@ -141,6 +147,8 @@ impl AdditiveFFT {
// After this, we start at depth (i of Algorithm 2) = (k of Algorithm 2) - 1
// and progress through FIELD_BITS-1 steps, obtaining \Psi_\beta(0,0).
let mut depart_no = 1_usize;
assert!(data.len() >= size);

while depart_no < size {
// if depart_no >= 8 {
// println!("\n\n\nplain/Round depart_no={depart_no}");
Expand All @@ -167,20 +175,16 @@ impl AdditiveFFT {
// if depart_no >= 8 && false{
// data[i + depart_no] ^= dbg!(data[dbg!(i)]);
// } else {

// TODO: Optimising bounds checks on this line will yield a great performance improvement.
data[i + depart_no] ^= data[i];
// }
}

// Algorithm 2 indexs the skew factor in line 5 page 6288
// by i and \omega_{j 2^{i+1}}, but not by r explicitly.
// We further explore this confusion below. (TODO)
let skew =
// if depart_no >= 8 && false {
// dbg!(self.skews[j + index - 1])
// } else {
self.skews[j + index - 1]
// }
;
let skew = self.skews[j + index - 1];

// It's reasonale to skip the loop if skew is zero, but doing so with
// all bits set requires justification. (TODO)
if skew.0 != ONEMASK {
Expand All @@ -191,8 +195,9 @@ impl AdditiveFFT {
// if depart_no >= 8 && false{
// data[i] ^= dbg!(dbg!(data[dbg!(i + depart_no)]).mul(skew));
// } else {

// TODO: Optimising bounds checks on this line will yield a great performance improvement.
data[i] ^= data[i + depart_no].mul(skew);
// }
}
}

Expand Down Expand Up @@ -270,6 +275,8 @@ impl AdditiveFFT {
// After this, we start at depth (i of Algorithm 1) = (k of Algorithm 1) - 1
// and progress through FIELD_BITS-1 steps, obtaining \Psi_\beta(0,0).
let mut depart_no = size >> 1_usize;
assert!(data.len() >= size);

while depart_no > 0 {
// Agrees with for loop (j of Algorithm 1) in (0..2^{k-i-1}) from line 5,
// except we've j in (depart_no..size).step_by(2*depart_no), meaning
Expand All @@ -291,6 +298,7 @@ impl AdditiveFFT {
// we think r actually appears but the skew factor repeats itself
// like in (19) in the proof of Lemma 4. (TODO)
// We should understand the rest of this basis story, like (8) too. (TODO)

let skew = self.skews[j + index - 1];

// It's reasonale to skip the loop if skew is zero, but doing so with
Expand All @@ -300,6 +308,8 @@ impl AdditiveFFT {
for i in (j - depart_no)..j {
// Line 6, explained by (28) page 6287, but
// adding depart_no acts like the r+2^i superscript.

// TODO: Optimising bounds checks on this line will yield a great performance improvement.
data[i] ^= data[i + depart_no].mul(skew);
}
}
Expand All @@ -308,6 +318,8 @@ impl AdditiveFFT {
for i in (j - depart_no)..j {
// Line 7, explained by (31) page 6287, but
// adding depart_no acts like the r+2^i superscript.

// TODO: Optimising bounds checks on this line will yield a great performance improvement.
data[i + depart_no] ^= data[i];
}

Expand Down Expand Up @@ -484,7 +496,7 @@ pub mod test_utils {
let data = gen_plain::<R>(size);
gen_faster8_from_plain(data)
}

#[cfg(all(target_feature = "avx", feature = "avx"))]
pub fn assert_plain_eq_faster8(plain: impl AsRef<[Additive]>, faster8: impl AsRef<[Additive]>) {
let plain = plain.as_ref();
Expand All @@ -502,7 +514,7 @@ mod afft_tests {
use super::super::*;
use super::super::test_utils::*;
use rand::rngs::SmallRng;

#[cfg(all(target_feature = "avx", feature = "avx"))]
#[test]
fn afft_output_plain_eq_faster8_size_16() {
Expand Down Expand Up @@ -544,7 +556,7 @@ mod afft_tests {
println!(">>>>");
assert_plain_eq_faster8(data_plain, data_faster8);
}

#[cfg(all(target_feature = "avx", feature = "avx"))]
#[test]
fn afft_output_plain_eq_faster8_impulse_data() {
Expand Down
23 changes: 10 additions & 13 deletions reed-solomon-novelpoly/src/field/inc_encode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ pub fn encode_low(data: &[Additive], k: usize, codeword: &mut [Additive], n: usi
encode_low_plain(data, k, codeword, n);
}

#[cfg(not(target_feature = "avx"))]
#[cfg(not(all(target_feature = "avx", feature = "avx")))]
encode_low_plain(data, k, codeword, n);
}

Expand Down Expand Up @@ -37,12 +37,10 @@ pub fn encode_low_plain(data: &[Additive], k: usize, codeword: &mut [Additive],

for shift in (k..n).step_by(k) {
let codeword_at_shift = &mut codeword_skip_first_k[(shift - k)..shift];

// copy `M_topdash` to the position we are currently at, the n transform
codeword_at_shift.copy_from_slice(codeword_first_k);
// dbg!(&codeword_at_shift);
afft(codeword_at_shift, k, shift);
// let post = &codeword_at_shift;
// dbg!(post);
}

// restore `M` from the derived ones
Expand Down Expand Up @@ -79,11 +77,10 @@ pub fn encode_low_faster8(data: &[Additive], k: usize, codeword: &mut [Additive]

for shift in (k..n).step_by(k) {
let codeword_at_shift = &mut codeword_skip_first_k[(shift - k)..shift];

// copy `M_topdash` to the position we are currently at, the n transform
codeword_at_shift.copy_from_slice(codeword_first_k);

afft_faster8(codeword_at_shift, k, shift);
// let post = &codeword8x_at_shift;
}

// restore `M` from the derived ones
Expand All @@ -108,6 +105,8 @@ pub fn encode_high(data: &[Additive], k: usize, parity: &mut [Additive], mem: &m
//data: message array. parity: parity array. mem: buffer(size>= n-k)
//Encoding alg for k/n>0.5: parity is a power of two.
pub fn encode_high_plain(data: &[Additive], k: usize, parity: &mut [Additive], mem: &mut [Additive], n: usize) {
assert!(is_power_of_2(n));

let t: usize = n - k;

// mem_zero(&mut parity[0..t]);
Expand Down Expand Up @@ -158,7 +157,7 @@ pub fn encode_sub(bytes: &[u8], n: usize, k: usize) -> Result<Vec<Additive>> {
} else {
encode_sub_plain(bytes, n, k)
}
#[cfg(not(target_feature = "avx"))]
#[cfg(not(all(target_feature = "avx", feature = "avx")))]
encode_sub_plain(bytes, n, k)
}

Expand Down Expand Up @@ -194,13 +193,11 @@ pub fn encode_sub_plain(bytes: &[u8], n: usize, k: usize) -> Result<Vec<Additive
elm_data[i] = Additive(Elt::from_be_bytes([
bytes.get(2 * i).copied().unwrap_or_default(),
bytes.get(2 * i + 1).copied().unwrap_or_default(),
]))
]));
}

// update new data bytes with zero padded bytes
// `l` is now `GF(2^16)` symbols
let elm_len = elm_data.len();
assert_eq!(elm_len, n);

let mut codeword = elm_data.clone();
assert_eq!(codeword.len(), n);
Expand Down Expand Up @@ -243,9 +240,9 @@ pub fn encode_sub_faster8(bytes: &[u8], n: usize, k: usize) -> Result<Vec<Additi

for i in 0..((bytes_len + 1) / 2) {
elm_data[i] = Additive(Elt::from_be_bytes([
bytes.get(2 * i).map(|x| *x).unwrap_or_default(),
bytes.get(2 * i + 1).map(|x| *x).unwrap_or_default(),
]))
bytes.get(2 * i).copied().unwrap_or_default(),
bytes.get(2 * i + 1).copied().unwrap_or_default(),
]));
}

// update new data bytes with zero padded bytes
Expand Down
8 changes: 6 additions & 2 deletions reed-solomon-novelpoly/src/field/inc_log_mul.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ impl Additive {

/// Multiplicaiton friendly LOG form of f2e16
#[derive(Clone, Debug, Copy, Add, AddAssign, Sub, SubAssign, PartialEq, Eq)] // Default, PartialOrd,Ord
#[repr(transparent)]
pub struct Multiplier(pub Elt);

impl Multiplier {
Expand All @@ -81,13 +82,16 @@ impl std::fmt::Display for Multiplier {
/// Fast Walsh–Hadamard transform over modulo `ONEMASK`
#[inline(always)]
pub fn walsh(data: &mut [Multiplier], size: usize) {
#[cfg(all(target_feature = "avx", table_bootstrap_complete))]
#[cfg(all(target_feature = "avx", table_bootstrap_complete, feature = "avx"))]
walsh_faster8(data, size);
#[cfg(not(all(target_feature = "avx", table_bootstrap_complete)))]
#[cfg(not(all(target_feature = "avx", table_bootstrap_complete, feature = "avx")))]
walsh_plain(data, size);
}

#[inline(always)]
pub fn walsh_plain(data: &mut [Multiplier], size: usize) {
assert!(data.len() >= size);

let mask = ONEMASK as Wide;
let mut depart_no = 1_usize;
while depart_no < size {
Expand Down
8 changes: 6 additions & 2 deletions reed-solomon-novelpoly/src/field/inc_reconstruct.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,13 +69,13 @@ pub(crate) fn decode_main(
assert!(n >= recover_up_to);
assert_eq!(erasure.len(), n);

for i in 0..n {
for i in 0..codeword.len() {
codeword[i] = if erasure[i] { Additive(0) } else { codeword[i].mul(log_walsh2[i]) };
}

inverse_afft(codeword, n, 0);

tweaked_formal_derivative(codeword, n);
tweaked_formal_derivative(codeword);

afft(codeword, n, 0);

Expand All @@ -89,6 +89,10 @@ pub(crate) fn decode_main(
// since this has only to be called once per reconstruction
pub fn eval_error_polynomial(erasure: &[bool], log_walsh2: &mut [Multiplier], n: usize) {
let z = std::cmp::min(n, erasure.len());
assert!(z <= erasure.len());
assert!(n <= log_walsh2.len());
assert!(z <= log_walsh2.len());

for i in 0..z {
log_walsh2[i] = Multiplier(erasure[i] as Elt);
}
Expand Down
2 changes: 1 addition & 1 deletion reed-solomon-novelpoly/src/novel_poly_basis/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ impl CodeParams {
{
self.k >= (Additive8x::LANE << 1) && self.n % Additive8x::LANE == 0
}
#[cfg(not(target_feature = "avx"))]
#[cfg(not(all(target_feature = "avx", feature = "avx")))]
false
}

Expand Down
Loading