From 51f4800e9fe4df23d8421fff9979edd7dbfb624f Mon Sep 17 00:00:00 2001 From: arvidn Date: Tue, 21 May 2024 13:31:49 +0200 Subject: [PATCH] extend BLS cache with items() and update() to support the current use case of serializing and returning the cache entries from the TX validation worker process back to the main process --- crates/chia-bls/src/cached_bls.rs | 48 ++++++++++++++++++++++++++----- tests/test_blscache.py | 38 +++++++++++++++++++++++- wheel/generate_type_stubs.py | 2 ++ wheel/python/chia_rs/chia_rs.pyi | 2 ++ 4 files changed, 82 insertions(+), 8 deletions(-) diff --git a/crates/chia-bls/src/cached_bls.rs b/crates/chia-bls/src/cached_bls.rs index 5b8324c34..4e4829f19 100644 --- a/crates/chia-bls/src/cached_bls.rs +++ b/crates/chia-bls/src/cached_bls.rs @@ -7,12 +7,15 @@ use sha2::{Digest, Sha256}; use crate::{aggregate_verify_gt, hash_to_g2}; use crate::{GTElement, PublicKey, Signature}; -/// This cache is a bit weird because it's trying to account for validating -/// mempool signatures versus block signatures. When validating block signatures, -/// there's not much point in caching the pairings because we're probably not going -/// to see them again unless there's a reorg. However, a spend in the mempool -/// is likely to reappear in a block later, so we can save having to do the pairing -/// again. So caching is primarily useful when synced and monitoring the mempool in real-time. +/// This is a cache of pairings of public keys and their corresponding message. +/// It accelerates aggregate verification when some public keys have already +/// been paired, and found in the cache. +/// We use it to cache pairings when validating transactions inserted into the +/// mempool, as many of those transactions are likely to show up in a full block +/// later. This makes it a lot cheaper to validate the full block. +/// However, validating a signature where we have no cached GT elements, the +/// aggregate_verify() primitive is faster. When long-syncing, that's +/// preferable. #[cfg_attr(feature = "py-bindings", pyo3::pyclass(name = "BLSCache"))] #[derive(Debug, Clone)] pub struct BlsCache { @@ -86,7 +89,7 @@ mod python { pybacked::PyBackedBytes, pymethods, types::{PyAnyMethods, PyList}, - Bound, PyResult, + Bound, PyObject, PyResult, }; #[pymethods] @@ -130,6 +133,37 @@ mod python { pub fn py_len(&self) -> PyResult { Ok(self.len()) } + + #[pyo3(name = "items")] + pub fn py_items(&self, py: pyo3::Python) -> PyResult { + use pyo3::prelude::*; + use pyo3::types::PyBytes; + let ret = PyList::empty_bound(py); + for (key, value) in self.cache.iter() { + ret.append(( + PyBytes::new_bound(py, key), + PyBytes::new_bound(py, &value.to_bytes()), + ))?; + } + Ok(ret.into()) + } + + #[pyo3(name = "update")] + pub fn py_update(&mut self, other: &Bound) -> PyResult<()> { + for item in other.borrow().iter()? { + let (key, value): (Vec, Vec) = item?.extract()?; + self.cache.put( + key.try_into() + .map_err(|_| PyValueError::new_err("invalid key"))?, + GTElement::from_bytes( + (&value[..]) + .try_into() + .map_err(|_| PyValueError::new_err("invalid GTElement"))?, + ), + ); + } + Ok(()) + } } } diff --git a/tests/test_blscache.py b/tests/test_blscache.py index 53fddcd1b..9cf9bce0b 100644 --- a/tests/test_blscache.py +++ b/tests/test_blscache.py @@ -59,7 +59,6 @@ def test_cache_limit() -> None: # old Python tests ported -# benchmark old vs new BLSCache def test_cached_bls(): cached_bls = BLSCache() n_keys = 10 @@ -115,6 +114,43 @@ def test_cached_bls(): assert cached_bls_old.aggregate_verify(pks_bytes, msgs, agg_sig, False, local_cache) +def test_cached_bls_flattening(): + cached_bls = BLSCache() + n_keys = 10 + seed = b"a" * 31 + sks = [AugSchemeMPL.key_gen(seed + bytes([i])) for i in range(n_keys)] + pks = [sk.get_g1() for sk in sks] + aggsig = AugSchemeMPL.aggregate( + [AugSchemeMPL.sign(sk, b"foobar", pk) for sk, pk in zip(sks, pks)] + ) + + assert cached_bls.aggregate_verify(pks, [b"foobar"] * n_keys, aggsig) + assert len(cached_bls.items()) == n_keys + gts = [ + bytes(pk.pair(AugSchemeMPL.g2_from_message(bytes(pk) + b"foobar"))) + for pk in pks + ] + for key, value in cached_bls.items(): + assert isinstance(key, bytes) + assert isinstance(value, bytes) + assert value in gts + gts.remove(value) + + cache_copy = BLSCache() + cache_copy.update(cached_bls.items()) + + assert len(cache_copy.items()) == n_keys + gts = [ + bytes(pk.pair(AugSchemeMPL.g2_from_message(bytes(pk) + b"foobar"))) + for pk in pks + ] + for key, value in cache_copy.items(): + assert isinstance(key, bytes) + assert isinstance(value, bytes) + assert value in gts + gts.remove(value) + + def test_cached_bls_repeat_pk(): cached_bls = BLSCache() n_keys = 400 diff --git a/wheel/generate_type_stubs.py b/wheel/generate_type_stubs.py index b0c9012d6..bd4e672aa 100644 --- a/wheel/generate_type_stubs.py +++ b/wheel/generate_type_stubs.py @@ -342,6 +342,8 @@ class BLSCache: def __init__(self, cache_size: Optional[int] = 50000) -> None: ... def len(self) -> int: ... def aggregate_verify(self, pks: List[G1Element], msgs: List[bytes], sig: G2Element) -> bool: ... + def items(self) -> List[Tuple[bytes, bytes]]: ... + def update(self, other: List[Tuple[bytes, bytes]]) -> None: ... class AugSchemeMPL: @staticmethod diff --git a/wheel/python/chia_rs/chia_rs.pyi b/wheel/python/chia_rs/chia_rs.pyi index 7b08a802e..4c28513f8 100644 --- a/wheel/python/chia_rs/chia_rs.pyi +++ b/wheel/python/chia_rs/chia_rs.pyi @@ -85,6 +85,8 @@ class BLSCache: def __init__(self, cache_size: Optional[int] = 50000) -> None: ... def len(self) -> int: ... def aggregate_verify(self, pks: List[G1Element], msgs: List[bytes], sig: G2Element) -> bool: ... + def items(self) -> List[Tuple[bytes, bytes]]: ... + def update(self, other: List[Tuple[bytes, bytes]]) -> None: ... class AugSchemeMPL: @staticmethod