Skip to content

Commit

Permalink
Merge pull request #516 from Chia-Network/pyo3-0.21.2
Browse files Browse the repository at this point in the history
Bump pyo3 to 0.21.2
  • Loading branch information
Rigidity committed May 17, 2024
2 parents 3765b12 + e224c5d commit b682686
Show file tree
Hide file tree
Showing 34 changed files with 281 additions and 222 deletions.
49 changes: 32 additions & 17 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion crates/chia-bls/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ hkdf = "0.12.0"
blst = { version = "0.3.11", git = "https://github.com/supranational/blst.git", rev = "0d46eefa45fc1e57aceb42bba0e84eab3a7a9725", features = ["portable"] }
hex = "0.4.3"
thiserror = "1.0.44"
pyo3 = { version = "0.19.0", features = ["multiple-pymethods"], optional = true }
pyo3 = { version = "0.21.2", features = ["multiple-pymethods"], optional = true }
arbitrary = { version = "1.3.0" , optional = true}
lru = "0.12.2"

Expand Down
2 changes: 1 addition & 1 deletion crates/chia-bls/fuzz/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ cargo-fuzz = true

[dependencies]
libfuzzer-sys = "0.4"
pyo3 = { version = ">=0.19.0", features = ["auto-initialize"]}
pyo3 = { version = "0.21.2", features = ["auto-initialize"]}

[dependencies.chia-bls]
path = ".."
Expand Down
47 changes: 30 additions & 17 deletions crates/chia-bls/fuzz/fuzz_targets/blspy-fidelity.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,13 @@ use chia_bls::secret_key::SecretKey;
use chia_bls::signature::{aggregate, sign};
use pyo3::types::{PyBytes, PyList, PyTuple};

