Skip to content

Commit

Permalink
feat: wrap settings for tree builders in a method
Browse files Browse the repository at this point in the history
  • Loading branch information
cryptonemo committed Nov 2, 2021
1 parent 5babbe6 commit 57955b5
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 22 deletions.
31 changes: 18 additions & 13 deletions storage-proofs-porep/src/stacked/vanilla/proof.rs
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,20 @@ impl<'a, Tree: 'static + MerkleTreeTrait, G: 'static + Hasher> StackedDrg<'a, Tr
Ok(tree)
}

// Even if the column builder is enabled, the GPU column builder
// only supports Poseidon hashes.
pub fn use_gpu_column_builder() -> bool {
SETTINGS.use_gpu_column_builder
&& TypeId::of::<Tree::Hasher>() == TypeId::of::<PoseidonHasher>()
}

// Even if the tree builder is enabled, the GPU tree builder
// only supports Poseidon hashes.
pub fn use_gpu_tree_builder() -> bool {
SETTINGS.use_gpu_tree_builder
&& TypeId::of::<Tree::Hasher>() == TypeId::of::<PoseidonHasher>()
}

#[cfg(any(feature = "cuda", feature = "opencl"))]
fn generate_tree_c<ColumnArity, TreeArity>(
layers: usize,
Expand All @@ -445,9 +459,7 @@ impl<'a, Tree: 'static + MerkleTreeTrait, G: 'static + Hasher> StackedDrg<'a, Tr
ColumnArity: 'static + PoseidonArity,
TreeArity: PoseidonArity,
{
if SETTINGS.use_gpu_column_builder
&& TypeId::of::<Tree::Hasher>() == TypeId::of::<PoseidonHasher>()
{
if Self::use_gpu_column_builder() {
Self::generate_tree_c_gpu::<ColumnArity, TreeArity>(
layers,
nodes_count,
Expand Down Expand Up @@ -824,9 +836,7 @@ impl<'a, Tree: 'static + MerkleTreeTrait, G: 'static + Hasher> StackedDrg<'a, Tr
start: usize,
end: usize,
) -> Result<TreeRElementData<Tree>> {
if SETTINGS.use_gpu_tree_builder
&& TypeId::of::<Tree::Hasher>() == TypeId::of::<PoseidonHasher>()
{
if Self::use_gpu_tree_builder() {
use fr32::bytes_into_fr;

let mut layer_bytes = vec![0u8; (end - start) * std::mem::size_of::<Fr>()];
Expand Down Expand Up @@ -937,10 +947,7 @@ impl<'a, Tree: 'static + MerkleTreeTrait, G: 'static + Hasher> StackedDrg<'a, Tr
None => Self::prepare_tree_r_data,
};

// The GPU tree builder only support Poseidon hashes.
if SETTINGS.use_gpu_tree_builder
&& TypeId::of::<Tree::Hasher>() == TypeId::of::<PoseidonHasher>()
{
if Self::use_gpu_tree_builder() {
Self::generate_tree_r_last_gpu::<TreeArity>(
data,
nodes_count,
Expand Down Expand Up @@ -1535,9 +1542,7 @@ impl<'a, Tree: 'static + MerkleTreeTrait, G: 'static + Hasher> StackedDrg<'a, Tr
tree_count,
)?;

if SETTINGS.use_gpu_tree_builder
&& TypeId::of::<Tree::Hasher>() == TypeId::of::<PoseidonHasher>()
{
if Self::use_gpu_tree_builder() {
info!("generating tree r last using the GPU");
let max_gpu_tree_batch_size = SETTINGS.max_gpu_tree_batch_size as usize;

Expand Down
11 changes: 2 additions & 9 deletions storage-proofs-update/src/vanilla.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
use std::any::TypeId;
use std::fs::{metadata, OpenOptions};
use std::iter::FromIterator;
use std::marker::PhantomData;
Expand All @@ -7,7 +6,7 @@ use std::path::{Path, PathBuf};
use anyhow::{ensure, Context, Error};
use blstrs::Scalar as Fr;
use ff::Field;
use filecoin_hashers::{poseidon::PoseidonHasher, HashFunction, Hasher};
use filecoin_hashers::{HashFunction, Hasher};
use fr32::{bytes_into_fr, fr_into_bytes_slice};
use generic_array::typenum::Unsigned;
use log::{info, trace};
Expand All @@ -31,7 +30,6 @@ use storage_proofs_core::{
},
parameter_cache::ParameterSetMetadata,
proof::ProofScheme,
settings::SETTINGS,
};
use storage_proofs_porep::stacked::{StackedDrg, TreeRElementData};

Expand Down Expand Up @@ -729,12 +727,7 @@ where
.read_range(start..end)
.expect("failed to read from source");

// Note: The TreeR type is already constrained to
// PoseidonHasher types, but for additional clarity, we add
// this check where it's not necessary.
if SETTINGS.use_gpu_tree_builder
&& TypeId::of::<TreeR::Hasher>() == TypeId::of::<PoseidonHasher>()
{
if StackedDrg::<TreeR, TreeDHasher>::use_gpu_tree_builder() {
Ok(TreeRElementData::FrList(
tree_data.into_par_iter().map(|x| x.into()).collect(),
))
Expand Down

0 comments on commit 57955b5

Please sign in to comment.