diff --git a/crates/chia-consensus/benches/merkle-set.rs b/crates/chia-consensus/benches/merkle-set.rs index 4dd22b08a..7b6448968 100644 --- a/crates/chia-consensus/benches/merkle-set.rs +++ b/crates/chia-consensus/benches/merkle-set.rs @@ -1,4 +1,4 @@ -use chia_consensus::merkle_tree::MerkleSet; +use chia_consensus::merkle_tree::{validate_merkle_proof, MerkleSet}; use criterion::{black_box, criterion_group, criterion_main, Criterion}; use rand::rngs::SmallRng; use rand::{Rng, SeedableRng}; @@ -9,8 +9,9 @@ fn run(c: &mut Criterion) { let mut rng = SmallRng::seed_from_u64(1337); - let mut leafs = Vec::<[u8; 32]>::with_capacity(1000); - for _ in 0..1000 { + const NUM_LEAFS: usize = 1000; + let mut leafs = Vec::<[u8; 32]>::with_capacity(NUM_LEAFS); + for _ in 0..NUM_LEAFS { let mut item = [0_u8; 32]; rng.fill(&mut item); leafs.push(item); @@ -24,7 +25,10 @@ fn run(c: &mut Criterion) { }) }); - let tree = MerkleSet::from_leafs(&mut leafs); + // build the tree from the first half of the leafs. The second half are + // examples of leafs *not* included in the tree, to also cover + // proofs-of-exclusion + let tree = MerkleSet::from_leafs(&mut leafs[0..NUM_LEAFS / 2]); group.bench_function("generate_proof", |b| { b.iter(|| { @@ -45,7 +49,7 @@ fn run(c: &mut Criterion) { ); } - group.bench_function("deserialize_proof", |b| { + group.bench_function("parse_proof", |b| { b.iter(|| { let start = Instant::now(); for p in &proofs { @@ -54,6 +58,18 @@ fn run(c: &mut Criterion) { start.elapsed() }) }); + let root = &tree.get_root(); + use std::iter::zip; + group.bench_function("validate_merkle_proof", |b| { + b.iter(|| { + let start = Instant::now(); + for (p, leaf) in zip(&proofs, &leafs) { + let _ = + black_box(validate_merkle_proof(&p, leaf, root).expect("expect valid proof")); + } + start.elapsed() + }) + }); } criterion_group!(merkle_set, run); diff --git a/crates/chia-consensus/fuzz/fuzz_targets/deserialize-proof.rs b/crates/chia-consensus/fuzz/fuzz_targets/deserialize-proof.rs index 89b6ce2c6..dcc0dfef8 100644 --- a/crates/chia-consensus/fuzz/fuzz_targets/deserialize-proof.rs +++ b/crates/chia-consensus/fuzz/fuzz_targets/deserialize-proof.rs @@ -1,8 +1,13 @@ #![no_main] +use chia_consensus::merkle_tree::{validate_merkle_proof, MerkleSet}; +use hex_literal::hex; use libfuzzer_sys::fuzz_target; -use chia_consensus::merkle_tree::MerkleSet; - fuzz_target!(|data: &[u8]| { let _r = MerkleSet::from_proof(data); + let dummy: [u8; 32] = hex!("cccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccc"); + assert!(!matches!( + validate_merkle_proof(data, &dummy, &dummy), + Ok(true) + )); }); diff --git a/crates/chia-consensus/fuzz/fuzz_targets/merkle-set.rs b/crates/chia-consensus/fuzz/fuzz_targets/merkle-set.rs index 16594d53e..b93201654 100644 --- a/crates/chia-consensus/fuzz/fuzz_targets/merkle-set.rs +++ b/crates/chia-consensus/fuzz/fuzz_targets/merkle-set.rs @@ -1,5 +1,6 @@ #![no_main] -use chia_consensus::merkle_tree::MerkleSet; +use chia_consensus::merkle_tree::{validate_merkle_proof, MerkleSet}; +use clvmr::sha2::{Digest, Sha256}; use libfuzzer_sys::fuzz_target; fuzz_target!(|data: &[u8]| { @@ -12,18 +13,26 @@ fuzz_target!(|data: &[u8]| { } let tree = MerkleSet::from_leafs(&mut leafs); + let root = tree.get_root(); - for item in &leafs { - let (true, proof) = tree.generate_proof(item).expect("failed to generate proof") else { - panic!("item is expected to exist"); - }; + // this is a leaf that's *not* in the tree, to also cover + // proofs-of-exclusion + let mut hasher = Sha256::new(); + hasher.update(data); + leafs.push(hasher.finalize().into()); + + for (idx, item) in leafs.iter().enumerate() { + let expect_included = idx < num_leafs; + let (included, proof) = tree.generate_proof(item).expect("failed to generate proof"); + assert_eq!(included, expect_included); let rebuilt = MerkleSet::from_proof(&proof).expect("failed to parse proof"); + let (included, _junk) = rebuilt + .generate_proof(item) + .expect("failed to validate proof"); + assert_eq!(rebuilt.get_root(), root); + assert_eq!(included, expect_included); assert!( - rebuilt - .generate_proof(item) - .expect("failed to validate proof") - .0 + validate_merkle_proof(&proof, item, &root).expect("proof failed") == expect_included ); - assert_eq!(rebuilt.get_root(), tree.get_root()); } }); diff --git a/crates/chia-consensus/src/merkle_tree.rs b/crates/chia-consensus/src/merkle_tree.rs index a9668d35f..7078f1e59 100644 --- a/crates/chia-consensus/src/merkle_tree.rs +++ b/crates/chia-consensus/src/merkle_tree.rs @@ -333,6 +333,21 @@ fn pad_middles_for_proof_gen(proof: &mut Vec, left: &[u8; 32], right: &[u8; } } +// returns true if the item is included in the tree with the specified root, +// given the proof, or false if it's not included in the tree. +// If neither can be proven, it fails with SetError +pub fn validate_merkle_proof( + proof: &[u8], + item: &[u8; 32], + root: &[u8; 32], +) -> Result { + let tree = MerkleSet::from_proof(proof)?; + if tree.get_root() != *root { + return Err(SetError); + } + Ok(tree.generate_proof(item)?.0) +} + #[cfg(feature = "py-bindings")] #[pymethods] impl MerkleSet { @@ -607,7 +622,7 @@ mod tests { assert_eq!(rebuilt.get_root(), root); let (included, new_proof) = rebuilt.generate_proof(&item).unwrap(); assert!(included); - assert_eq!(new_proof, vec![]); + assert_eq!(new_proof, Vec::::new()); assert_eq!(rebuilt.get_root(), root); } @@ -622,7 +637,7 @@ mod tests { let rebuilt = MerkleSet::from_proof(&proof).expect("failed to parse proof"); let (included, new_proof) = rebuilt.generate_proof(&item).unwrap(); assert!(!included); - assert_eq!(new_proof, vec![]); + assert_eq!(new_proof, Vec::::new()); assert_eq!(rebuilt.get_root(), root); } } diff --git a/tests/test_merkle_set.py b/tests/test_merkle_set.py index 0702a7d7d..efe7f167d 100644 --- a/tests/test_merkle_set.py +++ b/tests/test_merkle_set.py @@ -5,35 +5,34 @@ import time from chia_rs import ( MerkleSet as RustMerkleSet, - deserialize_proof as ru_deserialize_proof, compute_merkle_set_root, + confirm_included_already_hashed as ru_confirm_included_already_hashed, + confirm_not_included_already_hashed as ru_confirm_not_included_already_hashed, ) from random import Random from merkle_set import ( MerkleSet as PythonMerkleSet, - deserialize_proof as py_deserialize_proof, + confirm_included_already_hashed as py_confirm_included_already_hashed, + confirm_not_included_already_hashed as py_confirm_not_included_already_hashed, ) from chia_rs.sized_bytes import bytes32 def check_proof( proof: bytes, - deserialize: Callable[[bytes], Any], + confirm_included_already_hashed: Callable[[bytes32, bytes32, bytes], bool], + confirm_not_included_already_hashed: Callable[[bytes32, bytes32, bytes], bool], *, root: bytes32, item: bytes32, expect_included: bool = True, ) -> None: - proof_tree = deserialize(proof) - assert proof_tree.get_root() == root - included, junk = proof_tree.is_included_already_hashed(item) - assert included == expect_included - - # the rust implementation does not round-trip proofs of exclusions. - # doing so requires additional complexity (and cost). - # rust deliberately generates an empty proof from a tree generated from a - # proof - assert junk == b"" or junk == proof + if expect_included: + assert confirm_included_already_hashed(root, item, proof) + assert not confirm_not_included_already_hashed(root, item, proof) + else: + assert not confirm_included_already_hashed(root, item, proof) + assert confirm_not_included_already_hashed(root, item, proof) def check_tree(leafs: List[bytes32]) -> None: @@ -51,8 +50,8 @@ def check_tree(leafs: List[bytes32]) -> None: assert py_proof == ru_proof proof = ru_proof - check_proof(proof, py_deserialize_proof, root=root, item=item) - check_proof(proof, ru_deserialize_proof, root=root, item=item) + check_proof(proof, py_confirm_included_already_hashed, py_confirm_not_included_already_hashed, root=root, item=item) + check_proof(proof, ru_confirm_included_already_hashed, ru_confirm_not_included_already_hashed, root=root, item=item) for i in range(256): item = bytes32([i] + [2] * 31) @@ -63,12 +62,8 @@ def check_tree(leafs: List[bytes32]) -> None: assert py_proof == ru_proof proof = ru_proof - check_proof( - proof, py_deserialize_proof, root=root, item=item, expect_included=False - ) - check_proof( - proof, ru_deserialize_proof, root=root, item=item, expect_included=False - ) + check_proof(proof, py_confirm_included_already_hashed, py_confirm_not_included_already_hashed, root=root, item=item, expect_included=False) + check_proof(proof, ru_confirm_included_already_hashed, ru_confirm_not_included_already_hashed, root=root, item=item, expect_included=False) def h(b: str) -> bytes32: diff --git a/wheel/generate_type_stubs.py b/wheel/generate_type_stubs.py index 2218b85f6..b28ac22ed 100644 --- a/wheel/generate_type_stubs.py +++ b/wheel/generate_type_stubs.py @@ -289,6 +289,18 @@ def deserialize_proof( proof: bytes ) -> MerkleSet: ... +def confirm_included_already_hashed( + root: bytes32, + item: bytes32, + proof: bytes, +) -> bool: ... + +def confirm_not_included_already_hashed( + root: bytes32, + item: bytes32, + proof: bytes, +) -> bool: ... + COND_ARGS_NIL: int = ... NO_UNKNOWN_CONDS: int = ... STRICT_ARGS_COUNT: int = ... diff --git a/wheel/python/chia_rs/chia_rs.pyi b/wheel/python/chia_rs/chia_rs.pyi index d4f5f04d4..c0ea87f70 100644 --- a/wheel/python/chia_rs/chia_rs.pyi +++ b/wheel/python/chia_rs/chia_rs.pyi @@ -37,6 +37,18 @@ def deserialize_proof( proof: bytes ) -> MerkleSet: ... +def confirm_included_already_hashed( + root: bytes32, + item: bytes32, + proof: bytes, +) -> bool: ... + +def confirm_not_included_already_hashed( + root: bytes32, + item: bytes32, + proof: bytes, +) -> bool: ... + COND_ARGS_NIL: int = ... NO_UNKNOWN_CONDS: int = ... STRICT_ARGS_COUNT: int = ... diff --git a/wheel/src/api.rs b/wheel/src/api.rs index 6e2943850..faecf1da7 100644 --- a/wheel/src/api.rs +++ b/wheel/src/api.rs @@ -12,7 +12,7 @@ use chia_consensus::gen::run_puzzle::run_puzzle as native_run_puzzle; use chia_consensus::gen::solution_generator::solution_generator as native_solution_generator; use chia_consensus::gen::solution_generator::solution_generator_backrefs as native_solution_generator_backrefs; use chia_consensus::merkle_set::compute_merkle_set_root as compute_merkle_root_impl; -use chia_consensus::merkle_tree::MerkleSet; +use chia_consensus::merkle_tree::{validate_merkle_proof, MerkleSet}; use chia_protocol::{ BlockRecord, Bytes32, ChallengeBlockInfo, ChallengeChainSubSlot, ClassgroupElement, Coin, CoinSpend, CoinState, CoinStateUpdate, EndOfSubSlotBundle, Foliage, FoliageBlockData, @@ -78,8 +78,24 @@ pub fn compute_merkle_set_root<'p>( } #[pyfunction] -pub fn deserialize_proof(proof: &[u8]) -> PyResult { - MerkleSet::from_proof(proof).map_err(|_| PyValueError::new_err("Invalid proof")) +pub fn confirm_included_already_hashed( + root: Bytes32, + item: Bytes32, + proof: &[u8], +) -> PyResult { + validate_merkle_proof(proof, (&item).into(), (&root).into()) + .map_err(|_| PyValueError::new_err("Invalid proof")) +} + +#[pyfunction] +pub fn confirm_not_included_already_hashed( + root: Bytes32, + item: Bytes32, + proof: &[u8], +) -> PyResult { + validate_merkle_proof(proof, (&item).into(), (&root).into()) + .map_err(|_| PyValueError::new_err("Invalid proof")) + .map(|r| !r) } #[pyfunction] @@ -357,7 +373,8 @@ pub fn chia_rs(_py: Python, m: &PyModule) -> PyResult<()> { // merkle tree m.add_class::()?; - m.add_function(wrap_pyfunction!(deserialize_proof, m)?)?; + m.add_function(wrap_pyfunction!(confirm_included_already_hashed, m)?)?; + m.add_function(wrap_pyfunction!(confirm_not_included_already_hashed, m)?)?; // clvm functions m.add("COND_ARGS_NIL", COND_ARGS_NIL)?;