diff --git a/Cargo.toml b/Cargo.toml index 9a269c2..793b5ad 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,5 +18,5 @@ num-traits = { version = "0.2", default-features = false } [features] default = ["std"] -std = ["arrayvec/default", "num-complex/default", "num-traits/default"] -libm = ["num-traits/libm"] +std = ["num-traits/default", "num-complex/default"] +libm = ["num-traits/libm", "num-complex/libm"] diff --git a/README.md b/README.md index ff62382..83bb5b2 100644 --- a/README.md +++ b/README.md @@ -10,11 +10,9 @@ negative charges and the true zeros as positive charges. This enables finding all complex roots simultaneously, converging cubically (worst-case it converges linearly for zeros of multiplicity). -This crate is `#![no_std]` and tries to have minimal dependencies: -- [num-complex](https://crates.io/crates/num-complex) for Complex number types -- [arrayvec](https://crates.io/crates/arrayvec) to avoid allocations -- [num-traits](https://crates.io/crates/num-traits) to be generic over floating -point types. +This crate is `#![no_std]` and tries to have minimal dependencies. It uses +[arrayvec](https://crates.io/crates/arrayvec) to avoid allocations, which will +be removed in the future when rust stabilises support for const-generics. Usage @@ -30,11 +28,12 @@ and then call the `aberth` method on your polynomial. ```rust use aberth::aberth; const EPSILON: f32 = 0.001; +const MAX_ITERATIONS: u32 = 10; // 0 = -1 + 2x + 4x^4 + 11x^9 let polynomial = [-1., 2., 0., 0., 4., 0., 0., 0., 0., 11.]; -let roots = aberth(&polynomial, EPSILON).unwrap(); +let roots = aberth(&polynomial, MAX_ITERATIONS, EPSILON); // [ // Complex { re: 0.4293261, im: 1.084202e-19 }, // Complex { re: 0.7263235, im: 0.4555030 }, @@ -48,18 +47,41 @@ let roots = aberth(&polynomial, EPSILON).unwrap(); // ] ``` -Note that the returned values are not sorted in any particular order. +The above method does not require any allocation, instead doing all the +computation on the stack. It is generic over any size of polynomial, but the +size of the polynomial must be known at compile time. + +If `std` is available then there is also an `AberthSolver` struct which +allocates some memory to support dynamically sized polynomials at run time. +This may also be good to use when you are dealing with polynomials with many +terms, as it uses the heap instead of blowing up the stack. + +```rust +use aberth::AberthSolver; + +let mut solver = AberthSolver::new(); +solver.epsilon = 0.001; +solver.max_iterations = 10; -The function returns an `Err` if it fails to converge after 100 cycles. This -is excessive. Even polynomials of degree 20 will converge after <20 iterations. -Using a larger epsilon can usually avoid these errors. +// 0 = -1 + 2x + 4x^3 + 11x^4 +let a = [-1., 2., 0., 4., 11.]; +// 0 = -28 + 39x^2 - 12x^3 + x^4 +let b = [-28., 0., 39., -12., 1.]; + +for polynomial in [a, b] { + let roots = solver.find_roots(&polynomial); + // ... +} +``` + +Note that the returned values are not sorted in any particular order. The coefficient of the highest degree term should not be zero or you will get nonsense extra roots (probably at 0 + 0j). `#![no_std]` --------- +------------ To use in a `no_std` environment you must disable `default-features` and enable the `libm` feature: diff --git a/src/internal.rs b/src/internal.rs new file mode 100644 index 0000000..e395642 --- /dev/null +++ b/src/internal.rs @@ -0,0 +1,233 @@ +use super::StopReason; +use core::{fmt::Debug, iter::zip}; +use num_complex::{Complex, ComplexFloat}; +use num_traits::{ + cast, + float::{Float, FloatConst}, + identities::{One, Zero}, + MulAdd, +}; + +/// Find all of the roots of a polynomial using Aberth's method. +pub fn aberth_raw>( + polynomial: &[Complex], + dydx: &[Complex], + initial_guesses: &mut [Complex], + out: &mut [Complex], + max_iterations: u32, + epsilon: F, +) -> StopReason { + out.copy_from_slice(initial_guesses); + let mut zs = initial_guesses; + let mut new_zs = out; + + for iteration in 1..=max_iterations { + let mut converged = true; + + for i in 0..zs.len() { + let p_of_z = sample_polynomial(polynomial, zs[i]); + let dydx_of_z = sample_polynomial(dydx, zs[i]); + let sum = (0..zs.len()) + .filter(|&k| k != i) + .fold(Complex::::zero(), |acc, k| { + acc + Complex::::one() / (zs[i] - zs[k]) + }); + + let new_z = zs[i] + p_of_z / (p_of_z * sum - dydx_of_z); + new_zs[i] = new_z; + + if new_z.re.is_nan() + || new_z.im.is_nan() + || new_z.re.is_infinite() + || new_z.im.is_infinite() + { + return StopReason::Failed(iteration); + } + + if !new_z.approx_eq(zs[i], epsilon) { + converged = false; + } + } + if converged { + return StopReason::Converged(iteration); + } + core::mem::swap(&mut zs, &mut new_zs); + } + StopReason::MaxIteration(max_iterations) +} + +/// Sample the polynomial at some value of `x` using Horner's method. +/// +/// Polynomial of the form f(x) = a + b*x + c*x^2 + d*x^3 + ... +/// +/// `coefficients` is a slice containing the coefficients [a, b, c, d, ...] +pub(crate) fn sample_polynomial>( + coefficients: &[Complex], + x: Complex, +) -> Complex { + #![allow(clippy::len_zero)] + debug_assert!(coefficients.len() != 0); + let mut r = Complex::zero(); + for &c in coefficients.iter().rev() { + r = r.mul_add(x, c) + } + r +} + +/// Compute the derivative of a polynomial. +/// +/// Polynomial of the form f(x) = a + b*x + c*x^2 + d*x^3 + ... +/// +/// `coefficients` is a slice containing the coefficients [a, b, c, d, ...] +/// starting from the coefficient of the x with degree 0. +pub(crate) fn derivative( + polynomial: &[Complex], + out: &mut [Complex], +) { + polynomial + .iter() + .enumerate() + .skip(1) + .for_each(|(index, coefficient)| { + // SAFETY: it's possible to cast any usize to a float + let p = unsafe { F::from(index).unwrap_unchecked() }; + out[index - 1] = coefficient * p; + }) +} + +// Initial guesses using the method from "Iteration Methods for Finding all +// Zeros of a Polynomial Simultaneously" by Oliver Aberth. +pub(crate) fn initial_guesses< + F: Float + FloatConst + MulAdd + Debug, +>( + polynomial: &[Complex], + out: &mut [Complex], +) { + // the degree of the polynomial + let n = polynomial.len() - 1; + // SAFETY: it's possible to cast any usize to a float + let n_f: F = unsafe { cast(n).unwrap_unchecked() }; + // convert polynomial to monic form + let monic = out; + for (i, c) in polynomial.iter().enumerate() { + monic[i] = c / polynomial[n]; // TODO: check this divide by zero + } + // let a = - c_1 / n + let a: Complex = -monic[n - 1] / n_f; + // let z = w + a, + let p_of_w = { + // we can recycle monic on the fly. + for coefficient_index in 0..=n { + let c = monic[coefficient_index]; + monic[coefficient_index] = Complex::zero(); + for ((index, power), pascal) in zip( + zip(0..=coefficient_index, (0..=coefficient_index).rev()), + PascalRowIter::new(coefficient_index as u32), + ) { + // SAFETY: it's possible to cast any u32 to a float + let pascal: Complex = unsafe { cast(pascal).unwrap_unchecked() }; + monic[index] = + MulAdd::mul_add(c, pascal * a.powi(power as i32), monic[index]); + } + } + monic + }; + // convert P(w) into S(w) + let s_of_w = { + // skip the last coefficient + p_of_w.iter_mut().take(n).for_each(|coefficient| { + *coefficient = Complex::from(-coefficient.abs()) + }); + p_of_w + }; + // find r_0 + let mut int = F::one(); + let r_0 = loop { + let s_at_r0 = sample_polynomial(s_of_w, int.into()); + if s_at_r0.re > F::zero() { + break int; + } + int = int + F::one(); + }; + + { + let guesses = s_of_w; // output + + let frac_2pi_n = F::TAU() / n_f; + let frac_pi_2n = F::FRAC_PI_2() / n_f; + + for (k, guess) in guesses.iter_mut().enumerate().take(n) { + // SAFETY: it's possible to cast any usize to a float + let k_f = unsafe { cast(k).unwrap_unchecked() }; + let theta = MulAdd::mul_add(frac_2pi_n, k_f, frac_pi_2n); + + let real = r_0 * theta.cos(); + let imaginary = r_0 * theta.sin(); + + let val = Complex::new(real, imaginary) + a; + *guess = val; + } + } +} + +/// An iterator over the numbers in a row of Pascal's Triangle. +pub(crate) struct PascalRowIter { + n: u32, + k: u32, + previous: u32, +} + +impl PascalRowIter { + /// Create an iterator yielding the numbers in the nth row of Pascal's + /// triangle. + pub fn new(n: u32) -> Self { + Self { + n, + k: 0, + previous: 1, + } + } +} + +impl Iterator for PascalRowIter { + type Item = u32; + + fn next(&mut self) -> Option { + if self.k == 0 { + self.k = 1; + self.previous = 1; + return Some(1); + } + if self.k > self.n { + return None; + } + let new = self.previous * (self.n + 1 - self.k) / self.k; + self.k += 1; + self.previous = new; + Some(new) + } +} + +/// Some extra methods for Complex numbers +pub(crate) trait ComplexExt { + fn approx_eq(self, w: Self, epsilon: F) -> bool; +} + +impl ComplexExt for Complex { + /// Cheap comparison of complex numbers to within some margin, epsilon. + #[inline] + fn approx_eq(self, w: Complex, epsilon: F) -> bool { + (self.re - w.re).abs() < epsilon && (self.im - w.im).abs() < epsilon + } +} + +pub(crate) use private::ComplexCoefficient; +mod private { + use super::*; + /// A trait to group real & complex float types into a single generic type + pub trait ComplexCoefficient: Copy + Into> {} + impl ComplexCoefficient for f32 {} + impl ComplexCoefficient for f64 {} + impl ComplexCoefficient for Complex {} + impl ComplexCoefficient for Complex {} +} diff --git a/src/lib.rs b/src/lib.rs index de7ff51..b914343 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,564 +1,260 @@ -#![no_std] #![doc = include_str!("../README.md")] -#![allow(clippy::len_zero)] +#![cfg_attr(not(any(feature = "std", test, doctest)), no_std)] + +pub mod internal; +use internal::*; +pub use num_complex::Complex; + +#[cfg(any(test, doctest))] +mod tests; use arrayvec::ArrayVec; -use core::iter::zip; -use num_complex::Complex; +use core::{fmt::Debug, ops::Deref}; use num_traits::{ - cast, float::{Float, FloatConst}, - identities::{One, Zero}, + identities::Zero, MulAdd, }; -/// Find all of the roots of a polynomial using Aberth's method. +/// Find all of the roots of a polynomial using Aberth's method /// -/// Polynomial of the form f(x) = a + b*x + c*x^2 + d*x^3 + ... +/// Polynomial of the form `f(x) = a + b*x + c*x^2 + d*x^3 + ...` /// -/// `polynomial` is a slice containing the coefficients [a, b, c, d, ...] +/// `polynomial` is a slice containing the coefficients `[a, b, c, d, ...]` /// /// When two successive iterations produce roots with less than `epsilon` /// delta, the roots are returned. pub fn aberth< const TERMS: usize, - F: Float + FloatConst + MulAdd, + F: Float + FloatConst + MulAdd + Debug, + C: ComplexCoefficient + Into>, >( - polynomial: &[F; TERMS], + polynomial: &[C; TERMS], + max_iterations: u32, epsilon: F, -) -> Result, TERMS>, &'static str> { - let dydx = &derivative(polynomial); - let mut zs: ArrayVec, TERMS> = initial_guesses(polynomial); - let mut new_zs = zs.clone(); - - 'iteration: for _ in 0..100 { - let mut converged = true; +) -> Roots, TERMS>> { + let degree = TERMS - 1; - for i in 0..zs.len() { - let p_of_z = sample_polynomial(polynomial, zs[i]); - let dydx_of_z = sample_polynomial(dydx, zs[i]); - let sum = (0..zs.len()) - .filter(|&k| k != i) - .fold(Complex::::zero(), |acc, k| { - acc + Complex::::one() / (zs[i] - zs[k]) - }); - - let new_z = zs[i] + p_of_z / (p_of_z * sum - dydx_of_z); - new_zs[i] = new_z; - - if new_z.re.is_nan() - || new_z.im.is_nan() - || new_z.re.is_infinite() - || new_z.im.is_infinite() - { - break 'iteration; - } + let polynomial: &[Complex; TERMS] = &polynomial.map(|v| v.into()); - if !new_z.approx_eq(zs[i], epsilon) { - converged = false; - } - } - if converged { - return Ok(new_zs); - } - core::mem::swap(&mut zs, &mut new_zs); + let mut dydx: ArrayVec<_, TERMS> = ArrayVec::new_const(); + // SAFETY: we immediately populate every entry in dydx. + unsafe { + dydx.set_len(degree); + derivative::(polynomial, dydx.as_mut()); } - Err("Failed to converge.") -} -// Initial guesses using the method from "Iteration Methods for Finding all -// Zeros of a Polynomial Simultaneously" by Oliver Aberth. -fn initial_guesses< - const TERMS: usize, - F: Float + FloatConst + MulAdd, ->( - polynomial: &[F; TERMS], -) -> ArrayVec, TERMS> { - // the degree of the polynomial - let n = polynomial.len() - 1; - let n_f = unsafe { cast(n).unwrap_unchecked() }; - // convert polynomial to monic form - let mut monic: ArrayVec = ArrayVec::new(); - for c in polynomial { - // SAFETY: we push only as many values as there are terms. - unsafe { monic.push_unchecked(*c / polynomial[n]) }; + let mut guesses: ArrayVec<_, TERMS> = ArrayVec::new_const(); + // SAFETY: we immediately populate every entry in guesses. + unsafe { + guesses.set_len(TERMS); + initial_guesses(polynomial, guesses.as_mut()); + guesses.set_len(degree); } - // let a = - c_1 / n - let a = -monic[n - 1] / n_f; - // let z = w + a, - let p_of_w = { - // we can recycle monic on the fly. - for coefficient_index in 0..=n { - let c = monic[coefficient_index]; - monic[coefficient_index] = F::zero(); - for ((index, power), pascal) in zip( - zip(0..=coefficient_index, (0..=coefficient_index).rev()), - PascalRowIter::new(coefficient_index as u32), - ) { - let pascal: F = unsafe { cast(pascal).unwrap_unchecked() }; - monic[index] = - MulAdd::mul_add(c, pascal * a.powi(power as i32), monic[index]); - } - } - monic - }; - // convert P(w) into S(w) - let s_of_w = { - let mut p = p_of_w; - // skip the last coefficient - for i in 0..n { - p[i] = -p[i].abs() - } - p - }; - // find r_0 - let mut int = F::one(); - let r_0 = loop { - let s_at_r0 = sample_polynomial(&s_of_w, int.into()); - if s_at_r0.re > F::zero() { - break int; - } - int = int + F::one(); - }; - drop(s_of_w); - { - let mut guesses: ArrayVec, TERMS> = ArrayVec::new(); - - let frac_2pi_n = F::TAU() / n_f; - let frac_pi_2n = F::FRAC_PI_2() / n_f; - - for k in 0..n { - let k_f = unsafe { cast(k).unwrap_unchecked() }; - let theta = MulAdd::mul_add(frac_2pi_n, k_f, frac_pi_2n); - - let real = MulAdd::mul_add(r_0, theta.cos(), a); - let imaginary = r_0 * theta.sin(); - - let val = Complex::new(real, imaginary); - // SAFETY: we push 1 less values than there are terms. - unsafe { guesses.push_unchecked(val) }; + let mut output: ArrayVec<_, TERMS> = ArrayVec::new_const(); + // SAFETY: we push 1 less elements than there are terms. + unsafe { + for _ in 0..degree { + output.push_unchecked(Complex::zero()); } - - guesses } -} -/// An iterator over the numbers in a row of Pascal's Triangle. -pub struct PascalRowIter { - n: u32, - k: u32, - previous: u32, -} - -impl PascalRowIter { - /// Create an iterator yielding the numbers in the nth row of Pascal's - /// triangle. - pub fn new(n: u32) -> Self { - Self { - n, - k: 0, - previous: 1, - } - } -} - -impl Iterator for PascalRowIter { - type Item = u32; - - fn next(&mut self) -> Option { - if self.k == 0 { - self.k = 1; - self.previous = 1; - return Some(1); - } - if self.k > self.n { - return None; - } - let new = self.previous * (self.n + 1 - self.k) / self.k; - self.k += 1; - self.previous = new; - Some(new) - } -} - -/// Sample the polynomial at some value of `x` using Horner's method. -/// -/// Polynomial of the form f(x) = a + b*x + c*x^2 + d*x^3 + ... -/// -/// `coefficients` is a slice containing the coefficients [a, b, c, d, ...] -pub fn sample_polynomial>( - coefficients: &[F], - x: Complex, -) -> Complex { - debug_assert!(coefficients.len() != 0); - let mut r = Complex::zero(); - for c in coefficients.iter().rev() { - r = r.mul_add(x, c.into()) + let stop_reason = aberth_raw( + polynomial, + dydx.as_ref(), + guesses.as_mut(), + output.as_mut(), + max_iterations, + epsilon, + ); + + Roots { + roots: output, + stop_reason, } - r } -/// Compute the derivative of a polynomial. +/// The roots of a polynomial /// -/// Polynomial of the form f(x) = a + b*x + c*x^2 + d*x^3 + ... +/// Dereferences to an array-slice containing `roots`. /// -/// `coefficients` is a slice containing the coefficients [a, b, c, d, ...] -/// starting from the coefficient of the x with degree 0. -pub fn derivative( - coefficients: &[F; TERMS], -) -> ArrayVec { - debug_assert!(coefficients.len() != 0); - coefficients - .iter() - .enumerate() - .skip(1) - .map(|(power, &coefficient)| { - let p = unsafe { F::from(power).unwrap_unchecked() }; - p * coefficient - }) - .collect() +/// `stop_reason` contains information for how the solver terminated and how +/// many iterations it took. +#[derive(Clone, Debug, PartialEq)] +pub struct Roots { + pub roots: Arr, + pub stop_reason: StopReason, } -/// Some extra methods for Complex numbers -trait ComplexExt { - fn approx_eq(self, w: Self, epsilon: F) -> bool; -} +impl Deref for Roots { + type Target = Arr; -impl ComplexExt for Complex { - /// Cheap comparison of complex numbers to within some margin, epsilon. - #[inline] - fn approx_eq(self, w: Complex, epsilon: F) -> bool { - (self.re - w.re).abs() < epsilon && (self.im - w.im).abs() < epsilon + fn deref(&self) -> &Arr { + &self.roots } } -#[cfg(any(test, doctest))] -mod tests { - use super::*; - const EPSILON: f32 = 0.000_05; - const EPSILON_64: f64 = 0.000_000_000_005; +/// The reason the solver terminated and the number of iterations it took. +#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)] +pub enum StopReason { + /// converged to within the required precision + Converged(/* iterations */ u32), + /// reached the iteration limit + MaxIteration(/* iterations */ u32), + /// detected a NaN or Inf while iterating + Failed(/* iterations */ u32), +} - fn unsorted_compare( - zs: &[Complex], - ws: &[Complex], - epsilon: F, - ) -> bool { - zs.iter().fold(true, |acc, &z| { - let w = ws.iter().find(|&&w| z.approx_eq(w, epsilon)); - acc && w.is_some() - }) - } +#[cfg(feature = "std")] +pub use feature_std::AberthSolver; + +#[cfg(feature = "std")] +mod feature_std { + use super::Roots; + use crate::internal::*; + use core::fmt::Debug; + use num_complex::Complex; + use num_traits::{ + cast, + float::{Float, FloatConst}, + identities::Zero, + MulAdd, + }; - fn array_approx_eq(lhs: &[F], rhs: &[F], epsilon: F) -> bool { - if lhs.len() != rhs.len() { - return false; - } - for i in 0..lhs.len() { - if (lhs[i] - rhs[i]).abs() > epsilon { - return false; + impl Roots<&[Complex]> { + /// Create an owned duplicate of `Roots` by allocating a `Vec` to hold the + /// values + pub fn to_owned(&self) -> Roots>> { + Roots { + roots: self.roots.to_vec(), + stop_reason: self.stop_reason, } } - true } - /// ```should_panic - /// use aberth::derivative; + /// A solver for polynomials with Float or ComplexFloat coefficients. Will + /// find all complex-roots, using the Aberth-Ehrlich method. + /// + /// The solver allocates some memory, and will reuse this allocation for + /// subsequent calls. This is good to use for polynomials of varying lengths, + /// polynomials with many terms, and for use in hot-loops where you want to + /// avoid repeated allocations. + /// + /// Note the returned solutions are not sorted in any particular order. + /// + /// usage example: + /// + /// ```rust + /// use aberth::AberthSolver; + /// + /// let mut solver = AberthSolver::new(); + /// solver.epsilon = 0.001; + /// solver.max_iterations = 10; + /// + /// // 11x^4 + 4x^3 + 2x - 1 = 0 + /// let polynomial_a = [-1., 2., 0., 4., 11.]; + /// // x^4 -12x^3 + 39x^2 - 28 = 0 + /// let polynomial_b = [-28., 0., 39., -12., 1.]; /// - /// let y: [f32; 0] = []; - /// let dydx = derivative(&y); + /// for polynomial in [polynomial_a, polynomial_b] { + /// let roots = solver.find_roots(&polynomial); + /// // ... + /// } /// ``` - fn _derivative_empty_array() {} - - #[test] - fn derivative() { - use super::derivative; - - { - let y = [0., 1., 2., 3., 4.]; - let dydx = derivative(&y); - - let expected = [1., 4., 9., 16.]; - assert!(array_approx_eq(&dydx, &expected, EPSILON)); - } - - { - let y = [19., 2.3, 0., 8.3, 69.420]; - let dydx = derivative(&y); - - let expected = [2.3, 0., 24.9, 277.68]; - assert!(array_approx_eq(&dydx, &expected, EPSILON)); - } - } - - /// ```should_panic - /// use aberth::sample_polynomial; /// - /// let y = []; - /// let y_0 = sample_polynomial(&y, 0.0.into()); + /// If you want to hold onto the roots you previously found while reusing the + /// solver, then you can create an owned version: + /// ```rust + /// use aberth::AberthSolver; + /// + /// let mut solver = AberthSolver::new(); + /// let roots_a = solver.find_roots(&[-1., 2., 0., 4., 11.]).to_owned(); + /// let roots_b = solver.find_roots(&[-28., 39., -12., 1.]); + /// roots_a[0]; /// ``` - fn _sample_polynomial_empty_array() {} - - #[test] - fn sample_polynomial() { - use super::sample_polynomial; - - { - let y = [0., 1., 2., 3., 4.]; - - let x_0 = 0.0.into(); - let y_0 = sample_polynomial(&y, x_0); - let expected_0 = 0.0.into(); - assert!(y_0.approx_eq(expected_0, EPSILON)); - - let x_1 = 1.0.into(); - let y_1 = sample_polynomial(&y, x_1); - let expected_1 = 10.0.into(); - assert!(y_1.approx_eq(expected_1, EPSILON)); - - let x_2 = (-1.0).into(); - let y_2 = sample_polynomial(&y, x_2); - let expected_2 = 2.0.into(); - assert!(y_2.approx_eq(expected_2, EPSILON)); - - let x_3 = 2.5.into(); - let y_3 = sample_polynomial(&y, x_3); - let expected_3 = 218.125.into(); - assert!(y_3.approx_eq(expected_3, EPSILON)); - } - - { - let y = [19., 2.3, 0., 8.3, 69.420]; - - let x_0 = 0.0.into(); - let y_0 = sample_polynomial(&y, x_0); - let expected_0 = 19.0.into(); - assert!(y_0.approx_eq(expected_0, EPSILON)); - - let x_1 = 1.0.into(); - let y_1 = sample_polynomial(&y, x_1); - let expected_1 = 99.02.into(); - assert!(y_1.approx_eq(expected_1, EPSILON)); - - let x_2 = (-1.0).into(); - let y_2 = sample_polynomial(&y, x_2); - let expected_2 = 77.82.into(); - assert!(y_2.approx_eq(expected_2, EPSILON)); - } - } - - /// ```should_panic - /// use aberth::aberth; + /// or alternatively just copy the `.roots` field into a vec + /// ```rust + /// use aberth::{AberthSolver, Complex}; /// - /// let y = []; - /// let dydx = aberth(&y, 0.1); + /// let mut solver = AberthSolver::new(); + /// let roots_a: Vec> = + /// solver.find_roots(&[-1., 2., 0., 4., 11.]).to_vec(); + /// let roots_b = solver.find_roots(&[-28., 39., -12., 1.]); + /// roots_a[0]; /// ``` - fn _aberth_empty_array() {} - - #[test] - fn aberth() { - use super::*; - - { - let polynomial = [0., 1.]; - let roots = aberth(&polynomial, EPSILON).unwrap(); - assert!(roots[0].approx_eq(Complex::zero(), EPSILON)); - } - - { - let polynomial = [1., 0., -1.]; - let roots = aberth(&polynomial, EPSILON).unwrap(); - let expected = [1.0.into(), (-1.0).into()]; - assert!(unsorted_compare(&roots, &expected, EPSILON)); - } - - { - // x^3 -12x^2 + 39x - 28 = 0 - let polynomial = [-28., 39., -12., 1.]; - - let roots = aberth(&polynomial, EPSILON).unwrap(); - let expected = [7.0.into(), 4.0.into(), 1.0.into()]; - assert!(unsorted_compare(&roots, &expected, EPSILON)); - } - { - // 2x^3 - 38x^2 + 228x - 432 = 0 - let polynomial = [-432., 228., -38., 2.]; - - let roots = aberth(&polynomial, EPSILON).unwrap(); - let expected = [9.0.into(), 6.0.into(), 4.0.into()]; - assert!(unsorted_compare(&roots, &expected, EPSILON)); - } - { - // x^3 + 8 = 0 - let polynomial = [8., 0., 0., 1.]; - - let roots = aberth(&polynomial, EPSILON).unwrap(); - let expected = [ - (-2.).into(), - Complex::new(1., -3f32.sqrt()), - Complex::new(1., 3f32.sqrt()), - ]; - assert!(unsorted_compare(&roots, &expected, EPSILON)); - } - { - // 11x^9 + 4x^4 + 2x - 1 = 0 - let polynomial = [-1., 2., 0., 0., 4., 0., 0., 0., 0., 11.]; - - let roots = aberth(&polynomial, EPSILON).unwrap(); - let expected = [ - (0.429326).into(), - Complex::new(-0.802811, -0.229634), - Complex::new(-0.802811, 0.229634), - Complex::new(-0.344895, -0.842594), - Complex::new(-0.344895, 0.842594), - Complex::new(0.206720, -0.675070), - Complex::new(0.206720, 0.675070), - Complex::new(0.726324, -0.455503), - Complex::new(0.726324, 0.455503), - ]; - assert!(unsorted_compare(&roots, &expected, EPSILON)); - } - { - // 0 = - 20x^19 + 19x^18 - 18x^17 + 17x^16 - 16x^15 - // + 15x^14 - 14x^13 + 13x^12 - 12x^11 + 11x^10 - // - 10x^9 + 9x^8 - 8x^7 + 7x^6 - 6x^5 - // + 5x^4 - 4x^3 + 3x^2 - 2x + 1 - let polynomial = [ - 1., -2., 3., -4., 5., -6., 7., -8., 9., -10., 11., -12., 13., -14., - 15., -16., 17., -18., 19., -20., - ]; - - let roots = aberth(&polynomial, EPSILON).unwrap(); - // found using wolfram alpha - let expected = [ - 0.834053.into(), - Complex::new(-0.844_061, -0.321_794), - Complex::new(-0.844_061, 0.321_794), - Complex::new(-0.684_734, -0.550_992), - Complex::new(-0.684_734, 0.550_992), - Complex::new(-0.476_151, -0.721_437), - Complex::new(-0.476_151, 0.721_437), - Complex::new(-0.231_844, -0.822_470), - Complex::new(-0.231_844, 0.822_470), - Complex::new(0.028_207, -0.846_944), - Complex::new(0.028_207, 0.846_944), - Complex::new(0.281_692, -0.793_720), - Complex::new(0.281_692, 0.793_720), - Complex::new(0.506_511, -0.668_231), - Complex::new(0.506_511, 0.668_231), - Complex::new(0.682_933, -0.482_160), - Complex::new(0.682_933, 0.482_160), - Complex::new(0.795_421, -0.252_482), - Complex::new(0.795_421, 0.252_482), - ]; - assert!(unsorted_compare(&roots, &expected, EPSILON)); - } + #[derive(Debug, Clone)] + pub struct AberthSolver + where + F: Float, + { + pub max_iterations: u32, + pub epsilon: F, + data: Vec>, } - #[test] - fn aberth_f64() { - use super::aberth; - { - // 0 = - 20x^19 + 19x^18 - 18x^17 + 17x^16 - 16x^15 - // + 15x^14 - 14x^13 + 13x^12 - 12x^11 + 11x^10 - // - 10x^9 + 9x^8 - 8x^7 + 7x^6 - 6x^5 - // + 5x^4 - 4x^3 + 3x^2 - 2x + 1 - let polynomial: [f64; 20] = [ - 1., -2., 3., -4., 5., -6., 7., -8., 9., -10., 11., -12., 13., -14., - 15., -16., 17., -18., 19., -20., - ]; - - let roots = aberth(&polynomial, EPSILON_64).unwrap(); - let expected = [ - 0.834_053_367_550.into(), - Complex::new(-0.844_060_952_037, -0.321_793_977_746), - Complex::new(-0.844_060_952_037, 0.321_793_977_746), - Complex::new(-0.684_734_480_334, -0.550_992_054_369), - Complex::new(-0.684_734_480_334, 0.550_992_054_369), - Complex::new(-0.476_151_406_058, -0.721_436_901_065), - Complex::new(-0.476_151_406_058, 0.721_436_901_065), - Complex::new(-0.231_843_928_891, -0.822_470_497_825), - Complex::new(-0.231_843_928_891, 0.822_470_497_825), - Complex::new(0.028_207_047_127, -0.846_944_061_134), - Complex::new(0.028_207_047_127, 0.846_944_061_134), - Complex::new(0.281_691_706_643, -0.793_720_289_127), - Complex::new(0.281_691_706_643, 0.793_720_289_127), - Complex::new(0.506_511_447_570, -0.668_230_679_428), - Complex::new(0.506_511_447_570, 0.668_230_679_428), - Complex::new(0.682_933_030_868, -0.482_159_501_324), - Complex::new(0.682_933_030_868, 0.482_159_501_324), - Complex::new(0.795_420_851_336, -0.252_482_354_484), - Complex::new(0.795_420_851_336, 0.252_482_354_484), - ]; - assert!(unsorted_compare(&roots, &expected, EPSILON_64)); + impl + Default + Debug> Default + for AberthSolver + { + fn default() -> Self { + AberthSolver::new() } } - #[test] - fn pascal_triangle() { - { - let row = PascalRowIter::new(0) - .collect::>() - .into_inner() - .unwrap(); - let expected = [1]; - assert_eq!(row, expected); - } - { - let row = PascalRowIter::new(1) - .collect::>() - .into_inner() - .unwrap(); - let expected = [1, 1]; - assert_eq!(row, expected); - } - { - let row = PascalRowIter::new(2) - .collect::>() - .into_inner() - .unwrap(); - let expected = [1, 2, 1]; - assert_eq!(row, expected); - } - { - let row = PascalRowIter::new(3) - .collect::>() - .into_inner() - .unwrap(); - let expected = [1, 3, 3, 1]; - assert_eq!(row, expected); - } - { - let row = PascalRowIter::new(4) - .collect::>() - .into_inner() - .unwrap(); - let expected = [1, 4, 6, 4, 1]; - assert_eq!(row, expected); - } - { - let row = PascalRowIter::new(5) - .collect::>() - .into_inner() - .unwrap(); - let expected = [1, 5, 10, 10, 5, 1]; - assert_eq!(row, expected); - } - { - let row = PascalRowIter::new(6) - .collect::>() - .into_inner() - .unwrap(); - let expected = [1, 6, 15, 20, 15, 6, 1]; - assert_eq!(row, expected); + impl + Default + Debug> + AberthSolver + { + pub fn new() -> Self { + AberthSolver { + max_iterations: 100, + data: Vec::new(), + epsilon: cast(0.001).unwrap(), + } } - { - let row = PascalRowIter::new(9) - .collect::>() - .into_inner() - .unwrap(); - let expected = [1, 9, 36, 84, 126, 126, 84, 36, 9, 1]; - assert_eq!(row, expected); + + /// Find all the complex roots of the polynomial + /// + /// Polynomial is given in the form `f(x) = a + b*x + c*x^2 + d*x^3 + ...` + /// + /// `polynomial` is a slice containing the coefficients `[a, b, c, d, ...]` + pub fn find_roots>( + &mut self, + polynomial: &[C], + ) -> Roots<&[Complex]> { + let len = polynomial.len(); + let degree = len - 1; + // ensure we have enough space allocated + self + .data + .resize_with(len + degree + len + degree, Complex::zero); + // get mutable slices to our data + let (complex_poly, tail) = self.data.split_at_mut(len); + let (dydx, tail) = tail.split_at_mut(degree); + let (guesses, output) = tail.split_at_mut(len); + + // convert the polynomial to a complex type + polynomial + .iter() + .enumerate() + .for_each(|(i, &coefficient)| complex_poly[i] = coefficient.into()); + + initial_guesses(complex_poly, guesses); + let guesses = &mut guesses[0..degree]; + derivative(complex_poly, dydx); + + let stop_reason = aberth_raw( + complex_poly, + dydx, + guesses, + output, + self.max_iterations, + self.epsilon, + ); + + Roots { + roots: output, + stop_reason, + } } } } diff --git a/src/tests.rs b/src/tests.rs new file mode 100644 index 0000000..c3f2198 --- /dev/null +++ b/src/tests.rs @@ -0,0 +1,423 @@ +use crate::*; +use num_complex::ComplexFloat; + +const EPSILON: f32 = 0.000_05; +const EPSILON_64: f64 = 0.000_000_000_005; + +fn array_approx_eq( + lhs: &[Complex], + rhs: &[Complex], + epsilon: F, +) -> bool { + if lhs.len() != rhs.len() { + return false; + } + for i in 0..lhs.len() { + if (lhs[i] - rhs[i]).abs() > epsilon { + return false; + } + } + true +} + +fn unsorted_compare( + zs: &[Complex], + ws: &[Complex], + epsilon: F, +) -> bool { + zs.iter().fold(true, |acc, &z| { + let w = ws.iter().find(|&&w| z.approx_eq(w, epsilon)); + acc && w.is_some() + }) +} + +#[test] +fn derivative() { + use super::derivative; + + { + let y = [0., 1., 2., 3., 4.]; + let y = y.map(|v| Complex::from(v)); + let mut dydx = [Complex::zero(); 4]; + + derivative(&y, &mut dydx); + + let expected = [1., 4., 9., 16.]; + let expected = expected.map(|v| Complex::from(v)); + assert!(array_approx_eq(&dydx, &expected, EPSILON)); + } + + { + let y = [19., 2.3, 0., 8.3, 69.420]; + let y = y.map(|v| Complex::from(v)); + let mut dydx = [Complex::zero(); 4]; + + derivative(&y, &mut dydx); + + let expected = [2.3, 0., 24.9, 277.68]; + let expected = expected.map(|v| Complex::from(v)); + assert!(array_approx_eq(&dydx, &expected, EPSILON)); + } +} + +#[test] +fn sample_polynomial() { + use super::sample_polynomial; + + { + let y = [0., 1., 2., 3., 4.]; + let y = y.map(|v| Complex::from(v)); + + let x_0 = 0.0.into(); + let y_0 = sample_polynomial(&y, x_0); + let expected_0 = 0.0.into(); + assert!(y_0.approx_eq(expected_0, EPSILON)); + + let x_1 = 1.0.into(); + let y_1 = sample_polynomial(&y, x_1); + let expected_1 = 10.0.into(); + assert!(y_1.approx_eq(expected_1, EPSILON)); + + let x_2 = (-1.0).into(); + let y_2 = sample_polynomial(&y, x_2); + let expected_2 = 2.0.into(); + assert!(y_2.approx_eq(expected_2, EPSILON)); + + let x_3 = 2.5.into(); + let y_3 = sample_polynomial(&y, x_3); + let expected_3 = 218.125.into(); + assert!(y_3.approx_eq(expected_3, EPSILON)); + } + + { + let y = [19., 2.3, 0., 8.3, 69.420]; + let y = y.map(|v| Complex::from(v)); + + let x_0 = 0.0.into(); + let y_0 = sample_polynomial(&y, x_0); + let expected_0 = 19.0.into(); + assert!(y_0.approx_eq(expected_0, EPSILON)); + + let x_1 = 1.0.into(); + let y_1 = sample_polynomial(&y, x_1); + let expected_1 = 99.02.into(); + assert!(y_1.approx_eq(expected_1, EPSILON)); + + let x_2 = (-1.0).into(); + let y_2 = sample_polynomial(&y, x_2); + let expected_2 = 77.82.into(); + assert!(y_2.approx_eq(expected_2, EPSILON)); + } +} + +#[test] +fn pascal_triangle() { + { + let row: Vec<_> = PascalRowIter::new(0).collect(); + let expected = [1]; + assert_eq!(row, expected); + } + { + let row: Vec<_> = PascalRowIter::new(1).collect(); + let expected = [1, 1]; + assert_eq!(row, expected); + } + { + let row: Vec<_> = PascalRowIter::new(2).collect(); + let expected = [1, 2, 1]; + assert_eq!(row, expected); + } + { + let row: Vec<_> = PascalRowIter::new(3).collect(); + let expected = [1, 3, 3, 1]; + assert_eq!(row, expected); + } + { + let row: Vec<_> = PascalRowIter::new(4).collect(); + let expected = [1, 4, 6, 4, 1]; + assert_eq!(row, expected); + } + { + let row: Vec<_> = PascalRowIter::new(5).collect(); + let expected = [1, 5, 10, 10, 5, 1]; + assert_eq!(row, expected); + } + { + let row: Vec<_> = PascalRowIter::new(6).collect(); + let expected = [1, 6, 15, 20, 15, 6, 1]; + assert_eq!(row, expected); + } + { + let row: Vec<_> = PascalRowIter::new(9).collect(); + let expected = [1, 9, 36, 84, 126, 126, 84, 36, 9, 1]; + assert_eq!(row, expected); + } +} + +/// ```should_panic +/// use aberth::aberth; +/// +/// let y: [f32; 0] = []; +/// let dydx = aberth(&y, 100, 0.1); +/// ``` +fn _aberth_empty_array() {} + +#[test] +fn aberth() { + use super::*; + + { + let polynomial = [0., 1.]; + + let roots = aberth(&polynomial, 100, EPSILON); + assert!(roots[0].approx_eq(Complex::zero(), EPSILON)); + } + + { + let polynomial = [1., 0., -1.]; + + let roots = aberth(&polynomial, 100, EPSILON); + let expected = [1.0.into(), (-1.0).into()]; + assert!(unsorted_compare(&roots, &expected, EPSILON)); + } + + { + // x^3 -12x^2 + 39x - 28 = 0 + let polynomial = [-28., 39., -12., 1.]; + + let roots = aberth(&polynomial, 100, EPSILON); + let expected = [7.0.into(), 4.0.into(), 1.0.into()]; + assert!(unsorted_compare(&roots, &expected, EPSILON)); + } + { + // 2x^3 - 38x^2 + 228x - 432 = 0 + let polynomial = [-432., 228., -38., 2.]; + + let roots = aberth(&polynomial, 100, EPSILON); + let expected = [9.0.into(), 6.0.into(), 4.0.into()]; + assert!(unsorted_compare(&roots, &expected, EPSILON)); + } + { + // x^3 + 8 = 0 + let polynomial = [8., 0., 0., 1.]; + + let roots = aberth(&polynomial, 100, EPSILON); + let expected = [ + (-2.).into(), + Complex::new(1., -3f32.sqrt()), + Complex::new(1., 3f32.sqrt()), + ]; + assert!(unsorted_compare(&roots, &expected, EPSILON)); + } + { + // 11x^9 + 4x^4 + 2x - 1 = 0 + let polynomial = [-1., 2., 0., 0., 4., 0., 0., 0., 0., 11.]; + + let roots = aberth(&polynomial, 100, EPSILON); + let expected = [ + (0.429326).into(), + Complex::new(-0.802811, -0.229634), + Complex::new(-0.802811, 0.229634), + Complex::new(-0.344895, -0.842594), + Complex::new(-0.344895, 0.842594), + Complex::new(0.206720, -0.675070), + Complex::new(0.206720, 0.675070), + Complex::new(0.726324, -0.455503), + Complex::new(0.726324, 0.455503), + ]; + assert!(unsorted_compare(&roots, &expected, EPSILON)); + } + { + // 0 = - 20x^19 + 19x^18 - 18x^17 + 17x^16 - 16x^15 + // + 15x^14 - 14x^13 + 13x^12 - 12x^11 + 11x^10 + // - 10x^9 + 9x^8 - 8x^7 + 7x^6 - 6x^5 + // + 5x^4 - 4x^3 + 3x^2 - 2x + 1 + let polynomial = [ + 1., -2., 3., -4., 5., -6., 7., -8., 9., -10., 11., -12., 13., -14., 15., + -16., 17., -18., 19., -20., + ]; + + let roots = aberth(&polynomial, 100, EPSILON); + // found using wolfram alpha + let expected = [ + 0.834053.into(), + Complex::new(-0.844_061, -0.321_794), + Complex::new(-0.844_061, 0.321_794), + Complex::new(-0.684_734, -0.550_992), + Complex::new(-0.684_734, 0.550_992), + Complex::new(-0.476_151, -0.721_437), + Complex::new(-0.476_151, 0.721_437), + Complex::new(-0.231_844, -0.822_470), + Complex::new(-0.231_844, 0.822_470), + Complex::new(0.028_207, -0.846_944), + Complex::new(0.028_207, 0.846_944), + Complex::new(0.281_692, -0.793_720), + Complex::new(0.281_692, 0.793_720), + Complex::new(0.506_511, -0.668_231), + Complex::new(0.506_511, 0.668_231), + Complex::new(0.682_933, -0.482_160), + Complex::new(0.682_933, 0.482_160), + Complex::new(0.795_421, -0.252_482), + Complex::new(0.795_421, 0.252_482), + ]; + assert!(unsorted_compare(&roots, &expected, EPSILON)); + } +} + +#[test] +fn aberth_f64() { + use super::aberth; + { + // 0 = - 20x^19 + 19x^18 - 18x^17 + 17x^16 - 16x^15 + // + 15x^14 - 14x^13 + 13x^12 - 12x^11 + 11x^10 + // - 10x^9 + 9x^8 - 8x^7 + 7x^6 - 6x^5 + // + 5x^4 - 4x^3 + 3x^2 - 2x + 1 + let polynomial: [f64; 20] = [ + 1., -2., 3., -4., 5., -6., 7., -8., 9., -10., 11., -12., 13., -14., 15., + -16., 17., -18., 19., -20., + ]; + + let roots = aberth(&polynomial, 100, EPSILON_64); + let expected = [ + 0.834_053_367_550.into(), + Complex::new(-0.844_060_952_037, -0.321_793_977_746), + Complex::new(-0.844_060_952_037, 0.321_793_977_746), + Complex::new(-0.684_734_480_334, -0.550_992_054_369), + Complex::new(-0.684_734_480_334, 0.550_992_054_369), + Complex::new(-0.476_151_406_058, -0.721_436_901_065), + Complex::new(-0.476_151_406_058, 0.721_436_901_065), + Complex::new(-0.231_843_928_891, -0.822_470_497_825), + Complex::new(-0.231_843_928_891, 0.822_470_497_825), + Complex::new(0.028_207_047_127, -0.846_944_061_134), + Complex::new(0.028_207_047_127, 0.846_944_061_134), + Complex::new(0.281_691_706_643, -0.793_720_289_127), + Complex::new(0.281_691_706_643, 0.793_720_289_127), + Complex::new(0.506_511_447_570, -0.668_230_679_428), + Complex::new(0.506_511_447_570, 0.668_230_679_428), + Complex::new(0.682_933_030_868, -0.482_159_501_324), + Complex::new(0.682_933_030_868, 0.482_159_501_324), + Complex::new(0.795_420_851_336, -0.252_482_354_484), + Complex::new(0.795_420_851_336, 0.252_482_354_484), + ]; + assert!(unsorted_compare(&roots, &expected, EPSILON_64)); + } +} + +#[cfg(feature = "std")] +mod feature_std { + use crate::*; + + #[test] + fn aberth_solver() { + use super::*; + let mut solver = AberthSolver::new(); + solver.epsilon = EPSILON; + + { + let polynomial = [0., 1.]; + + let roots = solver.find_roots(&polynomial); + let expected = [0.0.into()]; + assert!(unsorted_compare(&roots, &expected, EPSILON)); + } + + { + let polynomial = [1., 0., -1.]; + + let roots = solver.find_roots(&polynomial); + let expected = [1.0.into(), (-1.0).into()]; + assert!(unsorted_compare(&roots, &expected, EPSILON)); + } + + { + // x^3 -12x^2 + 39x - 28 = 0 + let polynomial = [-28., 39., -12., 1.]; + + let roots = solver.find_roots(&polynomial); + let expected = [7.0.into(), 4.0.into(), 1.0.into()]; + assert!(unsorted_compare(&roots, &expected, EPSILON)); + } + { + // 2x^3 - 38x^2 + 228x - 432 = 0 + let polynomial = [-432., 228., -38., 2.]; + + let roots = solver.find_roots(&polynomial); + let expected = [9.0.into(), 6.0.into(), 4.0.into()]; + assert!(unsorted_compare(&roots, &expected, EPSILON)); + } + { + // x^3 + 8 = 0 + let polynomial = [8., 0., 0., 1.]; + + let roots = solver.find_roots(&polynomial); + let expected = [ + (-2.).into(), + Complex::new(1., -3f32.sqrt()), + Complex::new(1., 3f32.sqrt()), + ]; + assert!(unsorted_compare(&roots, &expected, EPSILON)); + } + { + // 11x^9 + 4x^4 + 2x - 1 = 0 + let polynomial = [-1., 2., 0., 0., 4., 0., 0., 0., 0., 11.]; + + let roots = solver.find_roots(&polynomial); + let expected = [ + (0.429326).into(), + Complex::new(-0.802811, -0.229634), + Complex::new(-0.802811, 0.229634), + Complex::new(-0.344895, -0.842594), + Complex::new(-0.344895, 0.842594), + Complex::new(0.206720, -0.675070), + Complex::new(0.206720, 0.675070), + Complex::new(0.726324, -0.455503), + Complex::new(0.726324, 0.455503), + ]; + assert!(unsorted_compare(&roots, &expected, EPSILON)); + } + { + // 0 = - 20x^19 + 19x^18 - 18x^17 + 17x^16 - 16x^15 + // + 15x^14 - 14x^13 + 13x^12 - 12x^11 + 11x^10 + // - 10x^9 + 9x^8 - 8x^7 + 7x^6 - 6x^5 + // + 5x^4 - 4x^3 + 3x^2 - 2x + 1 + let polynomial = [ + 1., -2., 3., -4., 5., -6., 7., -8., 9., -10., 11., -12., 13., -14., + 15., -16., 17., -18., 19., -20., + ]; + + let roots = solver.find_roots(&polynomial); + // found using wolfram alpha + let expected = [ + 0.834053.into(), + Complex::new(-0.844_061, -0.321_794), + Complex::new(-0.844_061, 0.321_794), + Complex::new(-0.684_734, -0.550_992), + Complex::new(-0.684_734, 0.550_992), + Complex::new(-0.476_151, -0.721_437), + Complex::new(-0.476_151, 0.721_437), + Complex::new(-0.231_844, -0.822_470), + Complex::new(-0.231_844, 0.822_470), + Complex::new(0.028_207, -0.846_944), + Complex::new(0.028_207, 0.846_944), + Complex::new(0.281_692, -0.793_720), + Complex::new(0.281_692, 0.793_720), + Complex::new(0.506_511, -0.668_231), + Complex::new(0.506_511, 0.668_231), + Complex::new(0.682_933, -0.482_160), + Complex::new(0.682_933, 0.482_160), + Complex::new(0.795_421, -0.252_482), + Complex::new(0.795_421, 0.252_482), + ]; + assert!(unsorted_compare(&roots, &expected, EPSILON)); + } + } + + #[test] + fn reuse_solver() { + let mut solver = AberthSolver::new(); + let roots_a = solver.find_roots(&[-1., 2., 0., 4., 11.]).to_owned(); + let roots_b = solver.find_roots(&[-28., 39., -12., 1.]); + roots_a.roots; + roots_b.roots; + } +}