Skip to content

Commit

Permalink
Merge Fix missing destruction of rand_generator from cuda/hip
Browse files Browse the repository at this point in the history
This PR fixes the missing destruction of rand_generator from cuda/hip

Related PR: #1417
  • Loading branch information
yhmtsai authored Oct 12, 2023
2 parents 6b0acfb + c28aeed commit d3dd178
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 0 deletions.
6 changes: 6 additions & 0 deletions cuda/base/curand_bindings.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,12 @@ inline curandGenerator_t rand_generator(int64 seed,
}


inline void destroy(curandGenerator_t gen)
{
GKO_ASSERT_NO_CURAND_ERRORS(curandDestroyGenerator(gen));
}


#define GKO_BIND_CURAND_RANDOM_VECTOR(ValueType, CurandName) \
inline void rand_vector( \
curandGenerator_t& gen, int n, remove_complex<ValueType> mean, \
Expand Down
1 change: 1 addition & 0 deletions cuda/solver/idr_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ void initialize_subspace_vectors(std::shared_ptr<const DefaultExecutor> exec,
gen,
subspace_vectors->get_size()[0] * subspace_vectors->get_stride(),
0.0, 1.0, subspace_vectors->get_values());
curand::destroy(gen);
}
}
Expand Down
5 changes: 5 additions & 0 deletions hip/base/hiprand_bindings.hip.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,11 @@ inline hiprandGenerator_t rand_generator(int64 seed,
return gen;
}

inline void destroy(hiprandGenerator_t gen)
{
GKO_ASSERT_NO_HIPRAND_ERRORS(hiprandDestroyGenerator(gen));
}


#define GKO_BIND_HIPRAND_RANDOM_VECTOR(ValueType, HiprandName) \
inline void rand_vector( \
Expand Down
1 change: 1 addition & 0 deletions hip/solver/idr_kernels.hip.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ void initialize_subspace_vectors(std::shared_ptr<const DefaultExecutor> exec,
gen,
subspace_vectors->get_size()[0] * subspace_vectors->get_stride(),
0.0, 1.0, subspace_vectors->get_values());
hiprand::destroy(gen);
}
}

Expand Down

0 comments on commit d3dd178

Please sign in to comment.