Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes for less-than instructions: LT32Chip STARK constraints and handling instructions with second operand immediate #166

Merged
merged 12 commits into from
May 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion alu_u32/src/lt/columns.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,26 @@ pub struct Lt32Cols<T> {
pub byte_flag: [T; 4],

/// Bit decomposition of 256 + input_1 - input_2
pub bits: [T; 10],
pub bits: [T; 9],

pub output: T,

pub multiplicity: T,

pub is_lt: T,
pub is_lte: T,
pub is_slt: T,
pub is_sle: T,

// inverse of input_1[i] - input_2[i] where i is the first byte that differs
pub diff_inv: T,

// bit decomposition of top bytes for input_1 and input_2
pub top_bits_1: [T; 8],
pub top_bits_2: [T; 8],

// boolean flag for whether the sign of the two inputs is different
pub different_signs: T,
}

pub const NUM_LT_COLS: usize = size_of::<Lt32Cols<u8>>();
Expand Down
238 changes: 97 additions & 141 deletions alu_u32/src/lt/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ where
vec![
(LT_COL_MAP.is_lt, SC::Val::from_canonical_u32(LT32)),
(LT_COL_MAP.is_lte, SC::Val::from_canonical_u32(LTE32)),
(LT_COL_MAP.is_slt, SC::Val::from_canonical_u32(SLT32)),
(LT_COL_MAP.is_sle, SC::Val::from_canonical_u32(SLE32)),
],
SC::Val::zero(),
);
Expand Down Expand Up @@ -94,31 +96,38 @@ impl Lt32Chip {
match op {
Operation::Lt32(a, b, c) => {
cols.is_lt = F::one();
self.set_cols(cols, a, b, c);
self.set_cols(cols, false, a, b, c);
}
Operation::Lte32(a, b, c) => {
cols.is_lte = F::one();
self.set_cols(cols, a, b, c);
self.set_cols(cols, false, a, b, c);
}
Operation::Slt32(a, b, c) => {
// TODO: this is just a placeholder
cols.is_lt = F::one();
self.set_cols(cols, a, b, c);
cols.is_slt = F::one();
self.set_cols(cols, true, a, b, c);
}
Operation::Sle32(a, b, c) => {
// TODO: this is just a placeholder
cols.is_lte = F::one();
self.set_cols(cols, a, b, c);
cols.is_sle = F::one();
self.set_cols(cols, true, a, b, c);
}
}
row
}

