Skip to content

Commit

Permalink
Port GPU kernels for SVD to the FFI.
Browse files Browse the repository at this point in the history
Unlike the other GPU linear algebra kernels that I've ported so far, this one isn't straightforward to implement as a single kernel, and while it does support lowering without access to a GPU (no more descriptor!), it only supports dynamics shapes in the batch dimensions. There are two main technical challenges:

1. The main `gesvd` kernels in cuSolver/hipSolver only support matrices with shape `(m, n)` with `m >= n`. This means that we need to transpose the inputs and outputs as part of the lowering rule when `m < n`. (Note: we actually just use C layouts instead of Fortran layouts to implement this case.) While this could be handled in the kernel, this seemed like a lot of work for somewhat limited benefit, and it would probably have performance implications.

2. The `gesvd` and `gesvdj` kernels return `V^H` and `V` respectively, and the batched version of `gesvdj` doesn't support `full_matrices=False`. This means that we need logic in the lowering rule to handle transposition and slicing. This makes it hard to have the algorithm selection be a parameter to the kernel.

Another note: cuSolver has a 64-bit implementation of the SVD, and we always use that implementation on the CUDA backend. The 32-bit interface is included for ROCM support, and I have tested it manually. This was a feature request from jax-ml#23413.

PiperOrigin-RevId: 676839182
  • Loading branch information
dfm authored and rajasekharporeddy committed Sep 20, 2024
1 parent 28ecce3 commit e1d1425
Show file tree
Hide file tree
Showing 6 changed files with 551 additions and 21 deletions.
7 changes: 6 additions & 1 deletion jaxlib/gpu/solver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ limitations under the License.
#include "nanobind/nanobind.h"
#include "nanobind/stl/pair.h"
#include "absl/container/flat_hash_map.h"
#include "absl/strings/str_format.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_format.h"
#include "jaxlib/gpu/gpu_kernel_helpers.h"
#include "jaxlib/gpu/solver_handle_pool.h"
#include "jaxlib/gpu/solver_kernels.h"
Expand Down Expand Up @@ -481,6 +481,11 @@ nb::dict Registrations() {
dict[JAX_GPU_PREFIX "solver_orgqr_ffi"] = EncapsulateFfiHandler(OrgqrFfi);
dict[JAX_GPU_PREFIX "solver_syevd_ffi"] = EncapsulateFfiHandler(SyevdFfi);
dict[JAX_GPU_PREFIX "solver_syrk_ffi"] = EncapsulateFfiHandler(SyrkFfi);
dict[JAX_GPU_PREFIX "solver_gesvd_ffi"] = EncapsulateFfiHandler(GesvdFfi);

#ifdef JAX_GPU_CUDA
dict[JAX_GPU_PREFIX "solver_gesvdj_ffi"] = EncapsulateFfiHandler(GesvdjFfi);
#endif // JAX_GPU_CUDA

return dict;
}
Expand Down
85 changes: 85 additions & 0 deletions jaxlib/gpu/solver_interface.cc
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,91 @@ JAX_GPU_DEFINE_SYRK(gpublasComplex, gpublasCsyrk);
JAX_GPU_DEFINE_SYRK(gpublasDoubleComplex, gpublasZsyrk);
#undef JAX_GPU_DEFINE_SYRK

// Singular Value Decomposition: gesvd

#define JAX_GPU_DEFINE_GESVD(Type, Name) \
template <> \
absl::StatusOr<int> GesvdBufferSize<Type>(gpusolverDnHandle_t handle, \
signed char job, int m, int n) { \
int lwork; \
JAX_RETURN_IF_ERROR( \
JAX_AS_STATUS(Name##_bufferSize(handle, job, job, m, n, &lwork))); \
return lwork; \
} \
\
template <> \
absl::Status Gesvd<Type>(gpusolverDnHandle_t handle, signed char job, int m, \
int n, Type *a, RealType<Type>::value *s, Type *u, \
Type *vt, Type *workspace, int lwork, int *info) { \
return JAX_AS_STATUS(Name(handle, job, job, m, n, a, m, s, u, m, vt, n, \
workspace, lwork, /*rwork=*/nullptr, info)); \
}

JAX_GPU_DEFINE_GESVD(float, gpusolverDnSgesvd);
JAX_GPU_DEFINE_GESVD(double, gpusolverDnDgesvd);
JAX_GPU_DEFINE_GESVD(gpuComplex, gpusolverDnCgesvd);
JAX_GPU_DEFINE_GESVD(gpuDoubleComplex, gpusolverDnZgesvd);
#undef JAX_GPU_DEFINE_GESVD

#ifdef JAX_GPU_CUDA

#define JAX_GPU_DEFINE_GESVDJ(Type, Name) \
template <> \
absl::StatusOr<int> GesvdjBufferSize<Type>( \
gpusolverDnHandle_t handle, gpusolverEigMode_t job, int econ, int m, \
int n, gpuGesvdjInfo_t params) { \
int lwork; \
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(Name##_bufferSize( \
handle, job, econ, m, n, /*a=*/nullptr, /*lda=*/m, /*s=*/nullptr, \
/*u=*/nullptr, /*ldu=*/m, /*v=*/nullptr, /*ldv=*/n, &lwork, params))); \
return lwork; \
} \
\
template <> \
absl::Status Gesvdj<Type>( \
gpusolverDnHandle_t handle, gpusolverEigMode_t job, int econ, int m, \
int n, Type *a, RealType<Type>::value *s, Type *u, Type *v, \
Type *workspace, int lwork, int *info, gpuGesvdjInfo_t params) { \
return JAX_AS_STATUS(Name(handle, job, econ, m, n, a, m, s, u, m, v, n, \
workspace, lwork, info, params)); \
}

