From 93879e15ea569b215ba1a5dbdd686e5e085fda74 Mon Sep 17 00:00:00 2001 From: Arvid Norberg Date: Sat, 3 Aug 2024 00:42:21 +0200 Subject: [PATCH] make GTElement also use parse_hex_string (just like PublicKey, Signature and PrivateKey). Add test cases to python tests --- crates/chia-bls/src/gtelement.rs | 32 +++++++++----------------------- tests/test_blspy_fidelity.py | 25 +++++++++++++++++++++++++ 2 files changed, 34 insertions(+), 23 deletions(-) diff --git a/crates/chia-bls/src/gtelement.rs b/crates/chia-bls/src/gtelement.rs index 2375d2540..151b4226d 100644 --- a/crates/chia-bls/src/gtelement.rs +++ b/crates/chia-bls/src/gtelement.rs @@ -128,39 +128,25 @@ impl GTElement { mod pybindings { use super::*; + use crate::parse_hex::parse_hex_string; use chia_traits::{FromJsonDict, ToJsonDict}; - use pyo3::{exceptions::PyValueError, prelude::*}; + use pyo3::prelude::*; impl ToJsonDict for GTElement { fn to_json_dict(&self, py: Python<'_>) -> PyResult { let bytes = self.to_bytes(); - Ok(hex::encode(bytes).into_py(py)) + Ok(("0x".to_string() + &hex::encode(bytes)).into_py(py)) } } impl FromJsonDict for GTElement { fn from_json_dict(o: &Bound<'_, PyAny>) -> PyResult { - let s: String = o.extract()?; - if !s.starts_with("0x") { - return Err(PyValueError::new_err( - "bytes object is expected to start with 0x", - )); - } - let s = &s[2..]; - let buf = match hex::decode(s) { - Err(_) => { - return Err(PyValueError::new_err("invalid hex")); - } - Ok(v) => v, - }; - if buf.len() != Self::SIZE { - return Err(PyValueError::new_err(format!( - "GTElement, invalid length {} expected {}", - buf.len(), - Self::SIZE - ))); - } - Ok(Self::from_bytes(buf.as_slice().try_into().unwrap())) + Ok(Self::from_bytes( + parse_hex_string(o, Self::SIZE, "GTElement")? + .as_slice() + .try_into() + .unwrap(), + )) } } } diff --git a/tests/test_blspy_fidelity.py b/tests/test_blspy_fidelity.py index 0374797dc..9af92566c 100644 --- a/tests/test_blspy_fidelity.py +++ b/tests/test_blspy_fidelity.py @@ -2,6 +2,8 @@ import chia_rs from random import getrandbits import sys +from typing import Any, Type +import pytest def randbytes(n: int) -> bytes: @@ -185,6 +187,29 @@ def test_bls() -> None: # get_fingerprint() assert pk1.get_fingerprint() == pk2.get_fingerprint() + obj: Any + klass: Any + for obj, klass in [ + (pk2, G1Element), + (sig2, G2Element), + (sk2, PrivateKey), + (pair2, chia_rs.GTElement), + ]: + print(f"{klass}") + # to_json_dict + expected_json = "0x" + bytes(obj).hex() + assert obj.to_json_dict() == expected_json + # from_json_dict + assert obj == klass.from_json_dict(expected_json) + # binary blobs are also accepted in JSON dicts + assert obj == klass.from_json_dict(bytes(obj)) + # too short + with pytest.raises(ValueError, match="invalid length"): + obj2 = klass.from_json_dict(bytes(obj)[0:-1]) + # too long + with pytest.raises(ValueError, match="invalid length"): + obj2 = klass.from_json_dict(bytes(obj) + b"a") + # ------------------------------------- 8< ---------------------------------- #