Skip to content

Commit

Permalink
Add nb::arg to nanobind definitions to generate better python annot…
Browse files Browse the repository at this point in the history
…ations.

PiperOrigin-RevId: 586721759
  • Loading branch information
viswanadha9 authored and jax authors committed Nov 30, 2023
1 parent 11d7a2b commit bd46e5c
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 54 deletions.
54 changes: 26 additions & 28 deletions jaxlib/cpu/_lapack.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -12,33 +12,31 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any

def cgesdd_rwork_size(*args, **kwargs) -> Any: ...
def cgesdd_work_size(*args, **kwargs) -> Any: ...
def dgesdd_work_size(*args, **kwargs) -> Any: ...
def gesdd_iwork_size(*args, **kwargs) -> Any: ...
def heevd_rwork_size(*args, **kwargs) -> Any: ...
def heevd_work_size(*args, **kwargs) -> Any: ...
def cgesdd_rwork_size(m: int, n: int, compute_uv: int) -> int: ...
def cgesdd_work_size(m: int, n: int, job_opt_compute_uv: bool, job_opt_full_matrices: bool) -> int: ...
def dgesdd_work_size(m: int, n: int, job_opt_compute_uv: bool, job_opt_full_matrices: bool) -> int: ...
def gesdd_iwork_size(m: int, n: int) -> int: ...
def heevd_rwork_size(n: int) -> int: ...
def heevd_work_size(n: int) -> int: ...
def initialize() -> None: ...
def lapack_cgehrd_workspace(*args, **kwargs) -> Any: ...
def lapack_cgeqrf_workspace(*args, **kwargs) -> Any: ...
def lapack_chetrd_workspace(*args, **kwargs) -> Any: ...
def lapack_cungqr_workspace(*args, **kwargs) -> Any: ...
def lapack_dgehrd_workspace(*args, **kwargs) -> Any: ...
def lapack_dgeqrf_workspace(*args, **kwargs) -> Any: ...
def lapack_dorgqr_workspace(*args, **kwargs) -> Any: ...
def lapack_dsytrd_workspace(*args, **kwargs) -> Any: ...
def lapack_sgehrd_workspace(*args, **kwargs) -> Any: ...
def lapack_sgeqrf_workspace(*args, **kwargs) -> Any: ...
def lapack_sorgqr_workspace(*args, **kwargs) -> Any: ...
def lapack_ssytrd_workspace(*args, **kwargs) -> Any: ...
def lapack_zgehrd_workspace(*args, **kwargs) -> Any: ...
def lapack_zgeqrf_workspace(*args, **kwargs) -> Any: ...
def lapack_zhetrd_workspace(*args, **kwargs) -> Any: ...
def lapack_zungqr_workspace(*args, **kwargs) -> Any: ...
def lapack_cgehrd_workspace(lda: int, n: int, ilo: int, ihi: int) -> int: ...
def lapack_cgeqrf_workspace(m: int, n: int) -> int: ...
def lapack_chetrd_workspace(lda: int, n: int) -> int: ...
def lapack_cungqr_workspace(m: int, n: int, k: int) -> int: ...
def lapack_dgehrd_workspace(lda: int, n: int, ilo: int, ihi: int) -> int: ...
def lapack_dgeqrf_workspace(m: int, n: int) -> int: ...
def lapack_dorgqr_workspace(m: int, n: int, k: int) -> int: ...
def lapack_dsytrd_workspace(lda: int, n: int) -> int: ...
def lapack_sgehrd_workspace(lda: int, n: int, ilo: int, ihi: int) -> int: ...
def lapack_sgeqrf_workspace(m: int, n: int) -> int: ...
def lapack_sorgqr_workspace(m: int, n: int, k: int) -> int: ...
def lapack_ssytrd_workspace(lda: int, n: int) -> int: ...
def lapack_zgehrd_workspace(lda: int, n: int, ilo: int, ihi: int) -> int: ...
def lapack_zgeqrf_workspace(m: int, n: int) -> int: ...
def lapack_zhetrd_workspace(lda: int, n: int) -> int: ...
def lapack_zungqr_workspace(m: int, n: int, k: int) -> int: ...
def registrations() -> dict: ...
def sgesdd_work_size(*args, **kwargs) -> Any: ...
def syevd_iwork_size(*args, **kwargs) -> Any: ...
def syevd_work_size(*args, **kwargs) -> Any: ...
def zgesdd_work_size(*args, **kwargs) -> Any: ...
def sgesdd_work_size(m: int, n: int, job_opt_compute_uv: bool, job_opt_full_matrices: bool) -> int: ...
def syevd_iwork_size(n: int) -> int: ...
def syevd_work_size(n: int) -> int: ...
def zgesdd_work_size(m: int, n: int, job_opt_compute_uv: bool, job_opt_full_matrices: bool) -> int: ...
78 changes: 52 additions & 26 deletions jaxlib/cpu/lapack.cc
Original file line number Diff line number Diff line change
Expand Up @@ -230,32 +230,58 @@ NB_MODULE(_lapack, m) {
m.def("initialize", GetLapackKernelsFromScipy);

m.def("registrations", &Registrations);
m.def("lapack_sgeqrf_workspace", &Geqrf<float>::Workspace);
m.def("lapack_dgeqrf_workspace", &Geqrf<double>::Workspace);
m.def("lapack_cgeqrf_workspace", &Geqrf<std::complex<float>>::Workspace);
m.def("lapack_zgeqrf_workspace", &Geqrf<std::complex<double>>::Workspace);
m.def("lapack_sorgqr_workspace", &Orgqr<float>::Workspace);
m.def("lapack_dorgqr_workspace", &Orgqr<double>::Workspace);
m.def("lapack_cungqr_workspace", &Orgqr<std::complex<float>>::Workspace);
m.def("lapack_zungqr_workspace", &Orgqr<std::complex<double>>::Workspace);
m.def("gesdd_iwork_size", &GesddIworkSize);
m.def("sgesdd_work_size", &RealGesdd<float>::Workspace);
m.def("dgesdd_work_size", &RealGesdd<double>::Workspace);
m.def("cgesdd_rwork_size", &ComplexGesddRworkSize);
m.def("cgesdd_work_size", &ComplexGesdd<std::complex<float>>::Workspace);
m.def("zgesdd_work_size", &ComplexGesdd<std::complex<double>>::Workspace);
m.def("syevd_work_size", &SyevdWorkSize);
m.def("syevd_iwork_size", &SyevdIworkSize);
m.def("heevd_work_size", &HeevdWorkSize);
m.def("heevd_rwork_size", &HeevdRworkSize);
m.def("lapack_sgehrd_workspace", &Gehrd<float>::Workspace);
m.def("lapack_dgehrd_workspace", &Gehrd<double>::Workspace);
m.def("lapack_cgehrd_workspace", &Gehrd<std::complex<float>>::Workspace);
m.def("lapack_zgehrd_workspace", &Gehrd<std::complex<double>>::Workspace);
m.def("lapack_ssytrd_workspace", &Sytrd<float>::Workspace);
m.def("lapack_dsytrd_workspace", &Sytrd<double>::Workspace);
m.def("lapack_chetrd_workspace", &Sytrd<std::complex<float>>::Workspace);
m.def("lapack_zhetrd_workspace", &Sytrd<std::complex<double>>::Workspace);
m.def("lapack_sgeqrf_workspace", &Geqrf<float>::Workspace, nb::arg("m"),
nb::arg("n"));
m.def("lapack_dgeqrf_workspace", &Geqrf<double>::Workspace, nb::arg("m"),
nb::arg("n"));
m.def("lapack_cgeqrf_workspace", &Geqrf<std::complex<float>>::Workspace,
nb::arg("m"), nb::arg("n"));
m.def("lapack_zgeqrf_workspace", &Geqrf<std::complex<double>>::Workspace,
nb::arg("m"), nb::arg("n"));
m.def("lapack_sorgqr_workspace", &Orgqr<float>::Workspace, nb::arg("m"),
nb::arg("n"), nb::arg("k"));
m.def("lapack_dorgqr_workspace", &Orgqr<double>::Workspace, nb::arg("m"),
nb::arg("n"), nb::arg("k"));
m.def("lapack_cungqr_workspace", &Orgqr<std::complex<float>>::Workspace,
nb::arg("m"), nb::arg("n"), nb::arg("k"));
m.def("lapack_zungqr_workspace", &Orgqr<std::complex<double>>::Workspace,
nb::arg("m"), nb::arg("n"), nb::arg("k"));
m.def("gesdd_iwork_size", &GesddIworkSize, nb::arg("m"), nb::arg("n"));
m.def("sgesdd_work_size", &RealGesdd<float>::Workspace, nb::arg("m"),
nb::arg("n"), nb::arg("job_opt_compute_uv"),
nb::arg("job_opt_full_matrices"));
m.def("dgesdd_work_size", &RealGesdd<double>::Workspace, nb::arg("m"),
nb::arg("n"), nb::arg("job_opt_compute_uv"),
nb::arg("job_opt_full_matrices"));
m.def("cgesdd_rwork_size", &ComplexGesddRworkSize, nb::arg("m"), nb::arg("n"),
nb::arg("compute_uv"));
m.def("cgesdd_work_size", &ComplexGesdd<std::complex<float>>::Workspace,
nb::arg("m"), nb::arg("n"), nb::arg("job_opt_compute_uv"),
nb::arg("job_opt_full_matrices"));
m.def("zgesdd_work_size", &ComplexGesdd<std::complex<double>>::Workspace,
nb::arg("m"), nb::arg("n"), nb::arg("job_opt_compute_uv"),
nb::arg("job_opt_full_matrices"));
m.def("syevd_work_size", &SyevdWorkSize, nb::arg("n"));
m.def("syevd_iwork_size", &SyevdIworkSize, nb::arg("n"));
m.def("heevd_work_size", &HeevdWorkSize, nb::arg("n"));
m.def("heevd_rwork_size", &HeevdRworkSize, nb::arg("n"));

m.def("lapack_sgehrd_workspace", &Gehrd<float>::Workspace, nb::arg("lda"),
nb::arg("n"), nb::arg("ilo"), nb::arg("ihi"));
m.def("lapack_dgehrd_workspace", &Gehrd<double>::Workspace, nb::arg("lda"),
nb::arg("n"), nb::arg("ilo"), nb::arg("ihi"));
m.def("lapack_cgehrd_workspace", &Gehrd<std::complex<float>>::Workspace,
nb::arg("lda"), nb::arg("n"), nb::arg("ilo"), nb::arg("ihi"));
m.def("lapack_zgehrd_workspace", &Gehrd<std::complex<double>>::Workspace,
nb::arg("lda"), nb::arg("n"), nb::arg("ilo"), nb::arg("ihi"));
m.def("lapack_ssytrd_workspace", &Sytrd<float>::Workspace, nb::arg("lda"),
nb::arg("n"));
m.def("lapack_dsytrd_workspace", &Sytrd<double>::Workspace, nb::arg("lda"),
nb::arg("n"));
m.def("lapack_chetrd_workspace", &Sytrd<std::complex<float>>::Workspace,
nb::arg("lda"), nb::arg("n"));
m.def("lapack_zhetrd_workspace", &Sytrd<std::complex<double>>::Workspace,
nb::arg("lda"), nb::arg("n"));
}

} // namespace
Expand Down

0 comments on commit bd46e5c

Please sign in to comment.