diff --git a/lax/src/eigh.rs b/lax/src/eigh.rs index 3605c2da..3207f7ce 100644 --- a/lax/src/eigh.rs +++ b/lax/src/eigh.rs @@ -5,17 +5,18 @@ use crate::{error::*, layout::MatrixLayout}; use cauchy::*; use num_traits::{ToPrimitive, Zero}; -pub trait Eigh: Sized { - type Elem: Scalar; - +pub(crate) trait Eigh: Scalar { /// Allocate working memory for eigenvalue problem - fn eigh_work(calc_eigenvec: bool, layout: MatrixLayout, uplo: UPLO) -> Result; + fn eigh_work(calc_eigenvec: bool, layout: MatrixLayout, uplo: UPLO) -> Result>; /// Solve eigenvalue problem - fn eigh_calc(&mut self, a: &mut [Self::Elem]) -> Result<&[::Real]>; + fn eigh_calc<'work>( + work: &'work mut EighWork, + a: &mut [Self], + ) -> Result<&'work [Self::Real]>; } -/// Working memory for symmetric/Hermitian eigenvalue problem. See [Eigh trait](trait.Eigh.html) +/// Working memory for symmetric/Hermitian eigenvalue problem. See [LapackStrict trait](trait.LapackStrict.html) pub struct EighWork { jobz: u8, uplo: UPLO, @@ -29,17 +30,15 @@ pub struct EighWork { macro_rules! impl_eigh_work_real { ($scalar:ty, $ev:path) => { - impl Eigh for EighWork<$scalar> { - type Elem = $scalar; - - fn eigh_work(calc_v: bool, layout: MatrixLayout, uplo: UPLO) -> Result { + impl Eigh for $scalar { + fn eigh_work(calc_v: bool, layout: MatrixLayout, uplo: UPLO) -> Result> { assert_eq!(layout.len(), layout.lda()); let n = layout.len(); let jobz = if calc_v { b'V' } else { b'N' }; let mut eigs = unsafe { vec_uninit(n as usize) }; let mut info = 0; - let mut work_size = [Self::Elem::zero()]; + let mut work_size = [Self::zero()]; unsafe { $ev( jobz, @@ -66,28 +65,28 @@ macro_rules! impl_eigh_work_real { }) } - fn eigh_calc( - &mut self, - a: &mut [Self::Elem], - ) -> Result<&[::Real]> { - assert_eq!(a.len(), (self.n * self.n) as usize); + fn eigh_calc<'work>( + work: &'work mut EighWork, + a: &mut [Self], + ) -> Result<&'work [Self::Real]> { + assert_eq!(a.len(), (work.n * work.n) as usize); let mut info = 0; - let lwork = self.work.len() as i32; + let lwork = work.work.len() as i32; unsafe { $ev( - self.jobz, - self.uplo as u8, - self.n, + work.jobz, + work.uplo as u8, + work.n, a, - self.n, - &mut self.eigs, - &mut self.work, + work.n, + &mut work.eigs, + &mut work.work, lwork, &mut info, ); } info.as_lapack_result()?; - Ok(&self.eigs) + Ok(&work.eigs) } } }; @@ -98,10 +97,8 @@ impl_eigh_work_real!(f64, lapack::dsyev); macro_rules! impl_eigh_work_complex { ($scalar:ty, $ev:path) => { - impl Eigh for EighWork<$scalar> { - type Elem = $scalar; - - fn eigh_work(calc_v: bool, layout: MatrixLayout, uplo: UPLO) -> Result { + impl Eigh for $scalar { + fn eigh_work(calc_v: bool, layout: MatrixLayout, uplo: UPLO) -> Result> { assert_eq!(layout.len(), layout.lda()); let n = layout.len(); let jobz = if calc_v { b'V' } else { b'N' }; @@ -109,7 +106,7 @@ macro_rules! impl_eigh_work_complex { let mut a = []; let mut info = 0; - let mut work_size = [Self::Elem::zero()]; + let mut work_size = [Self::zero()]; let mut rwork = Vec::with_capacity(3 * n as usize - 2); unsafe { $ev( @@ -138,29 +135,29 @@ macro_rules! impl_eigh_work_complex { }) } - fn eigh_calc( - &mut self, - a: &mut [Self::Elem], - ) -> Result<&[::Real]> { - assert_eq!(a.len(), (self.n * self.n) as usize); + fn eigh_calc<'work>( + work: &'work mut EighWork, + a: &mut [Self], + ) -> Result<&'work [Self::Real]> { + assert_eq!(a.len(), (work.n * work.n) as usize); let mut info = 0; - let lwork = self.work.len() as i32; + let lwork = work.work.len() as i32; unsafe { $ev( - self.jobz, - self.uplo as u8, - self.n, + work.jobz, + work.uplo as u8, + work.n, a, - self.n, - &mut self.eigs, - &mut self.work, + work.n, + &mut work.eigs, + &mut work.work, lwork, - self.rwork.as_mut().unwrap(), + work.rwork.as_mut().unwrap(), &mut info, ); } info.as_lapack_result()?; - Ok(&self.eigs) + Ok(&work.eigs) } } }; diff --git a/lax/src/eigh_generalized.rs b/lax/src/eigh_generalized.rs index 61469c95..fa34ee1a 100644 --- a/lax/src/eigh_generalized.rs +++ b/lax/src/eigh_generalized.rs @@ -3,36 +3,25 @@ use crate::{error::*, layout::MatrixLayout}; use cauchy::*; use num_traits::{ToPrimitive, Zero}; -/// Types of generalized eigenvalue problem -#[allow(dead_code)] // FIXME create interface to use ABxlx and BAxlx -#[repr(i32)] -pub enum ITYPE { - /// Solve $ A x = \lambda B x $ - AxlBx = 1, - /// Solve $ A B x = \lambda x $ - ABxlx = 2, - /// Solve $ B A x = \lambda x $ - BAxlx = 3, -} - /// Generalized eigenvalue problem for Symmetric/Hermite matrices -pub trait EighGeneralized: Sized { - type Elem: Scalar; - +pub(crate) trait EighGeneralized: Scalar { /// Allocate working memory - fn eigh_generalized_work(calc_eigenvec: bool, layout: MatrixLayout, uplo: UPLO) - -> Result; + fn eigh_generalized_work( + calc_eigenvec: bool, + layout: MatrixLayout, + uplo: UPLO, + ) -> Result>; /// Solve generalized eigenvalue problem - fn eigh_generalized_calc( - &mut self, - a: &mut [Self::Elem], - b: &mut [Self::Elem], - ) -> Result<&[::Real]>; + fn eigh_generalized_calc<'work>( + work: &'work mut EighGeneralizedWork, + a: &mut [Self], + b: &mut [Self], + ) -> Result<&'work [Self::Real]>; } /// Working memory for symmetric/Hermitian generalized eigenvalue problem. -/// See [EighGeneralized trait](trait.EighGeneralized.html) +/// See [LapackStrict trait](trait.LapackStrict.html) pub struct EighGeneralizedWork { jobz: u8, uplo: UPLO, @@ -46,21 +35,19 @@ pub struct EighGeneralizedWork { macro_rules! impl_eigh_work_real { ($scalar:ty, $ev:path) => { - impl EighGeneralized for EighGeneralizedWork<$scalar> { - type Elem = $scalar; - + impl EighGeneralized for $scalar { fn eigh_generalized_work( calc_v: bool, layout: MatrixLayout, uplo: UPLO, - ) -> Result { + ) -> Result> { assert_eq!(layout.len(), layout.lda()); let n = layout.len(); let jobz = if calc_v { b'V' } else { b'N' }; let mut eigs = unsafe { vec_uninit(n as usize) }; let mut info = 0; - let mut work_size = [Self::Elem::zero()]; + let mut work_size = [Self::zero()]; unsafe { $ev( &[ITYPE::AxlBx as i32], @@ -90,32 +77,32 @@ macro_rules! impl_eigh_work_real { }) } - fn eigh_generalized_calc( - &mut self, - a: &mut [Self::Elem], - b: &mut [Self::Elem], - ) -> Result<&[::Real]> { - assert_eq!(a.len(), (self.n * self.n) as usize); + fn eigh_generalized_calc<'work>( + work: &'work mut EighGeneralizedWork, + a: &mut [Self], + b: &mut [Self], + ) -> Result<&'work [Self::Real]> { + assert_eq!(a.len(), (work.n * work.n) as usize); let mut info = 0; - let lwork = self.work.len() as i32; + let lwork = work.work.len() as i32; unsafe { $ev( &[ITYPE::AxlBx as i32], - self.jobz, - self.uplo as u8, - self.n, + work.jobz, + work.uplo as u8, + work.n, a, - self.n, + work.n, b, - self.n, - &mut self.eigs, - &mut self.work, + work.n, + &mut work.eigs, + &mut work.work, lwork, &mut info, ); } info.as_lapack_result()?; - Ok(&self.eigs) + Ok(&work.eigs) } } }; @@ -126,14 +113,12 @@ impl_eigh_work_real!(f64, lapack::dsygv); macro_rules! impl_eigh_work_complex { ($scalar:ty, $ev:path) => { - impl EighGeneralized for EighGeneralizedWork<$scalar> { - type Elem = $scalar; - + impl EighGeneralized for $scalar { fn eigh_generalized_work( calc_v: bool, layout: MatrixLayout, uplo: UPLO, - ) -> Result { + ) -> Result> { assert_eq!(layout.len(), layout.lda()); let n = layout.len(); let jobz = if calc_v { b'V' } else { b'N' }; @@ -142,7 +127,7 @@ macro_rules! impl_eigh_work_complex { let mut eigs = unsafe { vec_uninit(n as usize) }; let mut info = 0; - let mut work_size = [Self::Elem::zero()]; + let mut work_size = [Self::zero()]; let mut rwork = unsafe { vec_uninit(3 * n as usize - 2) }; unsafe { $ev( @@ -174,33 +159,33 @@ macro_rules! impl_eigh_work_complex { }) } - fn eigh_generalized_calc( - &mut self, - a: &mut [Self::Elem], - b: &mut [Self::Elem], - ) -> Result<&[::Real]> { - assert_eq!(a.len(), (self.n * self.n) as usize); + fn eigh_generalized_calc<'work>( + work: &'work mut EighGeneralizedWork, + a: &mut [Self], + b: &mut [Self], + ) -> Result<&'work [Self::Real]> { + assert_eq!(a.len(), (work.n * work.n) as usize); let mut info = 0; - let lwork = self.work.len() as i32; + let lwork = work.work.len() as i32; unsafe { $ev( &[ITYPE::AxlBx as i32], - self.jobz, - self.uplo as u8, - self.n, + work.jobz, + work.uplo as u8, + work.n, a, - self.n, + work.n, b, - self.n, - &mut self.eigs, - &mut self.work, + work.n, + &mut work.eigs, + &mut work.work, lwork, - self.rwork.as_mut().unwrap(), + work.rwork.as_mut().unwrap(), &mut info, ); } info.as_lapack_result()?; - Ok(&self.eigs) + Ok(&work.eigs) } } }; diff --git a/lax/src/lib.rs b/lax/src/lib.rs index bea5d2dc..99939913 100644 --- a/lax/src/lib.rs +++ b/lax/src/lib.rs @@ -81,6 +81,7 @@ mod qr; mod rcond; mod solve; mod solveh; +mod strict; mod svd; mod svddc; mod traits; @@ -96,6 +97,7 @@ pub use self::qr::*; pub use self::rcond::*; pub use self::solve::*; pub use self::solveh::*; +pub use self::strict::*; pub use self::svd::*; pub use self::svddc::*; pub use self::traits::*; @@ -147,6 +149,18 @@ impl NormType { } } +/// Types of generalized eigenvalue problem +#[allow(dead_code)] // FIXME create interface to use ABxlx and BAxlx +#[repr(i32)] +pub enum ITYPE { + /// Solve $ A x = \lambda B x $ + AxlBx = 1, + /// Solve $ A B x = \lambda x $ + ABxlx = 2, + /// Solve $ B A x = \lambda x $ + BAxlx = 3, +} + /// Create a vector without initialization /// /// Safety diff --git a/lax/src/strict.rs b/lax/src/strict.rs new file mode 100644 index 00000000..30a9deb2 --- /dev/null +++ b/lax/src/strict.rs @@ -0,0 +1,80 @@ +use crate::{error::*, layout::*, *}; +use cauchy::*; + +pub trait LapackStrict: Scalar { + /// Allocate working memory for eigenvalue problem $A x = \lambda x$ + fn eigh_work(calc_eigenvec: bool, layout: MatrixLayout, uplo: UPLO) -> Result>; + + /// Solve eigenvalue problem $A x = \lambda x$ using allocated working memory + fn eigh_calc<'work>( + work: &'work mut EighWork, + a: &mut [Self], + ) -> Result<&'work [Self::Real]>; + + /// Allocate working memory for generalized eigenvalue problem $Ax = \lambda Bx$ + fn eigh_generalized_work( + calc_eigenvec: bool, + layout: MatrixLayout, + uplo: UPLO, + ) -> Result>; + + /// Solve generalized eigenvalue problem $Ax = \lambda Bx$ using allocated working memory + fn eigh_generalized_calc<'work>( + work: &'work mut EighGeneralizedWork, + a: &mut [Self], + b: &mut [Self], + ) -> Result<&'work [Self::Real]>; +} + +macro_rules! impl_lapack_strict_component { + ($impl_trait:path; fn $name:ident $(<$lt:lifetime>)* ( $( $arg_name:ident : $arg_type:ty, )*) -> $result:ty ;) => { + fn $name $(<$lt>)* ($($arg_name:$arg_type,)*) -> $result { + ::$name($($arg_name),*) + } + }; +} + +macro_rules! impl_lapack_strict { + ($scalar:ty) => { + impl LapackStrict for $scalar { + impl_lapack_strict_component!( + Eigh; + fn eigh_work( + calc_eigenvec: bool, + layout: MatrixLayout, + uplo: UPLO, + ) -> Result>; + ); + impl_lapack_strict_component!( + Eigh; + fn eigh_calc<'work>( + work: &'work mut EighWork, + a: &mut [Self], + ) -> Result<&'work [Self::Real]>; + ); + + impl_lapack_strict_component! ( + EighGeneralized; + fn eigh_generalized_work( + calc_eigenvec: bool, + layout: MatrixLayout, + uplo: UPLO, + ) -> Result>; + ); + + impl_lapack_strict_component! ( + EighGeneralized; + fn eigh_generalized_calc<'work>( + work: &'work mut EighGeneralizedWork, + a: &mut [Self], + b: &mut [Self], + ) -> Result<&'work [Self::Real]>; + ); + } + }; +} + +impl_lapack_strict!(f32); +impl_lapack_strict!(f64); +impl_lapack_strict!(c32); +impl_lapack_strict!(c64);