Skip to content

Commit

Permalink
Add eq.
Browse files Browse the repository at this point in the history
  • Loading branch information
thealmarty committed Jan 16, 2024
1 parent 83f5ac1 commit 6a7e00b
Show file tree
Hide file tree
Showing 6 changed files with 120 additions and 82 deletions.
17 changes: 11 additions & 6 deletions alu_u32/src/com/columns.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,20 @@ pub struct Com32Cols<T> {
pub input_1: Word<T>,
pub input_2: Word<T>,

/// Boolean flags indicating which byte pair differs
pub byte_flag: [T; 3],

/// Bit decomposition of 256 + input_1 - input_2
pub bits: [T; 10],
/// When doing an equality test between two words, `x` and `y`, this holds the sum of
/// `(x_i - y_i)^2`, which is zero if and only if `x = y`.
pub diff: T,
/// The inverse of `diff`, or undefined if `diff = 0`.
pub diff_inv: T,
/// A boolean flag indicating whether `diff != 0`.
pub not_equal: T,

pub output: T,

pub multiplicity: T,
pub is_ne: T,
pub is_eq: T,
pub is_ne: T,
pub is_eq: T,
}

pub const NUM_COM_COLS: usize = size_of::<Com32Cols<u8>>();
Expand Down
102 changes: 77 additions & 25 deletions alu_u32/src/com/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use valida_cpu::MachineWithCpuChip;
use valida_machine::{
instructions, Chip, Instruction, Interaction, Operands, Word, MEMORY_CELL_BYTES,
};
use valida_opcodes::NE32;
use valida_opcodes::{EQ32, NE32};

use p3_air::VirtualPairCol;
use p3_field::PrimeField;
Expand All @@ -24,6 +24,8 @@ pub mod stark;
#[derive(Clone)]
pub enum Operation {
Ne32(Word<u8>, Word<u8>, Word<u8>), // (dst, src1, src2)
Eq32(Word<u8>, Word<u8>, Word<u8>), // (dst, src1, src2)
Eq32(Word<u8>, Word<u8>, Word<u8>), // (dst, src1, src2)
}

#[derive(Default)]
Expand Down Expand Up @@ -52,7 +54,20 @@ where
}

