Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: unify AddressValue, put send and receive on builder #1010

Merged
merged 3 commits into from
Jul 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 16 additions & 46 deletions recursion/core-v2/src/alu.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,12 @@ use p3_field::Field;
use p3_field::PrimeField32;
use p3_matrix::dense::RowMajorMatrix;
use p3_matrix::Matrix;
use sp1_core::air::AirInteraction;
use sp1_core::air::MachineAir;
use sp1_core::air::SP1AirBuilder;
use sp1_core::lookup::InteractionKind;
use sp1_core::utils::pad_to_power_of_two;
use sp1_derive::AlignedBorrow;
use std::borrow::BorrowMut;

use crate::*;
use crate::{builder::SP1RecursionAirBuilder, *};

pub const NUM_FIELD_ALU_COLS: usize = core::mem::size_of::<FieldAluCols<u8>>();

Expand All @@ -23,13 +20,13 @@ pub const NUM_FIELD_ALU_COLS: usize = core::mem::size_of::<FieldAluCols<u8>>();

// 26 columns
// pub struct ExtensionFieldALU {
// pub in1: AddressValue<F>,
// pub in2: AddressValue<F>,
// pub in1: AddressValue<A, V>,
// pub in2: AddressValue<A, V>,
// pub sum: Extension<F>,
// pub diff: Extension<F>,
// pub product: Extension<F>,
// pub quotient: Extension<F>,
// pub out: AddressValue<F>,
// pub out: AddressValue<A, V>,
// pub is_add: Bool<F>,
// pub is_diff: Bool<F>,
// pub is_mul: Bool<F>,
Expand All @@ -42,9 +39,9 @@ pub struct FieldAluChip {}
#[derive(AlignedBorrow, Debug, Clone, Copy)]
#[repr(C)]
pub struct FieldAluCols<F: Copy> {
pub in1: AddressValueBase<F>,
pub in2: AddressValueBase<F>,
pub out: AddressValueBase<F>,
pub in1: AddressValue<F, F>,
pub in2: AddressValue<F, F>,
pub out: AddressValue<F, F>,
pub sum: F,
pub diff: F,
pub product: F,
Expand Down Expand Up @@ -145,24 +142,9 @@ impl<F: PrimeField32> MachineAir<F> for FieldAluChip {

impl<AB> Air<AB> for FieldAluChip
where
AB: SP1AirBuilder,
AB: SP1RecursionAirBuilder,
{
fn eval(&self, builder: &mut AB) {
// TODO improve types to remove all this boilerplate
let encode = |avb: AddressValueBase<AB::Var>| -> Vec<AB::Expr> {
let AddressValueBase { addr, val } = avb;
let av: AddressValue<AB::Expr> = AddressValue {
addr: addr.into(),
val: Block([
val.into(),
AB::F::zero().into(),
AB::F::zero().into(),
AB::F::zero().into(),
]),
};
av.iter().cloned().collect::<Vec<_>>()
};

let main = builder.main();
let local = main.row_slice(0);
let local: &FieldAluCols<AB::Var> = (*local).borrow();
Expand Down Expand Up @@ -191,23 +173,11 @@ where
// local.is_real is 0 or 1
// builder.assert_zero(local.is_real * (AB::Expr::one() - local.is_real));

builder.receive(AirInteraction::new(
encode(local.in1),
local.is_real.into(), // is_real should be 0 or 1
InteractionKind::Memory,
));

builder.receive(AirInteraction::new(
encode(local.in2),
local.is_real.into(), // is_real should be 0 or 1
InteractionKind::Memory,
));

builder.send(AirInteraction::new(
encode(local.out),
local.mult.into(),
InteractionKind::Memory,
));
builder.receive_single(local.in1, local.is_real);

builder.receive_single(local.in2, local.is_real);

builder.send_single(local.out, local.mult);
}
}

Expand All @@ -227,9 +197,9 @@ mod tests {

let shard = ExecutionRecord::<F> {
alu_events: vec![AluEvent {
out: AddressValueBase::new(F::zero(), F::one()),
in1: AddressValueBase::new(F::zero(), F::one()),
in2: AddressValueBase::new(F::zero(), F::one()),
out: AddressValue::new(F::zero(), F::one()),
in1: AddressValue::new(F::zero(), F::one()),
in2: AddressValue::new(F::zero(), F::one()),
mult: F::zero(),
opcode: Opcode::AddF,
}],
Expand Down
Loading