Skip to content

Commit

Permalink
reduction count implemented
Browse files Browse the repository at this point in the history
  • Loading branch information
gabriel-barrett committed Aug 31, 2024
1 parent 6915aea commit 6a16b1b
Show file tree
Hide file tree
Showing 8 changed files with 136 additions and 16 deletions.
1 change: 1 addition & 0 deletions src/lair/bytecode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ pub struct Func<F> {
pub(crate) input_size: usize,
pub(crate) output_size: usize,
pub(crate) body: Block<F>,
pub(crate) rc: usize,
}

impl<F> Func<F> {
Expand Down
21 changes: 19 additions & 2 deletions src/lair/execute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,25 @@ impl<'a, F: PrimeField32> Shard<'a, F> {
self.queries.expect("Missing query record reference")
}

pub fn get_func_range(&self, func_index: usize) -> Range<usize> {
let num_func_queries = self.queries().func_queries[func_index].len();
pub fn get_func_range_rc(&self, func: &Func<F>, rc_index: usize) -> Range<usize> {
let num_func_queries = self.queries().func_queries[func.index].len();
let shard_idx = self.index as usize;
let shard_chunk = self.shard_config.max_shard_size as usize * func.rc;
let start = shard_idx * shard_chunk;
let end = ((shard_idx + 1) * shard_chunk).min(num_func_queries);
let len = (start..end).len();
if len % func.rc == 0 {
let chunk = len / func.rc;
start + chunk * rc_index..start + chunk * (rc_index + 1)
} else {
let chunk = (len / func.rc) + 1;
let end = start + chunk * (rc_index + 1);
start + chunk * rc_index..end.min(num_func_queries)
}
}

pub fn get_func_range(&self, func: &Func<F>) -> Range<usize> {
let num_func_queries = self.queries().func_queries[func.index].len();
let shard_idx = self.index as usize;
let max_shard_size = self.shard_config.max_shard_size as usize;
shard_idx * max_shard_size..((shard_idx + 1) * max_shard_size).min(num_func_queries)
Expand Down
3 changes: 2 additions & 1 deletion src/lair/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -133,9 +133,10 @@ pub struct CasesE<K, F> {
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct FuncE<F> {
pub name: Name,
pub invertible: bool,
pub partial: bool,
pub invertible: bool,
pub input_params: VarList,
pub output_size: usize,
pub body: BlockE<F>,
pub rc: usize,
}
31 changes: 23 additions & 8 deletions src/lair/lair_chip.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::sync::Arc;

use p3_air::{Air, AirBuilder, AirBuilderWithPublicValues, BaseAir, PairBuilder};
use p3_field::{AbstractField, Field, PrimeField32};
use p3_matrix::{dense::RowMajorMatrix, Matrix};
Expand All @@ -20,7 +22,10 @@ use super::{
};

pub enum LairChip<'a, F, H: Chipset<F>> {
Func(FuncChip<'a, F, H>),
Func {
func_chip: Arc<FuncChip<'a, F, H>>,
rc_index: usize,
},
Mem(MemChip<F>),
Bytes(BytesChip<F>),
Entrypoint {
Expand Down Expand Up @@ -54,7 +59,7 @@ impl<'a, F: PrimeField32, H: Chipset<F>> EventLens<LairChip<'a, F, H>> for Shard
impl<'a, F: Field + Sync, H: Chipset<F>> BaseAir<F> for LairChip<'a, F, H> {
fn width(&self) -> usize {
match self {
Self::Func(func_chip) => func_chip.width(),
Self::Func { func_chip, .. } => func_chip.width(),
Self::Mem(mem_chip) => mem_chip.width(),
Self::Bytes(bytes_chip) => bytes_chip.width(),
Self::Entrypoint {
Expand All @@ -78,7 +83,7 @@ impl<'a, F: PrimeField32, H: Chipset<F>> MachineAir<F> for LairChip<'a, F, H> {

fn name(&self) -> String {
match self {
Self::Func(func_chip) => format!("Func[{}]", func_chip.func.name),
Self::Func { func_chip, .. } => format!("Func[{}]", func_chip.func.name),
Self::Mem(mem_chip) => format!("Mem[{}-wide]", mem_chip.len),
Self::Entrypoint { func_idx, .. } => format!("Entrypoint[{func_idx}]"),
// the following is required by sphinx
Expand All @@ -93,7 +98,10 @@ impl<'a, F: PrimeField32, H: Chipset<F>> MachineAir<F> for LairChip<'a, F, H> {
_: &mut Self::Record,
) -> RowMajorMatrix<F> {
match self {
Self::Func(func_chip) => func_chip.generate_trace(shard.events()),
Self::Func {
func_chip,
rc_index,
} => func_chip.generate_trace_rc(shard.events(), *rc_index),
Self::Mem(mem_chip) => mem_chip.generate_trace(shard.events()),
Self::Bytes(bytes_chip) => {
// TODO: Shard the byte events differently?
Expand All @@ -117,8 +125,8 @@ impl<'a, F: PrimeField32, H: Chipset<F>> MachineAir<F> for LairChip<'a, F, H> {

fn included(&self, shard: &Self::Record) -> bool {
match self {
Self::Func(func_chip) => {
let range = shard.get_func_range(func_chip.func.index);
Self::Func { func_chip, .. } => {
let range = shard.get_func_range(func_chip.func);
!range.is_empty()
}
Self::Mem(_mem_chip) => {
Expand Down Expand Up @@ -154,7 +162,7 @@ where
{
fn eval(&self, builder: &mut AB) {
match self {
Self::Func(func_chip) => func_chip.eval(builder),
Self::Func { func_chip, .. } => func_chip.eval(builder),
Self::Mem(mem_chip) => mem_chip.eval(builder),
Self::Bytes(bytes_chip) => bytes_chip.eval(builder),
Self::Entrypoint {
Expand Down Expand Up @@ -195,7 +203,14 @@ pub fn build_lair_chip_vector<'a, F: PrimeField32, H: Chipset<F>>(
let mut chip_vector = Vec::with_capacity(2 + toplevel.map.size() + MEM_TABLE_SIZES.len());
chip_vector.push(LairChip::entrypoint(func));
for func_chip in FuncChip::from_toplevel(toplevel) {
chip_vector.push(LairChip::Func(func_chip));
let func_chip = Arc::new(func_chip);
for rc_index in 0..func_chip.func.rc {
let func_chip = func_chip.clone();
chip_vector.push(LairChip::Func {
func_chip,
rc_index,
});
}
}
for mem_len in MEM_TABLE_SIZES {
chip_vector.push(LairChip::Mem(MemChip::new(mem_len)));
Expand Down
18 changes: 14 additions & 4 deletions src/lair/macros.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
#[macro_export]
macro_rules! func {
(fn $name:ident($( $in:ident $(: [$in_size:expr])? ),*): [$size:expr] $lair:tt) => {{
(#[RC = $rc:expr] $($x:tt)*) => { $crate::func_body!($rc, $($x)*) };
($($x:tt)*) => { $crate::func_body!(1, $($x)*) };
}

#[macro_export]
macro_rules! func_body {
($rc:expr, fn $name:ident($( $in:ident $(: [$in_size:expr])? ),*): [$size:expr] $lair:tt) => {{
$(let $in = $crate::var!($in $(, $in_size)?);)*
$crate::lair::expr::FuncE {
name: $crate::lair::Name(stringify!($name)),
Expand All @@ -9,9 +15,10 @@ macro_rules! func {
input_params: [$($crate::var!($in $(, $in_size)?)),*].into(),
output_size: $size,
body: $crate::block_init!($lair),
rc: $rc,
}
}};
(partial fn $name:ident($( $in:ident $(: [$in_size:expr])? ),*): [$size:expr] $lair:tt) => {{
($rc:expr, partial fn $name:ident($( $in:ident $(: [$in_size:expr])? ),*): [$size:expr] $lair:tt) => {{
$(let $in = $crate::var!($in $(, $in_size)?);)*
$crate::lair::expr::FuncE {
name: $crate::lair::Name(stringify!($name)),
Expand All @@ -20,9 +27,10 @@ macro_rules! func {
input_params: [$($crate::var!($in $(, $in_size)?)),*].into(),
output_size: $size,
body: $crate::block_init!($lair),
rc: $rc,
}
}};
(invertible fn $name:ident($( $in:ident $(: [$in_size:expr])? ),*): [$size:expr] $lair:tt) => {{
($rc:expr, invertible fn $name:ident($( $in:ident $(: [$in_size:expr])? ),*): [$size:expr] $lair:tt) => {{
$(let $in = $crate::var!($in $(, $in_size)?);)*
$crate::lair::expr::FuncE {
name: $crate::lair::Name(stringify!($name)),
Expand All @@ -31,9 +39,10 @@ macro_rules! func {
input_params: [$($crate::var!($in $(, $in_size)?)),*].into(),
output_size: $size,
body: $crate::block_init!($lair),
rc: $rc,
}
}};
(invertible partial fn $name:ident($( $in:ident $(: [$in_size:expr])? ),*): [$size:expr] $lair:tt) => {{
($rc:expr, invertible partial fn $name:ident($( $in:ident $(: [$in_size:expr])? ),*): [$size:expr] $lair:tt) => {{
$(let $in = $crate::var!($in $(, $in_size)?);)*
$crate::lair::expr::FuncE {
name: $crate::lair::Name(stringify!($name)),
Expand All @@ -42,6 +51,7 @@ macro_rules! func {
input_params: [$($crate::var!($in $(, $in_size)?)),*].into(),
output_size: $size,
body: $crate::block_init!($lair),
rc: $rc,
}
}};
}
Expand Down
1 change: 1 addition & 0 deletions src/lair/toplevel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ impl<F: Field + Ord> FuncE<F> {
body,
input_size: self.input_params.total_size(),
output_size: self.output_size,
rc: self.rc,
}
}
}
Expand Down
66 changes: 65 additions & 1 deletion src/lair/trace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,74 @@ impl<'a, T> ColumnMutSlice<'a, T> {
}

impl<'a, F: PrimeField32, H: Chipset<F>> FuncChip<'a, F, H> {
/// Per-row parallel trace generation
pub fn generate_trace_rc(&self, shard: &Shard<'_, F>, rc_index: usize) -> RowMajorMatrix<F> {
let func_queries = &shard.queries().func_queries()[self.func.index];
let range = shard.get_func_range_rc(self.func, rc_index);
let offset = range.start;
let width = self.width();
let non_dummy_height = range.len();
let height = non_dummy_height.next_power_of_two();
let mut rows = vec![F::zero(); height * width];
// initializing nonces
rows.chunks_mut(width)
.enumerate()
.for_each(|(i, row)| row[0] = F::from_canonical_usize(i + offset));
let non_dummies = &mut rows[0..non_dummy_height * width];
non_dummies
.par_chunks_mut(width)
.enumerate()
.for_each(|(i, row)| {
let (args, result) = func_queries.get_index(i + offset).unwrap();
let index = &mut ColumnIndex::default();
let slice = &mut ColumnMutSlice::from_slice(row, self.layout_sizes);
let requires = result.requires.iter();
let mut depth_requires = result.depth_requires.iter();
let queries = shard.queries();
let query_map = &queries.func_queries()[self.func.index];
let lookup = query_map
.get(args)
.expect("Cannot find query result")
.provide;
let provide = lookup.into_provide();
result
.output
.as_ref()
.unwrap()
.iter()
.for_each(|&o| slice.push_output(index, o));
slice.push_aux(index, provide.last_nonce);
slice.push_aux(index, provide.last_count);
// provenance and range check
if self.func.partial {
let num_requires = (DEPTH_W / 2) + (DEPTH_W % 2);
let depth: [u8; DEPTH_W] = result.depth.to_le_bytes();
for b in depth {
slice.push_aux(index, F::from_canonical_u8(b));
}
for _ in 0..num_requires {
let lookup = depth_requires.next().expect("Not enough require hints");
slice.push_require(index, lookup.into_require());
}
}
self.func.populate_row(
args,
index,
slice,
queries,
requires,
self.toplevel,
result.depth,
depth_requires,
);
});
RowMajorMatrix::new(rows, width)
}

/// Per-row parallel trace generation
pub fn generate_trace(&self, shard: &Shard<'_, F>) -> RowMajorMatrix<F> {
let func_queries = &shard.queries().func_queries()[self.func.index];
let range = shard.get_func_range(self.func.index);
let range = shard.get_func_range(self.func);
let width = self.width();
let non_dummy_height = range.len();
let height = non_dummy_height.next_power_of_two();
Expand Down
11 changes: 11 additions & 0 deletions src/lurk/eval.rs
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,11 @@ impl EvalErr {
}
}

const EVAL_RC: usize = 8;
const APPLY_RC: usize = 4;
const BINOP_RC: usize = 4;
const LOOKUP_RC: usize = 4;

pub fn lurk_main<F: AbstractField>() -> FuncE<F> {
func!(
partial fn lurk_main(full_expr_tag: [8], expr_digest: [8], env_digest: [8]): [16] {
Expand Down Expand Up @@ -301,6 +306,7 @@ fn ingress_builtin<F: AbstractField + Ord>(builtins: &BuiltinMemo<'_, F>) -> Fun
input_params: [input_var].into(),
output_size: 1,
body: BlockE { ops, ctrl },
rc: 1,
}
}

Expand Down Expand Up @@ -435,6 +441,7 @@ fn egress_builtin<F: AbstractField + Ord>(builtins: &BuiltinMemo<'_, F>) -> Func
input_params: [input_var].into(),
output_size: 8,
body: BlockE { ops, ctrl },
rc: 1,
}
}

Expand Down Expand Up @@ -565,6 +572,7 @@ pub fn big_num_lessthan<F>() -> FuncE<F> {

pub fn eval<F: AbstractField + Ord>(builtins: &BuiltinMemo<'_, F>) -> FuncE<F> {
func!(
#[RC = EVAL_RC]
partial fn eval(expr_tag, expr, env): [2] {
// Constants, tags, etc
let t = builtins.index("t");
Expand Down Expand Up @@ -1131,6 +1139,7 @@ pub fn eval_begin<F: AbstractField + Ord>(builtins: &BuiltinMemo<'_, F>) -> Func

pub fn eval_binop_num<F: AbstractField + Ord>(builtins: &BuiltinMemo<'_, F>) -> FuncE<F> {
func!(
#[RC = BINOP_RC]
partial fn eval_binop_num(head, exp1_tag, exp1, exp2_tag, exp2, env): [2] {
let err_tag = Tag::Err;
let num_tag = Tag::Num;
Expand Down Expand Up @@ -1685,6 +1694,7 @@ pub fn eval_letrec<F: AbstractField + Ord>() -> FuncE<F> {

pub fn apply<F: AbstractField + Ord>() -> FuncE<F> {
func!(
#[RC = APPLY_RC]
partial fn apply(head_tag, head, args_tag, args, args_env): [2] {
// Constants, tags, etc
let err_tag = Tag::Err;
Expand Down Expand Up @@ -1765,6 +1775,7 @@ pub fn apply<F: AbstractField + Ord>() -> FuncE<F> {

pub fn env_lookup<F: AbstractField>() -> FuncE<F> {
func!(
#[RC = LOOKUP_RC]
fn env_lookup(x_digest: [8], env): [2] {
if !env {
let err_tag = Tag::Err;
Expand Down

0 comments on commit 6a16b1b

Please sign in to comment.