From f4f247a16382b8b43cdf81b3d9da1ff723820f61 Mon Sep 17 00:00:00 2001 From: nemo Date: Tue, 19 Oct 2021 12:13:15 -0400 Subject: [PATCH] feat: reduce conversions for CPU tree building differently --- .../src/stacked/vanilla/proof.rs | 63 ++++++++++++------- 1 file changed, 41 insertions(+), 22 deletions(-) diff --git a/storage-proofs-porep/src/stacked/vanilla/proof.rs b/storage-proofs-porep/src/stacked/vanilla/proof.rs index f39e6cfe5b..3d08bc2d94 100644 --- a/storage-proofs-porep/src/stacked/vanilla/proof.rs +++ b/storage-proofs-porep/src/stacked/vanilla/proof.rs @@ -38,7 +38,7 @@ use storage_proofs_core::{ use yastl::Pool; use crate::{ - encode::{decode, encode_fr}, + encode::{decode, encode}, stacked::vanilla::{ challenges::LayerChallenges, column::Column, @@ -829,6 +829,7 @@ impl<'a, Tree: 'static + MerkleTreeTrait, G: 'static + Hasher> StackedDrg<'a, Tr start: usize, end: usize| -> Result> { + use crate::encode::encode_fr; use fr32::bytes_into_fr; let mut layer_bytes = vec![0u8; (end - start) * std::mem::size_of::()]; @@ -836,7 +837,7 @@ impl<'a, Tree: 'static + MerkleTreeTrait, G: 'static + Hasher> StackedDrg<'a, Tr .read_range_into(start, end, &mut layer_bytes) .expect("failed to read layer bytes"); - let encoded_data = layer_bytes + let encoded_data: Vec<_> = layer_bytes .into_par_iter() .chunks(std::mem::size_of::()) .map(|chunk| { @@ -852,7 +853,7 @@ impl<'a, Tree: 'static + MerkleTreeTrait, G: 'static + Hasher> StackedDrg<'a, Tr ::Domain::try_from_bytes(data_node_bytes) .expect("try_from_bytes failed"); - let mut encoded_fr: Fr = key.into(); + let mut encoded_fr: Fr = key; let data_node_fr: Fr = data_node.into(); encode_fr(&mut encoded_fr, data_node_fr); let encoded_fr_repr = encoded_fr.to_repr(); @@ -884,7 +885,7 @@ impl<'a, Tree: 'static + MerkleTreeTrait, G: 'static + Hasher> StackedDrg<'a, Tr start: usize, end: usize| -> Result> { - let encoded_data = source + let encoded_data: Vec<_> = source .read_range(start..end)? .into_par_iter() .zip( @@ -897,18 +898,22 @@ impl<'a, Tree: 'static + MerkleTreeTrait, G: 'static + Hasher> StackedDrg<'a, Tr ::Domain::try_from_bytes(data_node_bytes) .expect("try from bytes failed"); - let mut encoded_fr: Fr = key.into(); - let data_node_fr: Fr = data_node.into(); - encode_fr(&mut encoded_fr, data_node_fr); - let encoded_fr_repr = encoded_fr.to_repr(); - data_node_bytes - .copy_from_slice(AsRef::<[u8]>::as_ref(&encoded_fr_repr)); + let key_elem = + ::Domain::try_from_bytes(&key.into_bytes()) + .expect("failed to convert key"); + let encoded_node = + encode::<::Domain>(key_elem, data_node); + data_node_bytes.copy_from_slice(AsRef::<[u8]>::as_ref(&encoded_node)); - encoded_fr + encoded_node }) .collect(); - Ok(encoded_data) + // Safety: this is safe because in this case we + // know the caller needs a list of domain + // elements, and we are faking the return value of + // the closure to avoid the extra type conversion + Ok(unsafe { std::mem::transmute(encoded_data) }) }, }; @@ -957,17 +962,22 @@ impl<'a, Tree: 'static + MerkleTreeTrait, G: 'static + Hasher> StackedDrg<'a, Tr ::Domain::try_from_bytes(data_node_bytes) .expect("try from bytes failed"); - let mut encoded_fr: Fr = key.into(); - let data_node_fr: Fr = data_node.into(); - encode_fr(&mut encoded_fr, data_node_fr); - let encoded_fr_repr = encoded_fr.to_repr(); - data_node_bytes.copy_from_slice(AsRef::<[u8]>::as_ref(&encoded_fr_repr)); + let key_elem = + ::Domain::try_from_bytes(&key.into_bytes()) + .expect("failed to convert key"); + let encoded_node = + encode::<::Domain>(key_elem, data_node); + data_node_bytes.copy_from_slice(AsRef::<[u8]>::as_ref(&encoded_node)); - encoded_fr + encoded_node }) .collect(); - Ok(encoded_data) + // Safety: this is safe because in this case we + // know the caller needs a list of domain + // elements, and we are faking the return value of + // the closure to avoid the extra type conversion + Ok(unsafe { std::mem::transmute(encoded_data) }) }, }; @@ -1181,8 +1191,17 @@ impl<'a, Tree: 'static + MerkleTreeTrait, G: 'static + Hasher> StackedDrg<'a, Tr let mut end = size / tree_count; for (i, config) in configs.iter().enumerate() { - let encoded_data = callback(&source, Some(data), start, end) - .expect("failed to prepare tree_r_last data"); + // Safety: this is a safe operation because while we + // require a vector of domain elements, the closure return + // type returns a vector of Fr elements. Since the + // callback is caller controlled, we know it's actually a + // vector of domain elements in this case. + let encoded_data: Vec<::Domain> = unsafe { + std::mem::transmute( + callback(&source, Some(data), start, end) + .expect("failed to prepare tree_r_last data"), + ) + }; info!( "building base tree_r_last with CPU {}/{}", @@ -1204,7 +1223,7 @@ impl<'a, Tree: 'static + MerkleTreeTrait, G: 'static + Hasher> StackedDrg<'a, Tr } LCTree::::from_par_iter_with_config( - encoded_data.into_par_iter().map(Into::into), + encoded_data, config.clone(), ) .with_context(|| format!("failed tree_r_last CPU {}/{}", i + 1, tree_count))?;