From c28aeedc5be0610e92e95d828679803a670fccaa Mon Sep 17 00:00:00 2001 From: "Yuhsiang M. Tsai" Date: Thu, 14 Sep 2023 15:08:47 +0200 Subject: [PATCH] destroy rand_generator --- cuda/base/curand_bindings.hpp | 6 ++++++ cuda/solver/idr_kernels.cu | 1 + hip/base/hiprand_bindings.hip.hpp | 5 +++++ hip/solver/idr_kernels.hip.cpp | 1 + 4 files changed, 13 insertions(+) diff --git a/cuda/base/curand_bindings.hpp b/cuda/base/curand_bindings.hpp index 429481ec9b6..4bf12dd9064 100644 --- a/cuda/base/curand_bindings.hpp +++ b/cuda/base/curand_bindings.hpp @@ -83,6 +83,12 @@ inline curandGenerator_t rand_generator(int64 seed, } +inline void destroy(curandGenerator_t gen) +{ + GKO_ASSERT_NO_CURAND_ERRORS(curandDestroyGenerator(gen)); +} + + #define GKO_BIND_CURAND_RANDOM_VECTOR(ValueType, CurandName) \ inline void rand_vector( \ curandGenerator_t& gen, int n, remove_complex mean, \ diff --git a/cuda/solver/idr_kernels.cu b/cuda/solver/idr_kernels.cu index 10e8a7b2fc3..7bfe56987f4 100644 --- a/cuda/solver/idr_kernels.cu +++ b/cuda/solver/idr_kernels.cu @@ -104,6 +104,7 @@ void initialize_subspace_vectors(std::shared_ptr exec, gen, subspace_vectors->get_size()[0] * subspace_vectors->get_stride(), 0.0, 1.0, subspace_vectors->get_values()); + curand::destroy(gen); } } diff --git a/hip/base/hiprand_bindings.hip.hpp b/hip/base/hiprand_bindings.hip.hpp index 14e144f6d84..dfef3bb84b4 100644 --- a/hip/base/hiprand_bindings.hip.hpp +++ b/hip/base/hiprand_bindings.hip.hpp @@ -87,6 +87,11 @@ inline hiprandGenerator_t rand_generator(int64 seed, return gen; } +inline void destroy(hiprandGenerator_t gen) +{ + GKO_ASSERT_NO_HIPRAND_ERRORS(hiprandDestroyGenerator(gen)); +} + #define GKO_BIND_HIPRAND_RANDOM_VECTOR(ValueType, HiprandName) \ inline void rand_vector( \ diff --git a/hip/solver/idr_kernels.hip.cpp b/hip/solver/idr_kernels.hip.cpp index 9e6f353abe4..1a3d2931897 100644 --- a/hip/solver/idr_kernels.hip.cpp +++ b/hip/solver/idr_kernels.hip.cpp @@ -106,6 +106,7 @@ void initialize_subspace_vectors(std::shared_ptr exec, gen, subspace_vectors->get_size()[0] * subspace_vectors->get_stride(), 0.0, 1.0, subspace_vectors->get_values()); + hiprand::destroy(gen); } }