Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add middleware check_witness #356

Merged
merged 3 commits into from
Sep 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion halo2_debug/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,4 @@ rand_chacha = "0.3"
rayon = "1.8"

[features]
vector-tests = []
vector-tests = []
178 changes: 178 additions & 0 deletions halo2_debug/src/check_witness.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
use crate::display::FDisp;
use halo2_middleware::circuit::{Any, CompiledCircuit, ExpressionMid, VarMid};
use halo2_middleware::ff::PrimeField;
use rand_chacha::ChaCha20Rng;
use rand_core::SeedableRng;
use std::collections::HashSet;

fn rotate(n: usize, offset: usize, rotation: i32) -> usize {
let offset = offset as i32 + rotation;
if offset < 0 {
(offset + n as i32) as usize
} else if offset >= n as i32 {
(offset - n as i32) as usize
} else {
offset as usize
}
}

struct Assignments<'a, F: PrimeField> {
public: &'a [Vec<F>],
witness: &'a [Vec<F>],
fixed: &'a [Vec<F>],
blinders: &'a [Vec<F>],
blinded: &'a [bool],
usable_rows: usize,
n: usize,
}

impl<'a, F: PrimeField> Assignments<'a, F> {
// Query a particular Column at an offset
fn query(&self, column_type: Any, column_index: usize, offset: usize) -> F {
match column_type {
Any::Instance => self.public[column_index][offset],
Any::Advice => {
if offset >= self.usable_rows && self.blinded[column_index] {
self.blinders[column_index][offset - self.usable_rows]
} else {
self.witness[column_index][offset]
}
}
Any::Fixed => self.fixed[column_index][offset],
}
}

// Evaluate an expression using the assingment data
fn eval(&self, expr: &ExpressionMid<F>, offset: usize) -> F {
expr.evaluate(
&|s| s,
&|v| match v {
VarMid::Query(q) => {
let offset = rotate(self.n, offset, q.rotation.0);
self.query(q.column_type, q.column_index, offset)
}
VarMid::Challenge(_c) => unimplemented!(),
},
&|ne| -ne,
&|a, b| a + b,
&|a, b| a * b,
)
}

// Evaluate multiple expressions and return the result as concatenated bytes from the field
// element representation.
fn eval_to_buf(&self, f_len: usize, exprs: &[ExpressionMid<F>], offset: usize) -> Vec<u8> {
let mut eval_buf = Vec::with_capacity(exprs.len() * f_len);
for eval in exprs.iter().map(|e| self.eval(e, offset)) {
eval_buf.extend_from_slice(eval.to_repr().as_ref())
}
eval_buf
}
}

