diff --git a/benches/msm.rs b/benches/msm.rs index 2a98c525..db9949e2 100644 --- a/benches/msm.rs +++ b/benches/msm.rs @@ -16,7 +16,7 @@ use criterion::{BenchmarkId, Criterion}; use ff::{Field, PrimeField}; use group::prime::PrimeCurveAffine; use halo2curves_axiom::bn256::{Fr as Scalar, G1Affine as Point}; -use halo2curves_axiom::msm::{best_multiexp, multiexp_serial}; +use halo2curves_axiom::msm::{msm_best, msm_serial}; use rand_core::{RngCore, SeedableRng}; use rand_xorshift::XorShiftRng; use rayon::current_thread_index; @@ -136,7 +136,7 @@ fn msm(c: &mut Criterion) { assert!(k < 64); let n: usize = 1 << k; let mut acc = Point::identity().into(); - b.iter(|| multiexp_serial(&coeffs[b_index][..n], &bases[..n], &mut acc)); + b.iter(|| msm_serial(&coeffs[b_index][..n], &bases[..n], &mut acc)); }) .sample_size(10); } @@ -147,7 +147,7 @@ fn msm(c: &mut Criterion) { assert!(k < 64); let n: usize = 1 << k; b.iter(|| { - best_multiexp(&coeffs[b_index][..n], &bases[..n]); + msm_best(&coeffs[b_index][..n], &bases[..n]); }) }) .sample_size(SAMPLE_SIZE); diff --git a/src/msm.rs b/src/msm.rs index 1aab35ef..9179ab94 100644 --- a/src/msm.rs +++ b/src/msm.rs @@ -14,8 +14,7 @@ fn get_booth_index(window_index: usize, window_size: usize, el: &[u8]) -> i32 { // Booth encoding: // * step by `window` size // * slice by size of `window + 1`` - // * each window overlap by 1 bit - // * append a zero bit to the least significant end + // * each window overlap by 1 bit * append a zero bit to the least significant end // Indexing rule for example window size 3 where we slice by 4 bits: // `[0, +1, +1, +2, +2, +3, +3, +4, -4, -3, -3 -2, -2, -1, -1, 0]`` // So we can reduce the bucket size without preprocessing scalars @@ -54,14 +53,15 @@ fn get_booth_index(window_index: usize, window_size: usize, el: &[u8]) -> i32 { } } +/// Batch addition. fn batch_add( size: usize, buckets: &mut [BucketAffine], points: &[SchedulePoint], bases: &[Affine], ) { - let mut t = vec![C::Base::ZERO; size]; - let mut z = vec![C::Base::ZERO; size]; + let mut t = vec![C::Base::ZERO; size]; // Stores x2 - x1 + let mut z = vec![C::Base::ZERO; size]; // Stores y2 - y1 let mut acc = C::Base::ONE; for ( @@ -76,16 +76,42 @@ fn batch_add( z, ) in points.iter().zip(t.iter_mut()).zip(z.iter_mut()) { - *z = buckets[*buck_idx].x() - bases[*base_idx].x; + if buckets[*buck_idx].is_inf() { + // We assume bases[*base_idx] != infinity always. + continue; + } + + if buckets[*buck_idx].x() == bases[*base_idx].x { + // y-coordinate matches: + // 1. y1 == y2 and sign = false or + // 2. y1 != y2 and sign = true + // => ( y1 == y2) xor !sign + // (This uses the fact that x1 == x2 and both points satisfy the curve eq.) + if (buckets[*buck_idx].y() == bases[*base_idx].y) ^ !*sign { + // Doubling + let x_squared = bases[*base_idx].x.square(); + *z = buckets[*buck_idx].y() + buckets[*buck_idx].y(); // 2y + *t = acc * (x_squared + x_squared + x_squared); // acc * 3x^2 + acc *= *z; + continue; + } + // P + (-P) + buckets[*buck_idx].set_inf(); + continue; + } + // Addition + *z = buckets[*buck_idx].x() - bases[*base_idx].x; // x2 - x1 if *sign { *t = acc * (buckets[*buck_idx].y() - bases[*base_idx].y); } else { *t = acc * (buckets[*buck_idx].y() + bases[*base_idx].y); - } + } // y2 - y1 acc *= *z; } - acc = acc.invert().unwrap(); + acc = acc + .invert() + .expect("Some edge case has not been handled properly"); for ( ( @@ -99,15 +125,18 @@ fn batch_add( z, ) in points.iter().zip(t.iter()).zip(z.iter()).rev() { + if buckets[*buck_idx].is_inf() { + // We assume bases[*base_idx] != infinity always. + continue; + } let lambda = acc * t; - acc *= z; - - let x = lambda.square() - (buckets[*buck_idx].x() + bases[*base_idx].x); + acc *= z; // update acc + let x = lambda.square() - (buckets[*buck_idx].x() + bases[*base_idx].x); // x_result if *sign { buckets[*buck_idx].set_y(&((lambda * (bases[*base_idx].x - x)) - bases[*base_idx].y)); } else { buckets[*buck_idx].set_y(&((lambda * (bases[*base_idx].x - x)) + bases[*base_idx].y)); - } + } // y_result = lambda * (x1 - x_result) - y1 buckets[*buck_idx].set_x(&x); } } @@ -207,6 +236,13 @@ impl BucketAffine { } } + fn is_inf(&self) -> bool { + match self { + Self::None => true, + Self::Point(_) => false, + } + } + fn set_x(&mut self, x: &C::Base) { match self { Self::None => panic!("::set_x None"), @@ -220,6 +256,13 @@ impl BucketAffine { Self::Point(ref mut a) => a.y = *y, } } + + fn set_inf(&mut self) { + match self { + Self::None => {} + Self::Point(_) => *self = Self::None, + } + } } struct Schedule { @@ -286,7 +329,10 @@ impl Schedule { } } -pub fn multiexp_serial(coeffs: &[C::Scalar], bases: &[C], acc: &mut C::Curve) { +/// Performs a multi-scalar multiplication operation. +/// +/// This function will panic if coeffs and bases have a different length. +pub fn msm_serial(coeffs: &[C::Scalar], bases: &[C], acc: &mut C::Curve) { let coeffs: Vec<_> = coeffs.iter().map(|a| a.to_repr()).collect(); let c = if bases.len() < 4 { @@ -303,7 +349,7 @@ pub fn multiexp_serial(coeffs: &[C::Scalar], bases: &[C], acc: & let mut acc_or = vec![0; field_byte_size]; for coeff in &coeffs { for (acc_limb, limb) in acc_or.iter_mut().zip(coeff.as_ref().iter()) { - *acc_limb = *acc_limb | *limb; + *acc_limb |= *limb; } } let max_byte_size = field_byte_size @@ -315,7 +361,7 @@ pub fn multiexp_serial(coeffs: &[C::Scalar], bases: &[C], acc: & if max_byte_size == 0 { return; } - let number_of_windows = max_byte_size * 8 as usize / c + 1; + let number_of_windows = max_byte_size * 8_usize / c + 1; for current_window in (0..number_of_windows).rev() { for _ in 0..c { @@ -377,12 +423,12 @@ pub fn multiexp_serial(coeffs: &[C::Scalar], bases: &[C], acc: & } } -/// Performs a multi-exponentiation operation. +/// Performs a multi-scalar multiplication operation. /// /// This function will panic if coeffs and bases have a different length. /// /// This will use multithreading if beneficial. -pub fn best_multiexp(coeffs: &[C::Scalar], bases: &[C]) -> C::Curve { +pub fn msm_parallel(coeffs: &[C::Scalar], bases: &[C]) -> C::Curve { assert_eq!(coeffs.len(), bases.len()); let num_threads = rayon::current_num_threads(); @@ -399,25 +445,22 @@ pub fn best_multiexp(coeffs: &[C::Scalar], bases: &[C]) -> C::Cu .zip(results.iter_mut()) { scope.spawn(move |_| { - multiexp_serial(coeffs, bases, acc); + msm_serial(coeffs, bases, acc); }); } }); results.iter().fold(C::Curve::identity(), |a, b| a + b) } else { let mut acc = C::Curve::identity(); - multiexp_serial(coeffs, bases, &mut acc); + msm_serial(coeffs, bases, &mut acc); acc } } -/// + /// This function will panic if coeffs and bases have a different length. /// /// This will use multithreading if beneficial. -pub fn best_multiexp_independent_points( - coeffs: &[C::Scalar], - bases: &[C], -) -> C::Curve { +pub fn msm_best(coeffs: &[C::Scalar], bases: &[C]) -> C::Curve { assert_eq!(coeffs.len(), bases.len()); // TODO: consider adjusting it with emprical data? @@ -430,7 +473,7 @@ pub fn best_multiexp_independent_points( }; if c < 10 { - return best_multiexp(coeffs, bases); + return msm_parallel(coeffs, bases); } // coeffs to byte representation @@ -491,7 +534,6 @@ pub fn best_multiexp_independent_points( #[cfg(test)] mod test { - use std::ops::Neg; use crate::bn256::{Fr, G1Affine, G1}; @@ -548,7 +590,10 @@ mod test { } fn run_msm_cross(min_k: usize, max_k: usize) { + use rayon::iter::{IntoParallelIterator, ParallelIterator}; + let points = (0..1 << max_k) + .into_par_iter() .map(|_| C::Curve::random(OsRng)) .collect::>(); let mut affine_points = vec![C::identity(); 1 << max_k]; @@ -556,6 +601,7 @@ mod test { let points = affine_points; let scalars = (0..1 << max_k) + .into_par_iter() .map(|_| C::Scalar::random(OsRng)) .collect::>(); @@ -563,12 +609,12 @@ mod test { let points = &points[..1 << k]; let scalars = &scalars[..1 << k]; - let t0 = start_timer!(|| format!("cyclone k={}", k)); - let e0 = super::best_multiexp_independent_points(scalars, points); + let t0 = start_timer!(|| format!("cyclone indep k={}", k)); + let e0 = super::msm_best(scalars, points); end_timer!(t0); let t1 = start_timer!(|| format!("older k={}", k)); - let e1 = super::best_multiexp(scalars, points); + let e1 = super::msm_parallel(scalars, points); end_timer!(t1); assert_eq!(e0, e1); }