diff --git a/o1vm/src/interpreters/riscv32im/column.rs b/o1vm/src/interpreters/riscv32im/column.rs index 871db4a546..b506a6128d 100644 --- a/o1vm/src/interpreters/riscv32im/column.rs +++ b/o1vm/src/interpreters/riscv32im/column.rs @@ -4,7 +4,7 @@ use super::{ Instruction::{self, IType, MType, RType, SBType, SType, SyscallType, UJType, UType}, RInstruction, SBInstruction, SInstruction, SyscallInstruction, UInstruction, UJInstruction, }, - INSTRUCTION_SET_SIZE, SCRATCH_SIZE, + INSTRUCTION_SET_SIZE, SCRATCH_SIZE, SCRATCH_SIZE_INVERSE, }; use kimchi::circuits::{ berkeley_columns::BerkeleyChallengeTerm, @@ -15,6 +15,7 @@ use strum::EnumCount; #[derive(Clone, Debug, PartialEq, Eq)] pub enum Column { ScratchState(usize), + ScratchStateInverse(usize), InstructionCounter, Selector(usize), } @@ -26,13 +27,17 @@ impl From for usize { assert!(i < SCRATCH_SIZE); i } - Column::InstructionCounter => SCRATCH_SIZE, + Column::ScratchStateInverse(i) => { + assert!(i < SCRATCH_SIZE_INVERSE); + SCRATCH_SIZE + i + } + Column::InstructionCounter => SCRATCH_SIZE + SCRATCH_SIZE_INVERSE, Column::Selector(s) => { assert!( s < INSTRUCTION_SET_SIZE, "There is only {INSTRUCTION_SET_SIZE}" ); - SCRATCH_SIZE + 1 + s + SCRATCH_SIZE + SCRATCH_SIZE_INVERSE + 1 + s } } } @@ -41,13 +46,21 @@ impl From for usize { impl From for usize { fn from(instr: Instruction) -> usize { match instr { - RType(rtype) => SCRATCH_SIZE + 1 + rtype as usize, - IType(itype) => SCRATCH_SIZE + 1 + RInstruction::COUNT + itype as usize, + RType(rtype) => SCRATCH_SIZE + SCRATCH_SIZE_INVERSE + 1 + rtype as usize, + IType(itype) => { + SCRATCH_SIZE + SCRATCH_SIZE_INVERSE + 1 + RInstruction::COUNT + itype as usize + } SType(stype) => { - SCRATCH_SIZE + 1 + RInstruction::COUNT + IInstruction::COUNT + stype as usize + SCRATCH_SIZE + + SCRATCH_SIZE_INVERSE + + 1 + + RInstruction::COUNT + + IInstruction::COUNT + + stype as usize } SBType(sbtype) => { SCRATCH_SIZE + + SCRATCH_SIZE_INVERSE + 1 + RInstruction::COUNT + IInstruction::COUNT @@ -56,6 +69,7 @@ impl From for usize { } UType(utype) => { SCRATCH_SIZE + + SCRATCH_SIZE_INVERSE + 1 + RInstruction::COUNT + IInstruction::COUNT @@ -65,6 +79,7 @@ impl From for usize { } UJType(ujtype) => { SCRATCH_SIZE + + SCRATCH_SIZE_INVERSE + 1 + RInstruction::COUNT + IInstruction::COUNT @@ -75,6 +90,7 @@ impl From for usize { } SyscallType(syscalltype) => { SCRATCH_SIZE + + SCRATCH_SIZE_INVERSE + 1 + RInstruction::COUNT + IInstruction::COUNT @@ -86,6 +102,7 @@ impl From for usize { } MType(mtype) => { SCRATCH_SIZE + + SCRATCH_SIZE_INVERSE + 1 + RInstruction::COUNT + IInstruction::COUNT diff --git a/o1vm/src/interpreters/riscv32im/constraints.rs b/o1vm/src/interpreters/riscv32im/constraints.rs index e45099490e..6d467ebab5 100644 --- a/o1vm/src/interpreters/riscv32im/constraints.rs +++ b/o1vm/src/interpreters/riscv32im/constraints.rs @@ -15,6 +15,7 @@ use kimchi::circuits::{ pub struct Env { pub scratch_state_idx: usize, + pub scratch_state_idx_inverse: usize, pub lookups: Vec>>, pub constraints: Vec>, pub selector: Option>, @@ -24,6 +25,7 @@ impl Default for Env { fn default() -> Self { Self { scratch_state_idx: 0, + scratch_state_idx_inverse: 0, constraints: Vec::new(), lookups: Vec::new(), selector: None, @@ -49,6 +51,12 @@ impl InterpreterEnv for Env { Column::ScratchState(scratch_idx) } + fn alloc_scratch_inverse(&mut self) -> Self::Position { + let scratch_idx = self.scratch_state_idx_inverse; + self.scratch_state_idx_inverse += 1; + Column::ScratchStateInverse(scratch_idx) + } + type Variable = E; fn variable(&self, column: Self::Position) -> Self::Variable { @@ -206,8 +214,8 @@ impl InterpreterEnv for Env { unsafe { self.test_zero(x, pos) } }; let x_inv_or_zero = { - let pos = self.alloc_scratch(); - unsafe { self.inverse_or_zero(x, pos) } + let pos = self.alloc_scratch_inverse(); + self.variable(pos) }; // If x = 0, then res = 1 and x_inv_or_zero = 0 // If x <> 0, then res = 0 and x_inv_or_zero = x^(-1) @@ -216,14 +224,6 @@ impl InterpreterEnv for Env { res } - unsafe fn inverse_or_zero( - &mut self, - _x: &Self::Variable, - position: Self::Position, - ) -> Self::Variable { - self.variable(position) - } - fn equal(&mut self, x: &Self::Variable, y: &Self::Variable) -> Self::Variable { self.is_zero(&(x.clone() - y.clone())) } diff --git a/o1vm/src/interpreters/riscv32im/interpreter.rs b/o1vm/src/interpreters/riscv32im/interpreter.rs index 91f100e911..fc7776086b 100644 --- a/o1vm/src/interpreters/riscv32im/interpreter.rs +++ b/o1vm/src/interpreters/riscv32im/interpreter.rs @@ -561,6 +561,8 @@ pub trait InterpreterEnv { /// [crate::interpreters::riscv32im::SCRATCH_SIZE] fn alloc_scratch(&mut self) -> Self::Position; + fn alloc_scratch_inverse(&mut self) -> Self::Position; + type Variable: Clone + std::ops::Add + std::ops::Sub @@ -1061,21 +1063,6 @@ pub trait InterpreterEnv { /// `x`. unsafe fn test_zero(&mut self, x: &Self::Variable, position: Self::Position) -> Self::Variable; - /// Returns `x^(-1)`, or `0` if `x` is `0`, storing the result in `position`. - /// - /// # Safety - /// - /// There are no constraints on the returned value; callers must assert the relationship with - /// `x`. - /// - /// The value returned may be a placeholder; callers should be careful not to depend directly - /// on the value stored in the variable. - unsafe fn inverse_or_zero( - &mut self, - x: &Self::Variable, - position: Self::Position, - ) -> Self::Variable; - fn is_zero(&mut self, x: &Self::Variable) -> Self::Variable; /// Returns 1 if `x` is equal to `y`, or 0 otherwise, storing the result in `position`. diff --git a/o1vm/src/interpreters/riscv32im/mod.rs b/o1vm/src/interpreters/riscv32im/mod.rs index ed6dd1b38a..dade8fad6f 100644 --- a/o1vm/src/interpreters/riscv32im/mod.rs +++ b/o1vm/src/interpreters/riscv32im/mod.rs @@ -1,5 +1,6 @@ /// The minimal number of columns required for the VM pub const SCRATCH_SIZE: usize = 39; +pub const SCRATCH_SIZE_INVERSE: usize = 1; /// Number of instructions in the ISA pub const INSTRUCTION_SET_SIZE: usize = 48; diff --git a/o1vm/src/interpreters/riscv32im/tests.rs b/o1vm/src/interpreters/riscv32im/tests.rs index 08c1634c5b..3fada10b6b 100644 --- a/o1vm/src/interpreters/riscv32im/tests.rs +++ b/o1vm/src/interpreters/riscv32im/tests.rs @@ -1,4 +1,7 @@ -use super::{registers::Registers, witness::Env, INSTRUCTION_SET_SIZE, PAGE_SIZE, SCRATCH_SIZE}; +use super::{ + registers::Registers, witness::Env, INSTRUCTION_SET_SIZE, PAGE_SIZE, SCRATCH_SIZE, + SCRATCH_SIZE_INVERSE, +}; use crate::interpreters::riscv32im::{ constraints, interpreter::{ @@ -49,6 +52,8 @@ pub fn dummy_env() -> Env { registers_write_index: Registers::default(), scratch_state_idx: 0, scratch_state: [Fp::zero(); SCRATCH_SIZE], + scratch_state_inverse_idx: 0, + scratch_state_inverse: [Fp::zero(); SCRATCH_SIZE_INVERSE], halt: false, selector: INSTRUCTION_SET_SIZE, } diff --git a/o1vm/src/interpreters/riscv32im/witness.rs b/o1vm/src/interpreters/riscv32im/witness.rs index 256de81687..228a16779a 100644 --- a/o1vm/src/interpreters/riscv32im/witness.rs +++ b/o1vm/src/interpreters/riscv32im/witness.rs @@ -7,7 +7,7 @@ use super::{ SInstruction, SyscallInstruction, UInstruction, UJInstruction, }, registers::Registers, - INSTRUCTION_SET_SIZE, SCRATCH_SIZE, + INSTRUCTION_SET_SIZE, SCRATCH_SIZE, SCRATCH_SIZE_INVERSE, }; use crate::{ cannon::{State, PAGE_ADDRESS_MASK, PAGE_ADDRESS_SIZE, PAGE_SIZE}, @@ -46,6 +46,8 @@ pub struct Env { pub registers_write_index: Registers, pub scratch_state_idx: usize, pub scratch_state: [Fp; SCRATCH_SIZE], + pub scratch_state_inverse_idx: usize, + pub scratch_state_inverse: [Fp; SCRATCH_SIZE_INVERSE], pub halt: bool, pub selector: usize, } @@ -63,6 +65,12 @@ impl InterpreterEnv for Env { Column::ScratchState(scratch_idx) } + fn alloc_scratch_inverse(&mut self) -> Self::Position { + let scratch_inverse_idx = self.scratch_state_inverse_idx; + self.scratch_state_inverse_idx += 1; + Column::ScratchStateInverse(scratch_inverse_idx) + } + type Variable = u64; fn variable(&self, _column: Self::Position) -> Self::Variable { @@ -277,44 +285,39 @@ impl InterpreterEnv for Env { res } - unsafe fn inverse_or_zero( - &mut self, - x: &Self::Variable, - position: Self::Position, - ) -> Self::Variable { - if *x == 0 { - self.write_column(position, 0); - 0 - } else { - self.write_field_column(position, Fp::from(*x).inverse().unwrap()); - 1 // Placeholder value - } - } - fn is_zero(&mut self, x: &Self::Variable) -> Self::Variable { // write the result let pos = self.alloc_scratch(); let res = if *x == 0 { 1 } else { 0 }; self.write_column(pos, res); // write the non deterministic advice inv_or_zero - let pos = self.alloc_scratch(); - let inv_or_zero = if *x == 0 { - Fp::zero() + let pos = self.alloc_scratch_inverse(); + if *x == 0 { + self.write_field_column(pos, Fp::zero()); } else { - Fp::inverse(&Fp::from(*x)).unwrap() + self.write_field_column(pos, Fp::from(*x)); }; - self.write_field_column(pos, inv_or_zero); // return the result res } fn equal(&mut self, x: &Self::Variable, y: &Self::Variable) -> Self::Variable { - // To avoid subtraction overflow in the witness interpreter for u32 - if x > y { - self.is_zero(&(*x - *y)) + // We replicate is_zero(x-y), but working on field elt, + // to avoid subtraction overflow in the witness interpreter for u32 + let to_zero_test = Fp::from(*x) - Fp::from(*y); + let res = { + let pos = self.alloc_scratch(); + let is_zero: u64 = if to_zero_test == Fp::zero() { 1 } else { 0 }; + self.write_column(pos, is_zero); + is_zero + }; + let pos = self.alloc_scratch_inverse(); + if to_zero_test == Fp::zero() { + self.write_field_column(pos, Fp::zero()); } else { - self.is_zero(&(*y - *x)) - } + self.write_field_column(pos, to_zero_test); + }; + res } unsafe fn test_less_than( @@ -684,6 +687,8 @@ impl Env { registers_write_index: Registers::default(), scratch_state_idx: 0, scratch_state: fresh_scratch_state(), + scratch_state_inverse_idx: 0, + scratch_state_inverse: fresh_scratch_state(), halt: state.exited, selector, } @@ -842,6 +847,7 @@ impl Env { /// Execute a single step in the RISCV32i program pub fn step(&mut self) -> Instruction { self.reset_scratch_state(); + self.reset_scratch_state_inverse(); let (opcode, _instruction) = self.decode_instruction(); interpreter::interpret_instruction(self, opcode); @@ -865,6 +871,11 @@ impl Env { self.selector = INSTRUCTION_SET_SIZE; } + pub fn reset_scratch_state_inverse(&mut self) { + self.scratch_state_inverse_idx = 0; + self.scratch_state_inverse = fresh_scratch_state(); + } + pub fn write_column(&mut self, column: Column, value: u64) { self.write_field_column(column, value.into()) } @@ -872,6 +883,7 @@ impl Env { pub fn write_field_column(&mut self, column: Column, value: Fp) { match column { Column::ScratchState(idx) => self.scratch_state[idx] = value, + Column::ScratchStateInverse(idx) => self.scratch_state_inverse[idx] = value, Column::InstructionCounter => panic!("Cannot overwrite the column {:?}", column), Column::Selector(s) => self.selector = s, }