Skip to content

Commit

Permalink
Merge pull request #2939 from o1-labs/sai/batch-inverse-riscv
Browse files Browse the repository at this point in the history
riscv batch inverse
  • Loading branch information
svv232 authored Jan 8, 2025
2 parents 39a3b5d + 6f678de commit b947c28
Show file tree
Hide file tree
Showing 6 changed files with 79 additions and 57 deletions.
29 changes: 23 additions & 6 deletions o1vm/src/interpreters/riscv32im/column.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use super::{
Instruction::{self, IType, MType, RType, SBType, SType, SyscallType, UJType, UType},
RInstruction, SBInstruction, SInstruction, SyscallInstruction, UInstruction, UJInstruction,
},
INSTRUCTION_SET_SIZE, SCRATCH_SIZE,
INSTRUCTION_SET_SIZE, SCRATCH_SIZE, SCRATCH_SIZE_INVERSE,
};
use kimchi::circuits::{
berkeley_columns::BerkeleyChallengeTerm,
Expand All @@ -15,6 +15,7 @@ use strum::EnumCount;
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum Column {
ScratchState(usize),
ScratchStateInverse(usize),
InstructionCounter,
Selector(usize),
}
Expand All @@ -26,13 +27,17 @@ impl From<Column> for usize {
assert!(i < SCRATCH_SIZE);
i
}
Column::InstructionCounter => SCRATCH_SIZE,
Column::ScratchStateInverse(i) => {
assert!(i < SCRATCH_SIZE_INVERSE);
SCRATCH_SIZE + i
}
Column::InstructionCounter => SCRATCH_SIZE + SCRATCH_SIZE_INVERSE,
Column::Selector(s) => {
assert!(
s < INSTRUCTION_SET_SIZE,
"There is only {INSTRUCTION_SET_SIZE}"
);
SCRATCH_SIZE + 1 + s
SCRATCH_SIZE + SCRATCH_SIZE_INVERSE + 1 + s
}
}
}
Expand All @@ -41,13 +46,21 @@ impl From<Column> for usize {
impl From<Instruction> for usize {
fn from(instr: Instruction) -> usize {
match instr {
RType(rtype) => SCRATCH_SIZE + 1 + rtype as usize,
IType(itype) => SCRATCH_SIZE + 1 + RInstruction::COUNT + itype as usize,
RType(rtype) => SCRATCH_SIZE + SCRATCH_SIZE_INVERSE + 1 + rtype as usize,
IType(itype) => {
SCRATCH_SIZE + SCRATCH_SIZE_INVERSE + 1 + RInstruction::COUNT + itype as usize
}
SType(stype) => {
SCRATCH_SIZE + 1 + RInstruction::COUNT + IInstruction::COUNT + stype as usize
SCRATCH_SIZE
+ SCRATCH_SIZE_INVERSE
+ 1
+ RInstruction::COUNT
+ IInstruction::COUNT
+ stype as usize
}
SBType(sbtype) => {
SCRATCH_SIZE
+ SCRATCH_SIZE_INVERSE
+ 1
+ RInstruction::COUNT
+ IInstruction::COUNT
Expand All @@ -56,6 +69,7 @@ impl From<Instruction> for usize {
}
UType(utype) => {
SCRATCH_SIZE
+ SCRATCH_SIZE_INVERSE
+ 1
+ RInstruction::COUNT
+ IInstruction::COUNT
Expand All @@ -65,6 +79,7 @@ impl From<Instruction> for usize {
}
UJType(ujtype) => {
SCRATCH_SIZE
+ SCRATCH_SIZE_INVERSE
+ 1
+ RInstruction::COUNT
+ IInstruction::COUNT
Expand All @@ -75,6 +90,7 @@ impl From<Instruction> for usize {
}
SyscallType(syscalltype) => {
SCRATCH_SIZE
+ SCRATCH_SIZE_INVERSE
+ 1
+ RInstruction::COUNT
+ IInstruction::COUNT
Expand All @@ -86,6 +102,7 @@ impl From<Instruction> for usize {
}
MType(mtype) => {
SCRATCH_SIZE
+ SCRATCH_SIZE_INVERSE
+ 1
+ RInstruction::COUNT
+ IInstruction::COUNT
Expand Down
20 changes: 10 additions & 10 deletions o1vm/src/interpreters/riscv32im/constraints.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ use kimchi::circuits::{

pub struct Env<F: Field> {
pub scratch_state_idx: usize,
pub scratch_state_idx_inverse: usize,
pub lookups: Vec<Lookup<E<F>>>,
pub constraints: Vec<E<F>>,
pub selector: Option<E<F>>,
Expand All @@ -24,6 +25,7 @@ impl<Fp: Field> Default for Env<Fp> {
fn default() -> Self {
Self {
scratch_state_idx: 0,
scratch_state_idx_inverse: 0,
constraints: Vec::new(),
lookups: Vec::new(),
selector: None,
Expand All @@ -49,6 +51,12 @@ impl<Fp: Field> InterpreterEnv for Env<Fp> {
Column::ScratchState(scratch_idx)
}

fn alloc_scratch_inverse(&mut self) -> Self::Position {
let scratch_idx = self.scratch_state_idx_inverse;
self.scratch_state_idx_inverse += 1;
Column::ScratchStateInverse(scratch_idx)
}

type Variable = E<Fp>;

fn variable(&self, column: Self::Position) -> Self::Variable {
Expand Down Expand Up @@ -206,8 +214,8 @@ impl<Fp: Field> InterpreterEnv for Env<Fp> {
unsafe { self.test_zero(x, pos) }
};
let x_inv_or_zero = {
let pos = self.alloc_scratch();
unsafe { self.inverse_or_zero(x, pos) }
let pos = self.alloc_scratch_inverse();
self.variable(pos)
};
// If x = 0, then res = 1 and x_inv_or_zero = 0
// If x <> 0, then res = 0 and x_inv_or_zero = x^(-1)
Expand All @@ -216,14 +224,6 @@ impl<Fp: Field> InterpreterEnv for Env<Fp> {
res
}

unsafe fn inverse_or_zero(
&mut self,
_x: &Self::Variable,
position: Self::Position,
) -> Self::Variable {
self.variable(position)
}

fn equal(&mut self, x: &Self::Variable, y: &Self::Variable) -> Self::Variable {
self.is_zero(&(x.clone() - y.clone()))
}
Expand Down
17 changes: 2 additions & 15 deletions o1vm/src/interpreters/riscv32im/interpreter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -561,6 +561,8 @@ pub trait InterpreterEnv {
/// [crate::interpreters::riscv32im::SCRATCH_SIZE]
fn alloc_scratch(&mut self) -> Self::Position;

fn alloc_scratch_inverse(&mut self) -> Self::Position;

type Variable: Clone
+ std::ops::Add<Self::Variable, Output = Self::Variable>
+ std::ops::Sub<Self::Variable, Output = Self::Variable>
Expand Down Expand Up @@ -1061,21 +1063,6 @@ pub trait InterpreterEnv {
/// `x`.
unsafe fn test_zero(&mut self, x: &Self::Variable, position: Self::Position) -> Self::Variable;

/// Returns `x^(-1)`, or `0` if `x` is `0`, storing the result in `position`.
///
/// # Safety
///
/// There are no constraints on the returned value; callers must assert the relationship with
/// `x`.
///
/// The value returned may be a placeholder; callers should be careful not to depend directly
/// on the value stored in the variable.
unsafe fn inverse_or_zero(
&mut self,
x: &Self::Variable,
position: Self::Position,
) -> Self::Variable;

fn is_zero(&mut self, x: &Self::Variable) -> Self::Variable;

/// Returns 1 if `x` is equal to `y`, or 0 otherwise, storing the result in `position`.
Expand Down
1 change: 1 addition & 0 deletions o1vm/src/interpreters/riscv32im/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
/// The minimal number of columns required for the VM
pub const SCRATCH_SIZE: usize = 39;
pub const SCRATCH_SIZE_INVERSE: usize = 1;

/// Number of instructions in the ISA
pub const INSTRUCTION_SET_SIZE: usize = 48;
Expand Down
7 changes: 6 additions & 1 deletion o1vm/src/interpreters/riscv32im/tests.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
use super::{registers::Registers, witness::Env, INSTRUCTION_SET_SIZE, PAGE_SIZE, SCRATCH_SIZE};
use super::{
registers::Registers, witness::Env, INSTRUCTION_SET_SIZE, PAGE_SIZE, SCRATCH_SIZE,
SCRATCH_SIZE_INVERSE,
};
use crate::interpreters::riscv32im::{
constraints,
interpreter::{
Expand Down Expand Up @@ -49,6 +52,8 @@ pub fn dummy_env() -> Env<Fp> {
registers_write_index: Registers::default(),
scratch_state_idx: 0,
scratch_state: [Fp::zero(); SCRATCH_SIZE],
scratch_state_inverse_idx: 0,
scratch_state_inverse: [Fp::zero(); SCRATCH_SIZE_INVERSE],
halt: false,
selector: INSTRUCTION_SET_SIZE,
}
Expand Down
62 changes: 37 additions & 25 deletions o1vm/src/interpreters/riscv32im/witness.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use super::{
SInstruction, SyscallInstruction, UInstruction, UJInstruction,
},
registers::Registers,
INSTRUCTION_SET_SIZE, SCRATCH_SIZE,
INSTRUCTION_SET_SIZE, SCRATCH_SIZE, SCRATCH_SIZE_INVERSE,
};
use crate::{
cannon::{State, PAGE_ADDRESS_MASK, PAGE_ADDRESS_SIZE, PAGE_SIZE},
Expand Down Expand Up @@ -46,6 +46,8 @@ pub struct Env<Fp> {
pub registers_write_index: Registers<u64>,
pub scratch_state_idx: usize,
pub scratch_state: [Fp; SCRATCH_SIZE],
pub scratch_state_inverse_idx: usize,
pub scratch_state_inverse: [Fp; SCRATCH_SIZE_INVERSE],
pub halt: bool,
pub selector: usize,
}
Expand All @@ -63,6 +65,12 @@ impl<Fp: Field> InterpreterEnv for Env<Fp> {
Column::ScratchState(scratch_idx)
}

fn alloc_scratch_inverse(&mut self) -> Self::Position {
let scratch_inverse_idx = self.scratch_state_inverse_idx;
self.scratch_state_inverse_idx += 1;
Column::ScratchStateInverse(scratch_inverse_idx)
}

type Variable = u64;

fn variable(&self, _column: Self::Position) -> Self::Variable {
Expand Down Expand Up @@ -277,44 +285,39 @@ impl<Fp: Field> InterpreterEnv for Env<Fp> {
res
}

unsafe fn inverse_or_zero(
&mut self,
x: &Self::Variable,
position: Self::Position,
) -> Self::Variable {
if *x == 0 {
self.write_column(position, 0);
0
} else {
self.write_field_column(position, Fp::from(*x).inverse().unwrap());
1 // Placeholder value
}
}

fn is_zero(&mut self, x: &Self::Variable) -> Self::Variable {
// write the result
let pos = self.alloc_scratch();
let res = if *x == 0 { 1 } else { 0 };
self.write_column(pos, res);
// write the non deterministic advice inv_or_zero
let pos = self.alloc_scratch();
let inv_or_zero = if *x == 0 {
Fp::zero()
let pos = self.alloc_scratch_inverse();
if *x == 0 {
self.write_field_column(pos, Fp::zero());
} else {
Fp::inverse(&Fp::from(*x)).unwrap()
self.write_field_column(pos, Fp::from(*x));
};
self.write_field_column(pos, inv_or_zero);
// return the result
res
}

fn equal(&mut self, x: &Self::Variable, y: &Self::Variable) -> Self::Variable {
// To avoid subtraction overflow in the witness interpreter for u32
if x > y {
self.is_zero(&(*x - *y))
// We replicate is_zero(x-y), but working on field elt,
// to avoid subtraction overflow in the witness interpreter for u32
let to_zero_test = Fp::from(*x) - Fp::from(*y);
let res = {
let pos = self.alloc_scratch();
let is_zero: u64 = if to_zero_test == Fp::zero() { 1 } else { 0 };
self.write_column(pos, is_zero);
is_zero
};
let pos = self.alloc_scratch_inverse();
if to_zero_test == Fp::zero() {
self.write_field_column(pos, Fp::zero());
} else {
self.is_zero(&(*y - *x))
}
self.write_field_column(pos, to_zero_test);
};
res
}

unsafe fn test_less_than(
Expand Down Expand Up @@ -684,6 +687,8 @@ impl<Fp: Field> Env<Fp> {
registers_write_index: Registers::default(),
scratch_state_idx: 0,
scratch_state: fresh_scratch_state(),
scratch_state_inverse_idx: 0,
scratch_state_inverse: fresh_scratch_state(),
halt: state.exited,
selector,
}
Expand Down Expand Up @@ -842,6 +847,7 @@ impl<Fp: Field> Env<Fp> {
/// Execute a single step in the RISCV32i program
pub fn step(&mut self) -> Instruction {
self.reset_scratch_state();
self.reset_scratch_state_inverse();
let (opcode, _instruction) = self.decode_instruction();

interpreter::interpret_instruction(self, opcode);
Expand All @@ -865,13 +871,19 @@ impl<Fp: Field> Env<Fp> {
self.selector = INSTRUCTION_SET_SIZE;
}

pub fn reset_scratch_state_inverse(&mut self) {
self.scratch_state_inverse_idx = 0;
self.scratch_state_inverse = fresh_scratch_state();
}

pub fn write_column(&mut self, column: Column, value: u64) {
self.write_field_column(column, value.into())
}

pub fn write_field_column(&mut self, column: Column, value: Fp) {
match column {
Column::ScratchState(idx) => self.scratch_state[idx] = value,
Column::ScratchStateInverse(idx) => self.scratch_state_inverse[idx] = value,
Column::InstructionCounter => panic!("Cannot overwrite the column {:?}", column),
Column::Selector(s) => self.selector = s,
}
Expand Down

0 comments on commit b947c28

Please sign in to comment.