/// Check that the wintess passes all the constraints defined by the circuit. Panics if any
/// constraint is not satisfied.
pub fn check_witness<F: PrimeField>(
circuit: &CompiledCircuit<F>,
k: u32,
blinding_rows: usize,
witness: &[Vec<F>],
public: &[Vec<F>],
) {
let n = 2usize.pow(k);
let usable_rows = n - blinding_rows;
let cs = &circuit.cs;

// Calculate blinding values
let mut rng = ChaCha20Rng::seed_from_u64(0xdeadbeef);
let mut blinders = vec![vec![F::ZERO; blinding_rows]; cs.num_advice_columns];
for column_blinders in blinders.iter_mut() {
for v in column_blinders.iter_mut() {
*v = F::random(&mut rng);
}
}

let mut blinded = vec![true; cs.num_advice_columns];
for advice_column_index in &cs.unblinded_advice_columns {
blinded[*advice_column_index] = false;
}

let assignments = Assignments {
public,
witness,
fixed: &circuit.preprocessing.fixed,
blinders: &blinders,
blinded: &blinded,
usable_rows,
n,
};

// Verify all gates
for (i, gate) in cs.gates.iter().enumerate() {
for offset in 0..n {
let res = assignments.eval(&gate.poly, offset);
if !res.is_zero_vartime() {
panic!(
"Unsatisfied gate {} \"{}\" at offset {}",
i, gate.name, offset
);
}
}
}

// Verify all copy constraints
for (lhs, rhs) in &circuit.preprocessing.permutation.copies {
let value_lhs = assignments.query(lhs.column.column_type, lhs.column.index, lhs.row);
let value_rhs = assignments.query(rhs.column.column_type, rhs.column.index, rhs.row);
if value_lhs != value_rhs {
panic!(
"Unsatisfied copy constraint ({:?},{:?}): {} != {}",
lhs,
rhs,
FDisp(&value_lhs),
FDisp(&value_rhs)
)
}
}

// Verify all lookups
let f_len = F::Repr::default().as_ref().len();
for (i, lookup) in cs.lookups.iter().enumerate() {
let mut virtual_table = HashSet::new();
for offset in 0..usable_rows {
let table_eval_buf = assignments.eval_to_buf(f_len, &lookup.table_expressions, offset);
virtual_table.insert(table_eval_buf);
}
for offset in 0..usable_rows {
let input_eval_buf = assignments.eval_to_buf(f_len, &lookup.input_expressions, offset);
if !virtual_table.contains(&input_eval_buf) {
panic!(
"Unsatisfied lookup {} \"{}\" at offset {}",
i, lookup.name, offset
);
}
}
}

// Verify all shuffles
for (i, shuffle) in cs.shuffles.iter().enumerate() {
let mut virtual_shuffle = Vec::with_capacity(usable_rows);
for offset in 0..usable_rows {
let shuffle_eval_buf =
assignments.eval_to_buf(f_len, &shuffle.shuffle_expressions, offset);
virtual_shuffle.push(shuffle_eval_buf);
}
let mut virtual_input = Vec::with_capacity(usable_rows);
for offset in 0..usable_rows {
let input_eval_buf = assignments.eval_to_buf(f_len, &shuffle.input_expressions, offset);
virtual_input.push(input_eval_buf);
}

virtual_shuffle.sort_unstable();
virtual_input.sort_unstable();

if virtual_input != virtual_shuffle {
panic!("Unsatisfied shuffle {} \"{}\"", i, shuffle.name);
}
}
}
7 changes: 5 additions & 2 deletions halo2_debug/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
mod check_witness;
pub mod display;

pub use check_witness::check_witness;

use rand_chacha::ChaCha20Rng;
use rand_core::SeedableRng;
use tiny_keccak::Hasher;
Expand Down Expand Up @@ -34,5 +39,3 @@ pub fn test_result<F: FnOnce() -> Vec<u8> + Send>(test: F, _expected: &str) -> V

result
}

pub mod display;
1 change: 1 addition & 0 deletions p3_frontend/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,4 @@ p3-keccak-air = { git = "https://github.com/Plonky3/Plonky3", rev = "7b5b8a6" }
p3-keccak = { git = "https://github.com/Plonky3/Plonky3", rev = "7b5b8a6" }
p3-util = { git = "https://github.com/Plonky3/Plonky3", rev = "7b5b8a6" }
rand = "0.8.5"
halo2_debug = { path = "../halo2_debug" }
63 changes: 6 additions & 57 deletions p3_frontend/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
extern crate alloc;

use halo2_middleware::circuit::{
Any, Cell, ColumnMid, CompiledCircuit, ConstraintSystemMid, ExpressionMid, GateMid,
Preprocessing, QueryMid, VarMid,
Any, Cell, ColumnMid, ConstraintSystemMid, ExpressionMid, GateMid, Preprocessing, QueryMid,
VarMid,
};
use halo2_middleware::ff::{Field, PrimeField};
use halo2_middleware::permutation;
Expand Down Expand Up @@ -184,7 +184,7 @@ fn extract_copy_public<F: PrimeField + Hash>(
pub fn get_public_inputs<F: Field>(
preprocessing_info: &PreprocessingInfo,
size: usize,
witness: &[Option<Vec<F>>],
witness: &[Vec<F>],
) -> Vec<Vec<F>> {
if preprocessing_info.num_public_values == 0 {
return Vec::new();
Expand All @@ -196,7 +196,7 @@ pub fn get_public_inputs<F: Field>(
Location::LastRow => size - 1,
Location::Transition => unreachable!(),
};
public_inputs[*public_index] = witness[cell.0].as_ref().unwrap()[offset]
public_inputs[*public_index] = witness[cell.0][offset]
}
vec![public_inputs]
}
Expand Down Expand Up @@ -293,7 +293,7 @@ where
(cs, preprocessing_info)
}