fn global_receives(&self, machine: &M) -> Vec<Interaction<M::F>> {
let opcode = VirtualPairCol::constant(M::F::from_canonical_u32(NE32));
let opcode = VirtualPairCol::new_main(
vec![
(COM_COL_MAP.is_ne, M::F::from_canonical_u32(NE32)),
(COM_COL_MAP.is_eq, M::F::from_canonical_u32(EQ32)),
],
M::F::zero(),
);
let opcode = VirtualPairCol::new_main(
vec![
(COM_COL_MAP.is_ne, M::F::from_canonical_u32(NE32)),
(COM_COL_MAP.is_eq, M::F::from_canonical_u32(EQ32)),
],
M::F::zero(),
);
let input_1 = COM_COL_MAP.input_1.0.map(VirtualPairCol::single_main);
let input_2 = COM_COL_MAP.input_2.0.map(VirtualPairCol::single_main);
let output = (0..MEMORY_CELL_BYTES - 1)
Expand All @@ -64,9 +79,12 @@ where
fields.extend(input_2);
fields.extend(output);

let is_real = VirtualPairCol::sum_main(vec![COM_COL_MAP.is_ne, COM_COL_MAP.is_eq]);

let receive = Interaction {
fields,
count: VirtualPairCol::single_main(COM_COL_MAP.multiplicity),
count: is_real,
count: is_real,
argument_index: machine.general_bus(),
};
vec![receive]
Expand All @@ -82,25 +100,11 @@ impl Com32Chip {
let cols: &mut Com32Cols<F> = unsafe { transmute(&mut row) };

match op {
Operation::Ne32(dst, src1, src2) => {
if let Some(n) = src1
.into_iter()
.zip(src2.into_iter())
.enumerate()
.find_map(|(n, (x, y))| if x == y { Some(n) } else { None })
{
let z = 256u16 + src1[n] as u16 - src2[n] as u16;
for i in 0..10 {
cols.bits[i] = F::from_canonical_u16(z >> i & 1);
}
if n < 3 {
cols.byte_flag[n] = F::one();
}
}
cols.input_1 = src1.transform(F::from_canonical_u8);
cols.input_2 = src2.transform(F::from_canonical_u8);
cols.output = F::from_canonical_u8(dst[3]);
cols.multiplicity = F::one();
Operation::Ne32(_, _, _) => {
cols.is_ne = F::one();
}
Operation::Eq32(_, _, _) => {
cols.is_eq = F::one();
}
}
row
Expand All @@ -112,7 +116,8 @@ pub trait MachineWithCom32Chip: MachineWithCpuChip {
fn com_u32_mut(&mut self) -> &mut Com32Chip;
}

instructions!(Ne32Instruction);
instructions!(Ne32Instruction, Eq32Instruction);
instructions!(Ne32Instruction, Eq32Instruction);

impl<M> Instruction<M> for Ne32Instruction
where
Expand All @@ -127,14 +132,18 @@ where
let mut imm: Option<Word<u8>> = None;
let read_addr_1 = (state.cpu().fp as i32 + ops.b()) as u32;
let write_addr = (state.cpu().fp as i32 + ops.a()) as u32;
let src1 = state.mem_mut().read(clk, read_addr_1, true, pc, opcode, 0, "");
let src1 = state
.mem_mut()
.read(clk, read_addr_1, true, pc, opcode, 0, "");
let src2 = if ops.is_imm() == 1 {
let c = (ops.c() as u32).into();
imm = Some(c);
c
} else {
let read_addr_2 = (state.cpu().fp as i32 + ops.c()) as u32;
state.mem_mut().read(clk, read_addr_2, true, pc, opcode, 1, "")
state
.mem_mut()
.read(clk, read_addr_2, true, pc, opcode, 1, "")
};

let dst = if src1 != src2 {
Expand All @@ -148,6 +157,49 @@ where
.com_u32_mut()
.operations
.push(Operation::Ne32(dst, src1, src2));
state
.com_u32_mut()
.operations
.push(Operation::Eq32(dst, src1, src2));
state.cpu_mut().push_bus_op(imm, opcode, ops);
}
}


impl<M> Instruction<M> for Eq32Instruction
where
M: MachineWithCom32Chip,
{
const OPCODE: u32 = EQ32;

fn execute(state: &mut M, ops: Operands<i32>) {
let opcode = <Self as Instruction<M>>::OPCODE;
let clk = state.cpu().clock;
let pc = state.cpu().pc;
let mut imm: Option<Word<u8>> = None;
let read_addr_1 = (state.cpu().fp as i32 + ops.b()) as u32;
let write_addr = (state.cpu().fp as i32 + ops.a()) as u32;
let src1 = state.mem_mut().read(clk, read_addr_1, true, pc, opcode, 0, "");
let src2 = if ops.is_imm() == 1 {
let c = (ops.c() as u32).into();
imm = Some(c);
c
} else {
let read_addr_2 = (state.cpu().fp as i32 + ops.c()) as u32;
state.mem_mut().read(clk, read_addr_2, true, pc, opcode, 1, "")
};

let dst = if src1 == src2 {
Word::from(1)
} else {
Word::from(0)
};
state.mem_mut().write(clk, write_addr, dst, true);

state
.com_u32_mut()
.operations
.push(Operation::Eq32(dst, src1, src2));
state
.cpu_mut()
.push_bus_op(imm, opcode, ops);
Expand Down
66 changes: 20 additions & 46 deletions alu_u32/src/com/stark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,51 +22,25 @@ where
let main = builder.main();
let local: &Com32Cols<AB::Var> = main.row_slice(0).borrow();

let base_2 = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512].map(AB::Expr::from_canonical_u32);

let bit_comp: AB::Expr = local
.bits
.into_iter()
.zip(base_2.iter().cloned())
.map(|(bit, base)| bit * base)
.sum();

// Check bit decomposition of z = 256 + input_1[n] - input_2[n], where
// n is the most significant byte that differs between inputs
for i in 0..3 {
builder
.when_ne(local.byte_flag[i], AB::Expr::one())
.assert_eq(local.input_1[i], local.input_2[i]);

builder.when(local.byte_flag[i]).assert_eq(
AB::Expr::from_canonical_u32(256) + local.input_1[i] - local.input_2[i],
bit_comp.clone(),
);

builder.assert_bool(local.byte_flag[i]);
}

// Check final byte (if no other byte flags were set)
let flag_sum = local.byte_flag[0] + local.byte_flag[1] + local.byte_flag[2];
builder.assert_bool(flag_sum.clone());
builder
.when_ne(local.multiplicity, AB::Expr::zero())
.when_ne(flag_sum, AB::Expr::one())
.assert_eq(
AB::Expr::from_canonical_u32(256) + local.input_1[3] - local.input_2[3],
bit_comp.clone(),
);

// Output constraints
builder.when(local.bits[8]).assert_zero(local.output);
builder
.when_ne(local.multiplicity, AB::Expr::zero())
.when_ne(local.bits[8], AB::Expr::one())
.assert_one(local.output);

// Check bit decomposition
for bit in local.bits.into_iter() {
builder.assert_bool(bit);
}
// Check if the first two operand values are equal, in case we're doing a conditional branch.
// (when is_imm == 1, the second read value is guaranteed to be an immediate value)
builder.assert_eq(
local.diff,
local
.input_1
.into_iter()
.zip(local.input_2)
.map(|(a, b)| (a - b) * (a - b))
.sum::<AB::Expr>(),
);
builder.assert_bool(local.not_equal);
builder.assert_eq(local.not_equal, local.diff * local.diff_inv);

builder.assert_bool(local.is_ne);
builder.assert_bool(local.is_eq);
builder.assert_bool(local.is_ne + local.is_eq);

builder.when(local.is_ne).assert_one(local.not_equal);
builder.when(local.is_eq).assert_zero(local.not_equal);
}
}
2 changes: 1 addition & 1 deletion alu_u32/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@ extern crate alloc;

pub mod add;
pub mod bitwise;
pub mod com;
pub mod div;
pub mod lt;
pub mod com;
pub mod mul;
pub mod shift;
pub mod sub;
7 changes: 6 additions & 1 deletion basic/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@ use valida_alu_u32::{
And32Instruction, Bitwise32Chip, MachineWithBitwise32Chip, Or32Instruction,
Xor32Instruction,
},
com::{Com32Chip, Eq32Instruction, MachineWithCom32Chip, Ne32Instruction},
div::{Div32Chip, Div32Instruction, MachineWithDiv32Chip, SDiv32Instruction},
lt::{Lt32Chip, Lt32Instruction, MachineWithLt32Chip},
com::{Com32Chip, Ne32Instruction, MachineWithCom32Chip},
com::{Com32Chip, Ne32Instruction, Eq32Instruction, MachineWithCom32Chip},
mul::{
MachineWithMul32Chip, Mul32Chip, Mul32Instruction, Mulhs32Instruction, Mulhu32Instruction,
},
Expand Down Expand Up @@ -89,6 +90,10 @@ pub struct BasicMachine<F: PrimeField32 + TwoAdicField> {
lt32: Lt32Instruction,
#[instruction(com_u32)]
ne32: Ne32Instruction,
#[instruction(com_u32)]
eq32: Eq32Instruction,
#[instruction(com_u32)]
eq32: Eq32Instruction,
#[instruction(bitwise_u32)]
and32: And32Instruction,
#[instruction(bitwise_u32)]
Expand Down
8 changes: 5 additions & 3 deletions opcodes/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,11 @@ pub const AND32: u32 = 107;
pub const OR32: u32 = 108;
pub const XOR32: u32 = 109;
pub const NE32: u32 = 111;
pub const MULHU32 : u32 = 112; //TODO
pub const SRA32 : u32 = 113; //TODO
pub const MULHS32 : u32 =114; //TODO
pub const MULHU32: u32 = 112;
pub const SRA32: u32 = 113;
pub const MULHS32: u32 = 114;
pub const LTE32: u32 = 115; //TODO
pub const EQ32: u32 = 116;

/// NATIVE FIELD
pub const ADD: u32 = 200;
Expand Down

0 comments on commit 6a7e00b

Please sign in to comment.