Skip to content

Commit

Permalink
Add Mul<NonNativeFieldVar> for Group (#134)
Browse files Browse the repository at this point in the history
  • Loading branch information
Pratyush authored Dec 28, 2023
1 parent 6164009 commit 1ff3a90
Show file tree
Hide file tree
Showing 14 changed files with 392 additions and 294 deletions.
29 changes: 14 additions & 15 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -109,14 +109,14 @@ jobs:
- name: Checkout
uses: actions/checkout@v3

- name: Install Rust (${{ matrix.rust }})
- name: Install Rust
uses: dtolnay/rust-toolchain@stable
id: toolchain-thumbv6m
with:
target: thumbv6m-none-eabi
- run: rustup override set ${{steps.toolchain-thumbv6m.outputs.name}}

- name: Install Rust ARM64 (${{ matrix.rust }})
- name: Install Rust ARM64
uses: dtolnay/rust-toolchain@stable
id: toolchain-aarch64
with:
Expand Down Expand Up @@ -152,12 +152,12 @@ jobs:
- ed_on_bls12_381
steps:
- name: Checkout curves
uses: actions/checkout@v2
uses: actions/checkout@v4
with:
repository: arkworks-rs/curves
repository: arkworks-rs/algebra

- name: Checkout r1cs-std
uses: actions/checkout@v2
uses: actions/checkout@v4
with:
path: r1cs-std

Expand All @@ -166,22 +166,21 @@ jobs:

- name: Patch cargo.toml
run: |
cd curves
if grep -q "\[patch.crates-io\]" Cargo.toml ; then
MATCH=$(awk '/\[patch.crates-io\]/{ print NR; exit }' Cargo.toml);
sed -i "$MATCH,\$d" Cargo.toml
fi
{
echo "[patch.crates-io]"
echo "ark-std = { git = 'https://github.com/arkworks-rs/std' }"
echo "ark-ec = { git = 'https://github.com/arkworks-rs/algebra' }"
echo "ark-ff = { git = 'https://github.com/arkworks-rs/algebra' }"
echo "ark-poly = { git = 'https://github.com/arkworks-rs/algebra' }"
echo "ark-ec = { path = '../ec' }"
echo "ark-ff = { path = '../ff' }"
echo "ark-poly = { path = '../poly' }"
echo "ark-relations = { git = 'https://github.com/arkworks-rs/snark' }"
echo "ark-serialize = { git = 'https://github.com/arkworks-rs/algebra' }"
echo "ark-algebra-bench-templates = { git = 'https://github.com/arkworks-rs/algebra' }"
echo "ark-algebra-test-templates = { git = 'https://github.com/arkworks-rs/algebra' }"
echo "ark-r1cs-std = { path = 'r1cs-std' }"
echo "ark-serialize = { path = '../serialize' }"
echo "ark-algebra-bench-templates = { path = '../bench-templates' }"
echo "ark-algebra-test-templates = { path = '../test-templates' }"
echo "ark-r1cs-std = { path = '../r1cs-std' }"
} >> Cargo.toml
- name: Test on ${{ matrix.curve }}
run: "cd ${{ matrix.curve }} && cargo test --features 'r1cs'"
cd ${{ matrix.curve }} && cargo test --features 'r1cs'
2 changes: 1 addition & 1 deletion src/bits/uint8.rs
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,7 @@ impl<ConstraintF: Field> AllocVar<u8, ConstraintF> for UInt8<ConstraintF> {
/// `ConstraintF::MODULUS_BIT_SIZE - 1` chunks and converts each chunk, which is
/// assumed to be little-endian, to its `FpVar<ConstraintF>` representation.
/// This is the gadget counterpart to the `[u8]` implementation of
/// [ToConstraintField](ark_ff::ToConstraintField).
/// [`ToConstraintField`].
impl<ConstraintF: PrimeField> ToConstraintFieldGadget<ConstraintF> for [UInt8<ConstraintF>] {
#[tracing::instrument(target = "r1cs")]
fn to_constraint_field(&self) -> Result<Vec<FpVar<ConstraintF>>, SynthesisError> {
Expand Down
39 changes: 17 additions & 22 deletions src/fields/nonnative/allocated_field_var.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,14 +57,13 @@ impl<TargetField: PrimeField, BaseField: PrimeField>
optimization_type,
);

let mut base_repr: <TargetField as PrimeField>::BigInt = TargetField::one().into_bigint();

// Convert 2^{(params.bits_per_limb - 1)} into the TargetField and then double
// the base This is because 2^{(params.bits_per_limb)} might indeed be
// larger than the target field's prime.
base_repr.muln((params.bits_per_limb - 1) as u32);
let mut base: TargetField = TargetField::from_bigint(base_repr).unwrap();
base = base + &base;
let base_repr = TargetField::ONE.into_bigint() << (params.bits_per_limb - 1) as u32;

let mut base = TargetField::from_bigint(base_repr).unwrap();
base.double_in_place();

let mut result = TargetField::zero();
let mut power = TargetField::one();
Expand Down Expand Up @@ -206,25 +205,21 @@ impl<TargetField: PrimeField, BaseField: PrimeField>
> BaseField::MODULUS_BIT_SIZE as usize - 1)
{
Reducer::reduce(&mut other)?;
surfeit = overhead!(other.num_of_additions_over_normal_form + BaseField::one()) + 1;
surfeit = overhead!(other.num_of_additions_over_normal_form + BaseField::ONE) + 1;
}

// Step 2: construct the padding
let mut pad_non_top_limb_repr: <BaseField as PrimeField>::BigInt =
BaseField::one().into_bigint();
let mut pad_top_limb_repr: <BaseField as PrimeField>::BigInt = pad_non_top_limb_repr;
let mut pad_non_top_limb = BaseField::ONE.into_bigint();
let mut pad_top_limb = pad_non_top_limb;

pad_non_top_limb_repr.muln((surfeit + params.bits_per_limb) as u32);
let pad_non_top_limb = BaseField::from_bigint(pad_non_top_limb_repr).unwrap();
pad_non_top_limb <<= (surfeit + params.bits_per_limb) as u32;
let pad_non_top_limb = BaseField::from_bigint(pad_non_top_limb).unwrap();

pad_top_limb_repr.muln(
(surfeit
+ (TargetField::MODULUS_BIT_SIZE as usize
- params.bits_per_limb * (params.num_limbs - 1))) as u32,
);
let pad_top_limb = BaseField::from_bigint(pad_top_limb_repr).unwrap();
pad_top_limb <<= (surfeit + TargetField::MODULUS_BIT_SIZE as usize
- params.bits_per_limb * (params.num_limbs - 1)) as u32;
let pad_top_limb = BaseField::from_bigint(pad_top_limb).unwrap();

let mut pad_limbs = Vec::new();
let mut pad_limbs = Vec::with_capacity(self.limbs.len());
pad_limbs.push(pad_top_limb);
for _ in 0..self.limbs.len() - 1 {
pad_limbs.push(pad_non_top_limb);
Expand All @@ -236,12 +231,12 @@ impl<TargetField: PrimeField, BaseField: PrimeField>
Self::get_limbs_representations(&pad_to_kp_gap, self.get_optimization_type())?;

// Step 4: the result is self + pad + pad_to_kp - other
let mut limbs = Vec::new();
let mut limbs = Vec::with_capacity(self.limbs.len());
for (i, ((this_limb, other_limb), pad_to_kp_limb)) in self
.limbs
.iter()
.zip(other.limbs.iter())
.zip(pad_to_kp_limbs.iter())
.zip(&other.limbs)
.zip(&pad_to_kp_limbs)
.enumerate()
{
if i != 0 {
Expand Down Expand Up @@ -341,7 +336,7 @@ impl<TargetField: PrimeField, BaseField: PrimeField>
&cur_bits[cur_bits.len() - params.bits_per_limb..],
); // therefore, the lowest `bits_per_non_top_limb` bits is what we want.
limbs.push(BaseField::from_bigint(cur_mod_r).unwrap());
cur.divn(params.bits_per_limb as u32);
cur >>= params.bits_per_limb as u32;
}

// then we reserve, so that the limbs are ``big limb first''
Expand Down
25 changes: 11 additions & 14 deletions src/fields/nonnative/reduce.rs
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ impl<TargetField: PrimeField, BaseField: PrimeField> Reducer<TargetField, BaseFi
let mut cur = BaseField::one().into_bigint();
for _ in 0..num_limb_in_a_group {
array.push(BaseField::from_bigint(cur).unwrap());
cur.muln(shift_per_limb as u32);
cur <<= shift_per_limb as u32;
}

array
Expand Down Expand Up @@ -280,16 +280,13 @@ impl<TargetField: PrimeField, BaseField: PrimeField> Reducer<TargetField, BaseFi
for (group_id, (left_total_limb, right_total_limb, num_limb_in_this_group)) in
groupped_limb_pairs.iter().enumerate()
{
let mut pad_limb_repr: <BaseField as PrimeField>::BigInt =
BaseField::one().into_bigint();

pad_limb_repr.muln(
(surfeit
+ (bits_per_limb - shift_per_limb)
+ shift_per_limb * num_limb_in_this_group
+ 1
+ 1) as u32,
);
let mut pad_limb_repr = BaseField::ONE.into_bigint();

pad_limb_repr <<= (surfeit
+ (bits_per_limb - shift_per_limb)
+ shift_per_limb * num_limb_in_this_group
+ 1
+ 1) as u32;
let pad_limb = BaseField::from_bigint(pad_limb_repr).unwrap();

let left_total_limb_value = left_total_limb.value().unwrap_or_default();
Expand All @@ -298,12 +295,12 @@ impl<TargetField: PrimeField, BaseField: PrimeField> Reducer<TargetField, BaseFi
let mut carry_value =
left_total_limb_value + carry_in_value + pad_limb - right_total_limb_value;

let mut carry_repr = carry_value.into_bigint();
carry_repr.divn((shift_per_limb * num_limb_in_this_group) as u32);
let carry_repr =
carry_value.into_bigint() >> (shift_per_limb * num_limb_in_this_group) as u32;

carry_value = BaseField::from_bigint(carry_repr).unwrap();

let carry = FpVar::<BaseField>::new_witness(cs.clone(), || Ok(carry_value))?;
let carry = FpVar::new_witness(cs.clone(), || Ok(carry_value))?;

accumulated_extra += limbs_to_bigint(bits_per_limb, &[pad_limb]);

Expand Down
Loading

0 comments on commit 1ff3a90

Please sign in to comment.