diff --git a/chia-protocol/src/classgroup.rs b/chia-protocol/src/classgroup.rs index 00cd42588..5ebb37c60 100644 --- a/chia-protocol/src/classgroup.rs +++ b/chia-protocol/src/classgroup.rs @@ -7,18 +7,42 @@ use pyo3::prelude::*; streamable_struct!(ClassgroupElement { data: Bytes100 }); -#[cfg(feature = "py-bindings")] -#[cfg_attr(feature = "py-bindings", pymethods)] impl ClassgroupElement { - #[staticmethod] pub fn get_default_element() -> ClassgroupElement { let mut data = [0_u8; 100]; data[0] = 0x08; ClassgroupElement { data: data.into() } } - #[staticmethod] - pub fn get_size(_constants: pyo3::PyObject) -> i32 { + pub fn get_size() -> i32 { 100 } } + +#[cfg(feature = "py-bindings")] +#[pymethods] +impl ClassgroupElement { + #[staticmethod] + pub fn create(bytes: &[u8]) -> ClassgroupElement { + if bytes.len() == 100 { + ClassgroupElement { data: bytes.into() } + } else { + assert!(bytes.len() < 100); + let mut data = [0_u8; 100]; + data[..bytes.len()].copy_from_slice(bytes); + ClassgroupElement { data: data.into() } + } + } + + #[staticmethod] + #[pyo3(name = "get_default_element")] + pub fn py_get_default_element() -> ClassgroupElement { + Self::get_default_element() + } + + #[staticmethod] + #[pyo3(name = "get_size")] + pub fn py_get_size() -> i32 { + Self::get_size() + } +} diff --git a/wheel/chia_rs.pyi b/wheel/chia_rs.pyi index 3e62e4a96..fef6bd3df 100644 --- a/wheel/chia_rs.pyi +++ b/wheel/chia_rs.pyi @@ -350,6 +350,10 @@ class Handshake: class ClassgroupElement: data: bytes100 + @staticmethod + def get_default_element() -> ClassgroupElement: ... + @staticmethod + def get_size() -> int: ... def __init__( self, data: bytes100 diff --git a/wheel/generate_type_stubs.py b/wheel/generate_type_stubs.py index 16613792f..d9d94e541 100644 --- a/wheel/generate_type_stubs.py +++ b/wheel/generate_type_stubs.py @@ -139,7 +139,15 @@ def parse_rust_source(filename: str) -> List[Tuple[str, List[str]]]: return ret -extra_members = {"Coin": ["def name(self) -> bytes32: ..."]} +extra_members = { + "Coin": [ + "def name(self) -> bytes32: ...", + ], + "ClassgroupElement": [ + "@staticmethod\n def get_default_element() -> ClassgroupElement: ...", + "@staticmethod\n def get_size() -> int: ...", + ], +} classes = [] for f in sorted(glob(str(input_dir / "*.rs"))):