fn to_bytes(obj: &PyAny) -> &[u8] {
fn to_bytes(obj: &Bound<PyAny>) -> Vec<u8> {
obj.call_method0("__bytes__")
.unwrap()
.downcast::<PyBytes>()
.unwrap()
.as_bytes()
.to_vec()
}

fuzz_target!(|data: &[u8]| {
Expand All @@ -35,18 +36,21 @@ fuzz_target!(|data: &[u8]| {
print(sys.executable)
"#, None, None).unwrap();
*/
let blspy = py.import("blspy").unwrap();
let blspy = py.import_bound("blspy").unwrap();
let aug = blspy.getattr("AugSchemeMPL").unwrap();

// Generate key pair from seed
let rust_sk = SecretKey::from_seed(data);
let py_sk = aug
.call_method1("key_gen", PyTuple::new(py, [PyBytes::new(py, data)]))
.call_method1(
"key_gen",
PyTuple::new_bound(py, [PyBytes::new_bound(py, data)]),
)
.unwrap();

// convert to bytes
let rust_sk_bytes = rust_sk.to_bytes();
let py_sk_bytes = to_bytes(py_sk);
let py_sk_bytes = to_bytes(&py_sk);
assert_eq!(py_sk_bytes, rust_sk_bytes);

// get the public key
Expand All @@ -55,61 +59,70 @@ fuzz_target!(|data: &[u8]| {

// convert to bytes
let rust_pk_bytes = rust_pk.to_bytes();
let py_pk_bytes = to_bytes(py_pk);
let py_pk_bytes = to_bytes(&py_pk);
assert_eq!(py_pk_bytes, rust_pk_bytes);

let idx = u32::from_be_bytes(<[u8; 4]>::try_from(&data[0..4]).unwrap());
let rust_sk1 = rust_sk.derive_unhardened(idx);
let py_sk1 = aug
.call_method1(
"derive_child_sk_unhardened",
PyTuple::new(py, [py_sk, idx.to_object(py).as_ref(py)]),
PyTuple::new_bound(
py,
[py_sk.clone(), idx.to_object(py).bind(py).clone().into_any()],
),
)
.unwrap();
assert_eq!(to_bytes(py_sk1), rust_sk1.to_bytes());
assert_eq!(to_bytes(&py_sk1), rust_sk1.to_bytes());

let rust_pk1 = rust_pk.derive_unhardened(idx);
let py_pk1 = aug
.call_method1(
"derive_child_pk_unhardened",
PyTuple::new(py, [py_pk, idx.to_object(py).as_ref(py)]),
PyTuple::new_bound(py, [py_pk, idx.to_object(py).bind(py).clone().into_any()]),
)
.unwrap();
assert_eq!(to_bytes(py_pk1), rust_pk1.to_bytes());
assert_eq!(to_bytes(&py_pk1), rust_pk1.to_bytes());

// sign with the derived keys
let rust_sig1 = sign(&rust_sk1, data);
let py_sig1 = aug
.call_method1("sign", PyTuple::new(py, [py_sk1, PyBytes::new(py, data)]))
.call_method1(
"sign",
PyTuple::new_bound(py, [py_sk1, PyBytes::new_bound(py, data).into_any()]),
)
.unwrap();
assert_eq!(to_bytes(py_sig1), rust_sig1.to_bytes());
assert_eq!(to_bytes(&py_sig1), rust_sig1.to_bytes());

// derive hardened
let idx = u32::from_be_bytes(<[u8; 4]>::try_from(&data[4..8]).unwrap());
let rust_sk2 = rust_sk.derive_hardened(idx);
let py_sk2 = aug
.call_method1(
"derive_child_sk",
PyTuple::new(py, [py_sk, idx.to_object(py).as_ref(py)]),
PyTuple::new_bound(py, [py_sk, idx.to_object(py).bind(py).clone().into_any()]),
)
.unwrap();
assert_eq!(to_bytes(py_sk2), rust_sk2.to_bytes());
assert_eq!(to_bytes(&py_sk2), rust_sk2.to_bytes());

// sign with the derived keys
let rust_sig2 = sign(&rust_sk2, data);
let py_sig2 = aug
.call_method1("sign", PyTuple::new(py, [py_sk2, PyBytes::new(py, data)]))
.call_method1(
"sign",
PyTuple::new_bound(py, [py_sk2, PyBytes::new_bound(py, data).into_any()]),
)
.unwrap();
assert_eq!(to_bytes(py_sig2), rust_sig2.to_bytes());
assert_eq!(to_bytes(&py_sig2), rust_sig2.to_bytes());

// aggregate
let rust_agg = aggregate([rust_sig1, rust_sig2]);
let py_agg = aug
.call_method1(
"aggregate",
PyTuple::new(py, [PyList::new(py, [py_sig1, py_sig2])]),
PyTuple::new_bound(py, [PyList::new_bound(py, [py_sig1, py_sig2])]),
)
.unwrap();
assert_eq!(to_bytes(py_agg), rust_agg.to_bytes());
assert_eq!(to_bytes(&py_agg), rust_agg.to_bytes());
});
});
22 changes: 16 additions & 6 deletions crates/chia-bls/src/cached_bls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@ use std::num::NonZeroUsize;
#[cfg(feature = "py-bindings")]
use pyo3::exceptions::PyValueError;
#[cfg(feature = "py-bindings")]
use pyo3::types::PyList;
use pyo3::pybacked::PyBackedBytes;
#[cfg(feature = "py-bindings")]
use pyo3::types::{PyAnyMethods, PyList};
#[cfg(feature = "py-bindings")]
use pyo3::{pyclass, pymethods, PyResult};

Expand Down Expand Up @@ -107,13 +109,21 @@ impl BLSCache {
#[pyo3(name = "aggregate_verify")]
pub fn py_aggregate_verify(
&mut self,
pks: &PyList,
msgs: &PyList,
pks: &pyo3::Bound<PyList>,
msgs: &pyo3::Bound<PyList>,
sig: &Signature,
) -> PyResult<bool> {
let pks_r = pks.iter().map(|item| item.extract::<PublicKey>().unwrap());
let msgs_r = msgs.iter().map(|item| item.extract::<&[u8]>().unwrap());
Ok(self.aggregate_verify(pks_r, msgs_r, sig))
let pks = pks
.iter()?
.map(|item| item?.extract())
.collect::<PyResult<Vec<PublicKey>>>()?;

let msgs = msgs
.iter()?
.map(|item| item?.extract())
.collect::<PyResult<Vec<PyBackedBytes>>>()?;

Ok(self.aggregate_verify(pks, msgs, sig))
}

#[pyo3(name = "len")]
Expand Down
4 changes: 3 additions & 1 deletion crates/chia-bls/src/gtelement.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ use chia_traits::to_json_dict::ToJsonDict;
#[cfg(feature = "py-bindings")]
use pyo3::exceptions::PyValueError;
#[cfg(feature = "py-bindings")]
use pyo3::types::PyAnyMethods;
#[cfg(feature = "py-bindings")]
use pyo3::{pyclass, pymethods, IntoPy, PyAny, PyObject, PyResult, Python};

#[cfg_attr(feature = "py-bindings", pyclass, derive(PyStreamable))]
Expand Down Expand Up @@ -116,7 +118,7 @@ impl ToJsonDict for GTElement {

#[cfg(feature = "py-bindings")]
impl FromJsonDict for GTElement {
fn from_json_dict(o: &PyAny) -> PyResult<Self> {
fn from_json_dict(o: &pyo3::Bound<PyAny>) -> PyResult<Self> {
let s: String = o.extract()?;
if !s.starts_with("0x") {
return Err(PyValueError::new_err(
Expand Down
12 changes: 7 additions & 5 deletions crates/chia-bls/src/public_key.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ use chia_traits::from_json_dict::FromJsonDict;
#[cfg(feature = "py-bindings")]
use chia_traits::to_json_dict::ToJsonDict;
#[cfg(feature = "py-bindings")]
use pyo3::prelude::PyAnyMethods;
#[cfg(feature = "py-bindings")]
use pyo3::{pyclass, pymethods, IntoPy, PyAny, PyObject, PyResult, Python};

#[cfg_attr(
Expand Down Expand Up @@ -301,7 +303,7 @@ impl ToJsonDict for PublicKey {
}

#[cfg(feature = "py-bindings")]
pub fn parse_hex_string(o: &PyAny, len: usize, name: &str) -> PyResult<Vec<u8>> {
pub fn parse_hex_string(o: &pyo3::Bound<PyAny>, len: usize, name: &str) -> PyResult<Vec<u8>> {
use pyo3::exceptions::{PyTypeError, PyValueError};
if let Ok(s) = o.extract::<String>() {
let s = if let Some(st) = s.strip_prefix("0x") {
Expand Down Expand Up @@ -346,7 +348,7 @@ pub fn parse_hex_string(o: &PyAny, len: usize, name: &str) -> PyResult<Vec<u8>>

#[cfg(feature = "py-bindings")]
impl FromJsonDict for PublicKey {
fn from_json_dict(o: &PyAny) -> PyResult<Self> {
fn from_json_dict(o: &pyo3::Bound<PyAny>) -> PyResult<Self> {
Ok(Self::from_bytes(
parse_hex_string(o, 48, "PublicKey")?
.as_slice()
Expand Down Expand Up @@ -741,7 +743,7 @@ mod pytests {
let pk = sk.public_key();
Python::with_gil(|py| {
let string = pk.to_json_dict(py).expect("to_json_dict");
let pk2 = PublicKey::from_json_dict(string.as_ref(py)).unwrap();
let pk2 = PublicKey::from_json_dict(string.bind(py)).unwrap();
assert_eq!(pk, pk2);
});
}
Expand All @@ -757,8 +759,8 @@ mod pytests {
pyo3::prepare_freethreaded_python();
Python::with_gil(|py| {
let err =
PublicKey::from_json_dict(input.to_string().into_py(py).as_ref(py)).unwrap_err();
assert_eq!(err.value(py).to_string(), msg.to_string());
PublicKey::from_json_dict(input.to_string().into_py(py).bind(py)).unwrap_err();
assert_eq!(err.value_bound(py).to_string(), msg.to_string());
});
}
}
Loading

0 comments on commit b682686

Please sign in to comment.