From 9065c324ca862017c77b51185c0c0a6197f6c700 Mon Sep 17 00:00:00 2001 From: arvidn Date: Fri, 19 Apr 2024 16:42:22 +0200 Subject: [PATCH] add high-level function to validate a merkle set proof, given one item and the root hash --- crates/chia-consensus/benches/merkle-set.rs | 15 +++++++++++++-- crates/chia-consensus/src/merkle_tree.rs | 8 ++++++++ tests/test_merkle_set.py | 3 +++ wheel/src/api.rs | 9 ++++++++- 4 files changed, 32 insertions(+), 3 deletions(-) diff --git a/crates/chia-consensus/benches/merkle-set.rs b/crates/chia-consensus/benches/merkle-set.rs index 3621b6165..d92de79da 100644 --- a/crates/chia-consensus/benches/merkle-set.rs +++ b/crates/chia-consensus/benches/merkle-set.rs @@ -1,4 +1,4 @@ -use chia_consensus::merkle_tree::MerkleSet; +use chia_consensus::merkle_tree::{validate_merkle_proof, MerkleSet}; use criterion::{black_box, criterion_group, criterion_main, Criterion}; use rand::rngs::SmallRng; use rand::{Rng, SeedableRng}; @@ -45,7 +45,7 @@ fn run(c: &mut Criterion) { ); } - group.bench_function("deserialize_proof", |b| { + group.bench_function("parse_proof", |b| { b.iter(|| { let start = Instant::now(); for p in &proofs { @@ -54,6 +54,17 @@ fn run(c: &mut Criterion) { start.elapsed() }) }); + let root = &tree.get_root(); + use std::iter::zip; + group.bench_function("validate_merkle_proof", |b| { + b.iter(|| { + let start = Instant::now(); + for (p, leaf) in zip(&proofs, &leafs) { + assert!(validate_merkle_proof(&p, leaf, root)) + } + start.elapsed() + }) + }); } criterion_group!(merkle_set, run); diff --git a/crates/chia-consensus/src/merkle_tree.rs b/crates/chia-consensus/src/merkle_tree.rs index b6067f85b..5d5c577cf 100644 --- a/crates/chia-consensus/src/merkle_tree.rs +++ b/crates/chia-consensus/src/merkle_tree.rs @@ -309,6 +309,14 @@ fn pad_middles_for_proof_gen(proof: &mut Vec, left: &[u8; 32], right: &[u8; } } +pub fn validate_merkle_proof(proof: &[u8], item: &[u8; 32], root: &[u8; 32]) -> bool { + let tree = MerkleSet::from_proof(proof).expect("failed to parse proof"); + if tree.get_root() != *root { + return false; + } + matches!(tree.generate_proof(item), Ok(Some(_))) +} + #[cfg(feature = "py-bindings")] #[pymethods] impl MerkleSet { diff --git a/tests/test_merkle_set.py b/tests/test_merkle_set.py index c140ce230..91e903731 100644 --- a/tests/test_merkle_set.py +++ b/tests/test_merkle_set.py @@ -7,6 +7,7 @@ MerkleSet as RustMerkleSet, deserialize_proof as ru_deserialize_proof, compute_merkle_set_root, + validate_merkle_proof, ) from random import Random from merkle_set import ( @@ -29,6 +30,8 @@ def check_proof( assert included assert proof == proof2 + assert validate_merkle_proof(proof, item, root) + def check_tree(leafs: List[bytes32]) -> None: ru_tree = RustMerkleSet(leafs) diff --git a/wheel/src/api.rs b/wheel/src/api.rs index 6e2943850..fa55a9fca 100644 --- a/wheel/src/api.rs +++ b/wheel/src/api.rs @@ -12,7 +12,7 @@ use chia_consensus::gen::run_puzzle::run_puzzle as native_run_puzzle; use chia_consensus::gen::solution_generator::solution_generator as native_solution_generator; use chia_consensus::gen::solution_generator::solution_generator_backrefs as native_solution_generator_backrefs; use chia_consensus::merkle_set::compute_merkle_set_root as compute_merkle_root_impl; -use chia_consensus::merkle_tree::MerkleSet; +use chia_consensus::merkle_tree::{validate_merkle_proof, MerkleSet}; use chia_protocol::{ BlockRecord, Bytes32, ChallengeBlockInfo, ChallengeChainSubSlot, ClassgroupElement, Coin, CoinSpend, CoinState, CoinStateUpdate, EndOfSubSlotBundle, Foliage, FoliageBlockData, @@ -82,6 +82,12 @@ pub fn deserialize_proof(proof: &[u8]) -> PyResult { MerkleSet::from_proof(proof).map_err(|_| PyValueError::new_err("Invalid proof")) } +#[pyfunction] +#[pyo3(name = "validate_merkle_proof")] +pub fn py_validate_merkle_proof(proof: &[u8], item: Bytes32, root: Bytes32) -> bool { + validate_merkle_proof(proof, (&item).into(), (&root).into()) +} + #[pyfunction] pub fn tree_hash(py: Python, blob: PyBuffer) -> PyResult<&PyBytes> { if !blob.is_c_contiguous() { @@ -358,6 +364,7 @@ pub fn chia_rs(_py: Python, m: &PyModule) -> PyResult<()> { // merkle tree m.add_class::()?; m.add_function(wrap_pyfunction!(deserialize_proof, m)?)?; + m.add_function(wrap_pyfunction!(py_validate_merkle_proof, m)?)?; // clvm functions m.add("COND_ARGS_NIL", COND_ARGS_NIL)?;