JAX_GPU_DEFINE_GESVDJ(float, gpusolverDnSgesvdj);
JAX_GPU_DEFINE_GESVDJ(double, gpusolverDnDgesvdj);
JAX_GPU_DEFINE_GESVDJ(gpuComplex, gpusolverDnCgesvdj);
JAX_GPU_DEFINE_GESVDJ(gpuDoubleComplex, gpusolverDnZgesvdj);
#undef JAX_GPU_DEFINE_GESVDJ

#define JAX_GPU_DEFINE_GESVDJ_BATCHED(Type, Name) \
template <> \
absl::StatusOr<int> GesvdjBatchedBufferSize<Type>( \
gpusolverDnHandle_t handle, gpusolverEigMode_t job, int m, int n, \
gpuGesvdjInfo_t params, int batch) { \
int lwork; \
JAX_RETURN_IF_ERROR(JAX_AS_STATUS( \
Name##_bufferSize(handle, job, m, n, /*a=*/nullptr, /*lda=*/m, \
/*s=*/nullptr, /*u=*/nullptr, /*ldu=*/m, \
/*v=*/nullptr, /*ldv=*/n, &lwork, params, batch))); \
return lwork; \
} \
\
template <> \
absl::Status GesvdjBatched<Type>( \
gpusolverDnHandle_t handle, gpusolverEigMode_t job, int m, int n, \
Type *a, RealType<Type>::value *s, Type *u, Type *v, Type *workspace, \
int lwork, int *info, gpuGesvdjInfo_t params, int batch) { \
return JAX_AS_STATUS(Name(handle, job, m, n, a, m, s, u, m, v, n, \
workspace, lwork, info, params, batch)); \
}

JAX_GPU_DEFINE_GESVDJ_BATCHED(float, gpusolverDnSgesvdjBatched);
JAX_GPU_DEFINE_GESVDJ_BATCHED(double, gpusolverDnDgesvdjBatched);
JAX_GPU_DEFINE_GESVDJ_BATCHED(gpuComplex, gpusolverDnCgesvdjBatched);
JAX_GPU_DEFINE_GESVDJ_BATCHED(gpuDoubleComplex, gpusolverDnZgesvdjBatched);
#undef JAX_GPU_DEFINE_GESVDJ_BATCHED

#endif // JAX_GPU_CUDA

} // namespace solver
} // namespace JAX_GPU_NAMESPACE
} // namespace jax
43 changes: 43 additions & 0 deletions jaxlib/gpu/solver_interface.h
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,49 @@ JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::Status, Syevd);
JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::Status, Syrk);
#undef JAX_GPU_SOLVER_Syrk_ARGS

// Singular Value Decomposition: gesvd

#define JAX_GPU_SOLVER_GesvdBufferSize_ARGS(Type, ...) \
gpusolverDnHandle_t handle, signed char job, int m, int n
JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::StatusOr<int>, GesvdBufferSize);
#undef JAX_GPU_SOLVER_GesvdBufferSize_ARGS

#define JAX_GPU_SOLVER_Gesvd_ARGS(Type, Real) \
gpusolverDnHandle_t handle, signed char job, int m, int n, Type *a, Real *s, \
Type *u, Type *vt, Type *workspace, int lwork, int *info
JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::Status, Gesvd);
#undef JAX_GPU_SOLVER_Gesvd_ARGS

#ifdef JAX_GPU_CUDA

#define JAX_GPU_SOLVER_GesvdjBufferSize_ARGS(Type, ...) \
gpusolverDnHandle_t handle, gpusolverEigMode_t job, int econ, int m, int n, \
gesvdjInfo_t params
JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::StatusOr<int>, GesvdjBufferSize);
#undef JAX_GPU_SOLVER_GesvdjBufferSize_ARGS

#define JAX_GPU_SOLVER_Gesvdj_ARGS(Type, Real) \
gpusolverDnHandle_t handle, gpusolverEigMode_t job, int econ, int m, int n, \
Type *a, Real *s, Type *u, Type *v, Type *workspace, \
int lwork, int *info, gesvdjInfo_t params
JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::Status, Gesvdj);
#undef JAX_GPU_SOLVER_Gesvdj_ARGS

#define JAX_GPU_SOLVER_GesvdjBatchedBufferSize_ARGS(Type, ...) \
gpusolverDnHandle_t handle, gpusolverEigMode_t job, int m, int n, \
gpuGesvdjInfo_t params, int batch
JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::StatusOr<int>, GesvdjBatchedBufferSize);
#undef JAX_GPU_SOLVER_GesvdjBatchedBufferSize_ARGS

#define JAX_GPU_SOLVER_GesvdjBatched_ARGS(Type, Real) \
gpusolverDnHandle_t handle, gpusolverEigMode_t job, int m, int n, Type *a, \
Real *s, Type *u, Type *v, Type *workspace, int lwork, \
int *info, gpuGesvdjInfo_t params, int batch
JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::Status, GesvdjBatched);
#undef JAX_GPU_SOLVER_GesvdjBatched_ARGS

#endif // JAX_GPU_CUDA

#undef JAX_GPU_SOLVER_EXPAND_DEFINITION

} // namespace solver
Expand Down
Loading

0 comments on commit e1d1425

Please sign in to comment.