pub fn trace_to_wit<F: Field>(k: u32, trace: RowMajorMatrix<FWrap<F>>) -> Vec<Option<Vec<F>>> {
pub fn trace_to_wit<F: Field>(k: u32, trace: RowMajorMatrix<FWrap<F>>) -> Vec<Vec<F>> {
let n = 2usize.pow(k);
let num_columns = trace.width;
let mut witness = vec![vec![F::ZERO; n]; num_columns];
Expand All @@ -302,56 +302,5 @@ pub fn trace_to_wit<F: Field>(k: u32, trace: RowMajorMatrix<FWrap<F>>) -> Vec<Op
witness[column_index][row_offset] = row[column_index].0;
}
}
witness.into_iter().map(Some).collect()
}

// TODO: Move to middleware
pub fn check_witness<F: Field>(
circuit: &CompiledCircuit<F>,
k: u32,
witness: &[Option<Vec<F>>],
public: &[Vec<F>],
) {
let n = 2usize.pow(k);
let cs = &circuit.cs;
let preprocessing = &circuit.preprocessing;
// TODO: Simulate blinding rows
// Verify all gates
for (i, gate) in cs.gates.iter().enumerate() {
for offset in 0..n {
let res = gate.poly.evaluate(
&|s| s,
&|v| match v {
VarMid::Query(q) => {
let offset = offset as i32 + q.rotation.0;
// TODO: Try to do mod n with a rust function
let offset = if offset < 0 {
(offset + n as i32) as usize
} else if offset >= n as i32 {
(offset - n as i32) as usize
} else {
offset as usize
};
match q.column_type {
Any::Instance => public[q.column_index][offset],
Any::Advice => witness[q.column_index].as_ref().unwrap()[offset],
Any::Fixed => preprocessing.fixed[q.column_index][offset],
}
}
VarMid::Challenge(_c) => unimplemented!(),
},
&|ne| -ne,
&|a, b| a + b,
&|a, b| a * b,
);
if !res.is_zero_vartime() {
println!(
"Unsatisfied gate {} \"{}\" at offset {}",
i, gate.name, offset
);
panic!("KO");
}
}
}
println!("Check witness: OK");
witness
}
13 changes: 9 additions & 4 deletions p3_frontend/tests/common/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,15 @@ use halo2_backend::{
Blake2bRead, Blake2bWrite, Challenge255, TranscriptReadBuffer, TranscriptWriterBuffer,
},
};
use halo2_debug::check_witness;
use halo2_debug::test_rng;
use halo2_middleware::circuit::CompiledCircuit;
use halo2_middleware::zal::impls::H2cEngine;
use halo2curves::bn256::{Bn256, Fr, G1Affine};
use p3_air::Air;
use p3_frontend::{
check_witness, compile_circuit_cs, compile_preprocessing, get_public_inputs, trace_to_wit,
CompileParams, FWrap, SymbolicAirBuilder,
compile_circuit_cs, compile_preprocessing, get_public_inputs, trace_to_wit, CompileParams,
FWrap, SymbolicAirBuilder,
};
use p3_matrix::dense::RowMajorMatrix;
use std::time::Instant;
Expand Down Expand Up @@ -50,8 +51,12 @@ where
let witness = trace_to_wit(k, trace);
let pis = get_public_inputs(&preprocessing_info, size, &witness);

check_witness(&compiled_circuit, k, &witness, &pis);
(compiled_circuit, witness, pis)
check_witness(&compiled_circuit, k, 5, &witness, &pis);
(
compiled_circuit,
witness.into_iter().map(Some).collect(),
pis,
)
}

pub(crate) fn setup_prove_verify(
Expand Down
Loading