fn set_cols<F>(&self, cols: &mut Lt32Cols<F>, a: &Word<u8>, b: &Word<u8>, c: &Word<u8>)
where
fn set_cols<F>(
&self,
cols: &mut Lt32Cols<F>,
is_signed: bool,
a: &Word<u8>,
b: &Word<u8>,
c: &Word<u8>,
) where
F: PrimeField,
{
// Set the input columns
debug_assert_eq!(a.0.len(), 4);
debug_assert_eq!(b.0.len(), 4);
debug_assert_eq!(c.0.len(), 4);
cols.input_1 = b.transform(F::from_canonical_u8);
cols.input_2 = c.transform(F::from_canonical_u8);
cols.output = F::from_canonical_u8(a[3]);
Expand All @@ -127,51 +136,56 @@ impl Lt32Chip {
.into_iter()
.zip(c.into_iter())
.enumerate()
.find_map(|(n, (x, y))| if x == y { Some(n) } else { None })
.find_map(|(n, (x, y))| if x == y { None } else { Some(n) })
{
let z = 256u16 + b[n] as u16 - c[n] as u16;
for i in 0..10 {
for i in 0..9 {
cols.bits[i] = F::from_canonical_u16(z >> i & 1);
}
if n < 4 {
cols.byte_flag[n] = F::one();
}
cols.byte_flag[n] = F::one();
// b[n] != c[n] always here, so the difference is never zero.
cols.diff_inv = (cols.input_1[n] - cols.input_2[n]).inverse();
}
// compute (little-endian) bit decomposition of the top bytes
for i in 0..8 {
cols.top_bits_1[i] = F::from_canonical_u8(b[0] >> i & 1);
cols.top_bits_2[i] = F::from_canonical_u8(c[0] >> i & 1);
}
// check if sign bits agree and set different_signs accordingly
cols.different_signs = if is_signed {
if cols.top_bits_1[7] != cols.top_bits_2[7] {
F::one()
} else {
F::zero()
}
} else {
F::zero()
};

cols.multiplicity = F::one();
}
}

pub trait MachineWithLt32Chip<F: Field>: MachineWithCpuChip<F> {
fn lt_u32(&self) -> &Lt32Chip;
fn lt_u32_mut(&mut self) -> &mut Lt32Chip;
}

instructions!(
Lt32Instruction,
Lte32Instruction,
Slt32Instruction,
Sle32Instruction
);

impl<M, F> Instruction<M, F> for Lt32Instruction
where
M: MachineWithLt32Chip<F>,
F: Field,
{
const OPCODE: u32 = LT32;

fn execute(state: &mut M, ops: Operands<i32>) {
let opcode = <Self as Instruction<M, F>>::OPCODE;
fn execute_with_closure<M, E, F>(
state: &mut M,
ops: Operands<i32>,
opcode: u32,
comp: F,
) -> (Word<u8>, Word<u8>, Word<u8>)
where
M: MachineWithLt32Chip<E>,
E: Field,
F: Fn(Word<u8>, Word<u8>) -> bool,
{
let clk = state.cpu().clock;
let pc = state.cpu().pc;
let mut imm: Option<Word<u8>> = None;
let read_addr_1 = (state.cpu().fp as i32 + ops.b()) as u32;
let write_addr = (state.cpu().fp as i32 + ops.a()) as u32;
let src1: Word<u8> = if ops.d() == 1 {
let b = (ops.b() as u32).into();
imm = Some(b);
b
} else {
let read_addr_1 = (state.cpu().fp as i32 + ops.b()) as u32;
state
.mem_mut()
.read(clk, read_addr_1, true, pc, opcode, 0, "")
Expand All @@ -187,18 +201,49 @@ where
.read(clk, read_addr_2, true, pc, opcode, 1, "")
};

let dst = if src1 < src2 {
let dst = if comp(src1, src2) {
Word::from(1)
} else {
Word::from(0)
};
state.mem_mut().write(clk, write_addr, dst, true);

if ops.d() == 1 {
state.cpu_mut().push_left_imm_bus_op(imm, opcode, ops)
} else {
state.cpu_mut().push_bus_op(imm, opcode, ops);
}
(dst, src1, src2)
}
}

pub trait MachineWithLt32Chip<F: Field>: MachineWithCpuChip<F> {
fn lt_u32(&self) -> &Lt32Chip;
fn lt_u32_mut(&mut self) -> &mut Lt32Chip;
}

instructions!(
Lt32Instruction,
Lte32Instruction,
Slt32Instruction,
Sle32Instruction
);

impl<M, F> Instruction<M, F> for Lt32Instruction
where
M: MachineWithLt32Chip<F>,
F: Field,
{
const OPCODE: u32 = LT32;

fn execute(state: &mut M, ops: Operands<i32>) {
let opcode = <Self as Instruction<M, F>>::OPCODE;
let comp = |a, b| a < b;
let (dst, src1, src2) = Lt32Chip::execute_with_closure(state, ops, opcode, comp);
state
.lt_u32_mut()
.operations
.push(Operation::Lt32(dst, src1, src2));
state.cpu_mut().push_bus_op(imm, opcode, ops);
}
}

Expand All @@ -211,43 +256,12 @@ where

fn execute(state: &mut M, ops: Operands<i32>) {
let opcode = <Self as Instruction<M, F>>::OPCODE;
let clk = state.cpu().clock;
let pc = state.cpu().pc;
let mut imm: Option<Word<u8>> = None;
let read_addr_1 = (state.cpu().fp as i32 + ops.b()) as u32;
let write_addr = (state.cpu().fp as i32 + ops.a()) as u32;
let src1: Word<u8> = if ops.d() == 1 {
let b = (ops.b() as u32).into();
imm = Some(b);
b
} else {
state
.mem_mut()
.read(clk, read_addr_1, true, pc, opcode, 0, "")
};
let src2: Word<u8> = if ops.is_imm() == 1 {
let c = (ops.c() as u32).into();
imm = Some(c);
c
} else {
let read_addr_2 = (state.cpu().fp as i32 + ops.c()) as u32;
state
.mem_mut()
.read(clk, read_addr_2, true, pc, opcode, 1, "")
};

let dst = if src1 <= src2 {
Word::from(1)
} else {
Word::from(0)
};
state.mem_mut().write(clk, write_addr, dst, true);

let comp = |a, b| a <= b;
let (dst, src1, src2) = Lt32Chip::execute_with_closure(state, ops, opcode, comp);
state
.lt_u32_mut()
.operations
.push(Operation::Lte32(dst, src1, src2));
state.cpu_mut().push_bus_op(imm, opcode, ops);
}
}

Expand All @@ -260,45 +274,16 @@ where

fn execute(state: &mut M, ops: Operands<i32>) {
let opcode = <Self as Instruction<M, F>>::OPCODE;
let clk = state.cpu().clock;
let pc = state.cpu().pc;
let mut imm: Option<Word<u8>> = None;
let read_addr_1 = (state.cpu().fp as i32 + ops.b()) as u32;
let write_addr = (state.cpu().fp as i32 + ops.a()) as u32;
let src1: Word<u8> = if ops.d() == 1 {
let b = (ops.b() as u32).into();
imm = Some(b);
b
} else {
state
.mem_mut()
.read(clk, read_addr_1, true, pc, opcode, 0, "")
};
let src2: Word<u8> = if ops.is_imm() == 1 {
let c = (ops.c() as u32).into();
imm = Some(c);
c
} else {
let read_addr_2 = (state.cpu().fp as i32 + ops.c()) as u32;
state
.mem_mut()
.read(clk, read_addr_2, true, pc, opcode, 1, "")
let comp = |a: Word<u8>, b: Word<u8>| {
let a_i: i32 = a.into();
let b_i: i32 = b.into();
a_i < b_i
};

let src1_i: i32 = src1.into();
let src2_i: i32 = src2.into();
let dst = if src1_i < src2_i {
Word::from(1)
} else {
Word::from(0)
};
state.mem_mut().write(clk, write_addr, dst, true);

let (dst, src1, src2) = Lt32Chip::execute_with_closure(state, ops, opcode, comp);
state
.lt_u32_mut()
.operations
.push(Operation::Slt32(dst, src1, src2));
state.cpu_mut().push_bus_op(imm, opcode, ops);
}
}

Expand All @@ -311,44 +296,15 @@ where

fn execute(state: &mut M, ops: Operands<i32>) {
let opcode = <Self as Instruction<M, F>>::OPCODE;
let clk = state.cpu().clock;
let pc = state.cpu().pc;
let mut imm: Option<Word<u8>> = None;
let read_addr_1 = (state.cpu().fp as i32 + ops.b()) as u32;
let write_addr = (state.cpu().fp as i32 + ops.a()) as u32;
let src1: Word<u8> = if ops.d() == 1 {
let b = (ops.b() as u32).into();
imm = Some(b);
b
} else {
state
.mem_mut()
.read(clk, read_addr_1, true, pc, opcode, 0, "")
};
let src2: Word<u8> = if ops.is_imm() == 1 {
let c = (ops.c() as u32).into();
imm = Some(c);
c
} else {
let read_addr_2 = (state.cpu().fp as i32 + ops.c()) as u32;
state
.mem_mut()
.read(clk, read_addr_2, true, pc, opcode, 1, "")
let comp = |a: Word<u8>, b: Word<u8>| {
let a_i: i32 = a.into();
let b_i: i32 = b.into();
a_i <= b_i
};

let src1_i: i32 = src1.into();
let src2_i: i32 = src2.into();
let dst = if src1_i <= src2_i {
Word::from(1)
} else {
Word::from(0)
};
state.mem_mut().write(clk, write_addr, dst, true);

let (dst, src1, src2) = Lt32Chip::execute_with_closure(state, ops, opcode, comp);
state
.lt_u32_mut()
.operations
.push(Operation::Sle32(dst, src1, src2));
state.cpu_mut().push_bus_op(imm, opcode, ops);
}
}
Loading
Loading