Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reduction count #140

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/lair/bytecode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ pub struct Func<F> {
pub(crate) input_size: usize,
pub(crate) output_size: usize,
pub(crate) body: Block<F>,
pub(crate) rc: usize,
}

impl<F> Func<F> {
Expand Down
21 changes: 19 additions & 2 deletions src/lair/execute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,25 @@ impl<'a, F: PrimeField32> Shard<'a, F> {
self.queries.expect("Missing query record reference")
}

pub fn get_func_range(&self, func_index: usize) -> Range<usize> {
let num_func_queries = self.queries().func_queries[func_index].len();
pub fn get_func_range_rc(&self, func: &Func<F>, rc_index: usize) -> Range<usize> {
let num_func_queries = self.queries().func_queries[func.index].len();
let shard_idx = self.index as usize;
let shard_chunk = self.shard_config.max_shard_size as usize * func.rc;
let start = shard_idx * shard_chunk;
let end = ((shard_idx + 1) * shard_chunk).min(num_func_queries);
let len = (start..end).len();
if len % func.rc == 0 {
let chunk = len / func.rc;
start + chunk * rc_index..start + chunk * (rc_index + 1)
} else {
let chunk = (len / func.rc) + 1;
let end = start + chunk * (rc_index + 1);
start + chunk * rc_index..end.min(num_func_queries)
}
}

pub fn get_func_range(&self, func: &Func<F>) -> Range<usize> {
let num_func_queries = self.queries().func_queries[func.index].len();
let shard_idx = self.index as usize;
let max_shard_size = self.shard_config.max_shard_size as usize;
shard_idx * max_shard_size..((shard_idx + 1) * max_shard_size).min(num_func_queries)
Expand Down
3 changes: 2 additions & 1 deletion src/lair/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -133,9 +133,10 @@ pub struct CasesE<K, F> {
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct FuncE<F> {
pub name: Name,
pub invertible: bool,
pub partial: bool,
pub invertible: bool,
pub input_params: VarList,
pub output_size: usize,
pub body: BlockE<F>,
pub rc: usize,
}
31 changes: 23 additions & 8 deletions src/lair/lair_chip.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::sync::Arc;

use p3_air::{Air, AirBuilder, AirBuilderWithPublicValues, BaseAir, PairBuilder};
use p3_field::{AbstractField, Field, PrimeField32};
use p3_matrix::{dense::RowMajorMatrix, Matrix};
Expand All @@ -20,7 +22,10 @@ use super::{
};

pub enum LairChip<'a, F, H: Chipset<F>> {
Func(FuncChip<'a, F, H>),
Func {
func_chip: Arc<FuncChip<'a, F, H>>,
rc_index: usize,
},
Mem(MemChip<F>),
Bytes(BytesChip<F>),
Entrypoint {
Expand Down Expand Up @@ -54,7 +59,7 @@ impl<'a, F: PrimeField32, H: Chipset<F>> EventLens<LairChip<'a, F, H>> for Shard
impl<'a, F: Field + Sync, H: Chipset<F>> BaseAir<F> for LairChip<'a, F, H> {
fn width(&self) -> usize {
match self {
Self::Func(func_chip) => func_chip.width(),
Self::Func { func_chip, .. } => func_chip.width(),
Self::Mem(mem_chip) => mem_chip.width(),
Self::Bytes(bytes_chip) => bytes_chip.width(),
Self::Entrypoint {
Expand All @@ -78,7 +83,7 @@ impl<'a, F: PrimeField32, H: Chipset<F>> MachineAir<F> for LairChip<'a, F, H> {

fn name(&self) -> String {
match self {
Self::Func(func_chip) => format!("Func[{}]", func_chip.func.name),
Self::Func { func_chip, .. } => format!("Func[{}]", func_chip.func.name),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I’m not sure if this is checked by the current verifier, but I think there might be a risk of collision if we include the same chip multiple times.

Self::Mem(mem_chip) => format!("Mem[{}-wide]", mem_chip.len),
Self::Entrypoint { func_idx, .. } => format!("Entrypoint[{func_idx}]"),
// the following is required by sphinx
Expand All @@ -93,7 +98,10 @@ impl<'a, F: PrimeField32, H: Chipset<F>> MachineAir<F> for LairChip<'a, F, H> {
_: &mut Self::Record,
) -> RowMajorMatrix<F> {
match self {
Self::Func(func_chip) => func_chip.generate_trace(shard.events()),
Self::Func {
func_chip,
rc_index,
} => func_chip.generate_trace_rc(shard.events(), *rc_index),
Self::Mem(mem_chip) => mem_chip.generate_trace(shard.events()),
Self::Bytes(bytes_chip) => {
// TODO: Shard the byte events differently?
Expand All @@ -117,8 +125,8 @@ impl<'a, F: PrimeField32, H: Chipset<F>> MachineAir<F> for LairChip<'a, F, H> {

fn included(&self, shard: &Self::Record) -> bool {
match self {
Self::Func(func_chip) => {
let range = shard.get_func_range(func_chip.func.index);
Self::Func { func_chip, .. } => {
let range = shard.get_func_range(func_chip.func);
!range.is_empty()
}
Self::Mem(_mem_chip) => {
Expand Down Expand Up @@ -154,7 +162,7 @@ where
{
fn eval(&self, builder: &mut AB) {
match self {
Self::Func(func_chip) => func_chip.eval(builder),
Self::Func { func_chip, .. } => func_chip.eval(builder),
Self::Mem(mem_chip) => mem_chip.eval(builder),
Self::Bytes(bytes_chip) => bytes_chip.eval(builder),
Self::Entrypoint {
Expand Down Expand Up @@ -195,7 +203,14 @@ pub fn build_lair_chip_vector<'a, F: PrimeField32, H: Chipset<F>>(
let mut chip_vector = Vec::with_capacity(2 + toplevel.map.size() + MEM_TABLE_SIZES.len());
chip_vector.push(LairChip::entrypoint(func));
for func_chip in FuncChip::from_toplevel(toplevel) {
chip_vector.push(LairChip::Func(func_chip));
let func_chip = Arc::new(func_chip);
for rc_index in 0..func_chip.func.rc {
let func_chip = func_chip.clone();
chip_vector.push(LairChip::Func {
func_chip,
rc_index,
});
}
}
for mem_len in MEM_TABLE_SIZES {
chip_vector.push(LairChip::Mem(MemChip::new(mem_len)));
Expand Down
18 changes: 14 additions & 4 deletions src/lair/macros.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
#[macro_export]
macro_rules! func {
(fn $name:ident($( $in:ident $(: [$in_size:expr])? ),*): [$size:expr] $lair:tt) => {{
(#[RC = $rc:expr] $($x:tt)*) => { $crate::func_body!($rc, $($x)*) };
($($x:tt)*) => { $crate::func_body!(1, $($x)*) };
}

#[macro_export]
macro_rules! func_body {
($rc:expr, fn $name:ident($( $in:ident $(: [$in_size:expr])? ),*): [$size:expr] $lair:tt) => {{
$(let $in = $crate::var!($in $(, $in_size)?);)*
$crate::lair::expr::FuncE {
name: $crate::lair::Name(stringify!($name)),
Expand All @@ -9,9 +15,10 @@ macro_rules! func {
input_params: [$($crate::var!($in $(, $in_size)?)),*].into(),
output_size: $size,
body: $crate::block_init!($lair),
rc: $rc,
}
}};
(partial fn $name:ident($( $in:ident $(: [$in_size:expr])? ),*): [$size:expr] $lair:tt) => {{
($rc:expr, partial fn $name:ident($( $in:ident $(: [$in_size:expr])? ),*): [$size:expr] $lair:tt) => {{
$(let $in = $crate::var!($in $(, $in_size)?);)*
$crate::lair::expr::FuncE {
name: $crate::lair::Name(stringify!($name)),
Expand All @@ -20,9 +27,10 @@ macro_rules! func {
input_params: [$($crate::var!($in $(, $in_size)?)),*].into(),
output_size: $size,
body: $crate::block_init!($lair),
rc: $rc,
}
}};
(invertible fn $name:ident($( $in:ident $(: [$in_size:expr])? ),*): [$size:expr] $lair:tt) => {{
($rc:expr, invertible fn $name:ident($( $in:ident $(: [$in_size:expr])? ),*): [$size:expr] $lair:tt) => {{
$(let $in = $crate::var!($in $(, $in_size)?);)*
$crate::lair::expr::FuncE {
name: $crate::lair::Name(stringify!($name)),
Expand All @@ -31,9 +39,10 @@ macro_rules! func {
input_params: [$($crate::var!($in $(, $in_size)?)),*].into(),
output_size: $size,
body: $crate::block_init!($lair),
rc: $rc,
}
}};
(invertible partial fn $name:ident($( $in:ident $(: [$in_size:expr])? ),*): [$size:expr] $lair:tt) => {{
($rc:expr, invertible partial fn $name:ident($( $in:ident $(: [$in_size:expr])? ),*): [$size:expr] $lair:tt) => {{
$(let $in = $crate::var!($in $(, $in_size)?);)*
$crate::lair::expr::FuncE {
name: $crate::lair::Name(stringify!($name)),
Expand All @@ -42,6 +51,7 @@ macro_rules! func {
input_params: [$($crate::var!($in $(, $in_size)?)),*].into(),
output_size: $size,
body: $crate::block_init!($lair),
rc: $rc,
}
}};
}
Expand Down
1 change: 1 addition & 0 deletions src/lair/toplevel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ impl<F: Field + Ord> FuncE<F> {
body,
input_size: self.input_params.total_size(),
output_size: self.output_size,
rc: self.rc,
}
}
}
Expand Down
66 changes: 65 additions & 1 deletion src/lair/trace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,74 @@ impl<'a, T> ColumnMutSlice<'a, T> {
}

impl<'a, F: PrimeField32, H: Chipset<F>> FuncChip<'a, F, H> {
/// Per-row parallel trace generation
pub fn generate_trace_rc(&self, shard: &Shard<'_, F>, rc_index: usize) -> RowMajorMatrix<F> {
let func_queries = &shard.queries().func_queries()[self.func.index];
let range = shard.get_func_range_rc(self.func, rc_index);
let offset = range.start;
let width = self.width();
let non_dummy_height = range.len();
let height = non_dummy_height.next_power_of_two();
let mut rows = vec![F::zero(); height * width];
// initializing nonces
rows.chunks_mut(width)
.enumerate()
.for_each(|(i, row)| row[0] = F::from_canonical_usize(i + offset));
let non_dummies = &mut rows[0..non_dummy_height * width];
non_dummies
.par_chunks_mut(width)
.enumerate()
.for_each(|(i, row)| {
let (args, result) = func_queries.get_index(i + offset).unwrap();
let index = &mut ColumnIndex::default();
let slice = &mut ColumnMutSlice::from_slice(row, self.layout_sizes);
let requires = result.requires.iter();
let mut depth_requires = result.depth_requires.iter();
let queries = shard.queries();
let query_map = &queries.func_queries()[self.func.index];
let lookup = query_map
.get(args)
.expect("Cannot find query result")
.provide;
let provide = lookup.into_provide();
result
.output
.as_ref()
.unwrap()
.iter()
.for_each(|&o| slice.push_output(index, o));
slice.push_aux(index, provide.last_nonce);
slice.push_aux(index, provide.last_count);
// provenance and range check
if self.func.partial {
let num_requires = (DEPTH_W / 2) + (DEPTH_W % 2);
let depth: [u8; DEPTH_W] = result.depth.to_le_bytes();
for b in depth {
slice.push_aux(index, F::from_canonical_u8(b));
}
for _ in 0..num_requires {
let lookup = depth_requires.next().expect("Not enough require hints");
slice.push_require(index, lookup.into_require());
}
}
self.func.populate_row(
args,
index,
slice,
queries,
requires,
self.toplevel,
result.depth,
depth_requires,
);
});
RowMajorMatrix::new(rows, width)
}

/// Per-row parallel trace generation
pub fn generate_trace(&self, shard: &Shard<'_, F>) -> RowMajorMatrix<F> {
let func_queries = &shard.queries().func_queries()[self.func.index];
let range = shard.get_func_range(self.func.index);
let range = shard.get_func_range(self.func);
let width = self.width();
let non_dummy_height = range.len();
let height = non_dummy_height.next_power_of_two();
Expand Down
11 changes: 11 additions & 0 deletions src/lurk/eval.rs
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,11 @@ impl EvalErr {
}
}

const EVAL_RC: usize = 8;
const APPLY_RC: usize = 4;
const BINOP_RC: usize = 4;
const LOOKUP_RC: usize = 4;

pub fn lurk_main<F: AbstractField>() -> FuncE<F> {
func!(
partial fn lurk_main(full_expr_tag: [8], expr_digest: [8], env_digest: [8]): [16] {
Expand Down Expand Up @@ -301,6 +306,7 @@ fn ingress_builtin<F: AbstractField + Ord>(builtins: &BuiltinMemo<'_, F>) -> Fun
input_params: [input_var].into(),
output_size: 1,
body: BlockE { ops, ctrl },
rc: 1,
}
}

Expand Down Expand Up @@ -435,6 +441,7 @@ fn egress_builtin<F: AbstractField + Ord>(builtins: &BuiltinMemo<'_, F>) -> Func
input_params: [input_var].into(),
output_size: 8,
body: BlockE { ops, ctrl },
rc: 1,
}
}

Expand Down Expand Up @@ -565,6 +572,7 @@ pub fn big_num_lessthan<F>() -> FuncE<F> {

pub fn eval<F: AbstractField + Ord>(builtins: &BuiltinMemo<'_, F>) -> FuncE<F> {
func!(
#[RC = EVAL_RC]
partial fn eval(expr_tag, expr, env): [2] {
// Constants, tags, etc
let t = builtins.index("t");
Expand Down Expand Up @@ -1131,6 +1139,7 @@ pub fn eval_begin<F: AbstractField + Ord>(builtins: &BuiltinMemo<'_, F>) -> Func

pub fn eval_binop_num<F: AbstractField + Ord>(builtins: &BuiltinMemo<'_, F>) -> FuncE<F> {
func!(
#[RC = BINOP_RC]
partial fn eval_binop_num(head, exp1_tag, exp1, exp2_tag, exp2, env): [2] {
let err_tag = Tag::Err;
let num_tag = Tag::Num;
Expand Down Expand Up @@ -1685,6 +1694,7 @@ pub fn eval_letrec<F: AbstractField + Ord>() -> FuncE<F> {

pub fn apply<F: AbstractField + Ord>() -> FuncE<F> {
func!(
#[RC = APPLY_RC]
partial fn apply(head_tag, head, args_tag, args, args_env): [2] {
// Constants, tags, etc
let err_tag = Tag::Err;
Expand Down Expand Up @@ -1765,6 +1775,7 @@ pub fn apply<F: AbstractField + Ord>() -> FuncE<F> {

pub fn env_lookup<F: AbstractField>() -> FuncE<F> {
func!(
#[RC = LOOKUP_RC]
fn env_lookup(x_digest: [8], env): [2] {
if !env {
let err_tag = Tag::Err;
Expand Down