Skip to content

Commit

Permalink
extend BLS cache with items() and update() to support the current use…
Browse files Browse the repository at this point in the history
… case of serializing and returning the cache entries from the TX validation worker process back to the main process
  • Loading branch information
arvidn committed May 22, 2024
1 parent 4e1689d commit 51f4800
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 8 deletions.
48 changes: 41 additions & 7 deletions crates/chia-bls/src/cached_bls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,15 @@ use sha2::{Digest, Sha256};
use crate::{aggregate_verify_gt, hash_to_g2};
use crate::{GTElement, PublicKey, Signature};

/// This cache is a bit weird because it's trying to account for validating
/// mempool signatures versus block signatures. When validating block signatures,
/// there's not much point in caching the pairings because we're probably not going
/// to see them again unless there's a reorg. However, a spend in the mempool
/// is likely to reappear in a block later, so we can save having to do the pairing
/// again. So caching is primarily useful when synced and monitoring the mempool in real-time.
/// This is a cache of pairings of public keys and their corresponding message.
/// It accelerates aggregate verification when some public keys have already
/// been paired, and found in the cache.
/// We use it to cache pairings when validating transactions inserted into the
/// mempool, as many of those transactions are likely to show up in a full block
/// later. This makes it a lot cheaper to validate the full block.
/// However, validating a signature where we have no cached GT elements, the
/// aggregate_verify() primitive is faster. When long-syncing, that's
/// preferable.
#[cfg_attr(feature = "py-bindings", pyo3::pyclass(name = "BLSCache"))]
#[derive(Debug, Clone)]
pub struct BlsCache {
Expand Down Expand Up @@ -86,7 +89,7 @@ mod python {
pybacked::PyBackedBytes,
pymethods,
types::{PyAnyMethods, PyList},
Bound, PyResult,
Bound, PyObject, PyResult,
};

#[pymethods]
Expand Down Expand Up @@ -130,6 +133,37 @@ mod python {
pub fn py_len(&self) -> PyResult<usize> {
Ok(self.len())
}

#[pyo3(name = "items")]
pub fn py_items(&self, py: pyo3::Python) -> PyResult<PyObject> {
use pyo3::prelude::*;
use pyo3::types::PyBytes;
let ret = PyList::empty_bound(py);
for (key, value) in self.cache.iter() {
ret.append((
PyBytes::new_bound(py, key),
PyBytes::new_bound(py, &value.to_bytes()),
))?;
}
Ok(ret.into())
}

#[pyo3(name = "update")]
pub fn py_update(&mut self, other: &Bound<PyList>) -> PyResult<()> {
for item in other.borrow().iter()? {
let (key, value): (Vec<u8>, Vec<u8>) = item?.extract()?;
self.cache.put(
key.try_into()
.map_err(|_| PyValueError::new_err("invalid key"))?,
GTElement::from_bytes(
(&value[..])
.try_into()
.map_err(|_| PyValueError::new_err("invalid GTElement"))?,
),
);
}
Ok(())
}
}
}

Expand Down
38 changes: 37 additions & 1 deletion tests/test_blscache.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ def test_cache_limit() -> None:


# old Python tests ported
# benchmark old vs new BLSCache
def test_cached_bls():
cached_bls = BLSCache()
n_keys = 10
Expand Down Expand Up @@ -115,6 +114,43 @@ def test_cached_bls():
assert cached_bls_old.aggregate_verify(pks_bytes, msgs, agg_sig, False, local_cache)


def test_cached_bls_flattening():
cached_bls = BLSCache()
n_keys = 10
seed = b"a" * 31
sks = [AugSchemeMPL.key_gen(seed + bytes([i])) for i in range(n_keys)]
pks = [sk.get_g1() for sk in sks]
aggsig = AugSchemeMPL.aggregate(
[AugSchemeMPL.sign(sk, b"foobar", pk) for sk, pk in zip(sks, pks)]
)

assert cached_bls.aggregate_verify(pks, [b"foobar"] * n_keys, aggsig)
assert len(cached_bls.items()) == n_keys
gts = [
bytes(pk.pair(AugSchemeMPL.g2_from_message(bytes(pk) + b"foobar")))
for pk in pks
]
for key, value in cached_bls.items():
assert isinstance(key, bytes)
assert isinstance(value, bytes)
assert value in gts
gts.remove(value)

cache_copy = BLSCache()
cache_copy.update(cached_bls.items())

assert len(cache_copy.items()) == n_keys
gts = [
bytes(pk.pair(AugSchemeMPL.g2_from_message(bytes(pk) + b"foobar")))
for pk in pks
]
for key, value in cache_copy.items():
assert isinstance(key, bytes)
assert isinstance(value, bytes)
assert value in gts
gts.remove(value)


def test_cached_bls_repeat_pk():
cached_bls = BLSCache()
n_keys = 400
Expand Down
2 changes: 2 additions & 0 deletions wheel/generate_type_stubs.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,8 @@ class BLSCache:
def __init__(self, cache_size: Optional[int] = 50000) -> None: ...
def len(self) -> int: ...
def aggregate_verify(self, pks: List[G1Element], msgs: List[bytes], sig: G2Element) -> bool: ...
def items(self) -> List[Tuple[bytes, bytes]]: ...
def update(self, other: List[Tuple[bytes, bytes]]) -> None: ...
class AugSchemeMPL:
@staticmethod
Expand Down
2 changes: 2 additions & 0 deletions wheel/python/chia_rs/chia_rs.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,8 @@ class BLSCache:
def __init__(self, cache_size: Optional[int] = 50000) -> None: ...
def len(self) -> int: ...
def aggregate_verify(self, pks: List[G1Element], msgs: List[bytes], sig: G2Element) -> bool: ...
def items(self) -> List[Tuple[bytes, bytes]]: ...
def update(self, other: List[Tuple[bytes, bytes]]) -> None: ...

class AugSchemeMPL:
@staticmethod
Expand Down

0 comments on commit 51f4800

Please sign in to comment.