From 34ef598b431989f98f20770cdabff5d61c1ed535 Mon Sep 17 00:00:00 2001 From: morph <82043364+morph-dev@users.noreply.github.com> Date: Sat, 20 Jul 2024 16:30:53 +0300 Subject: [PATCH] feat: implement LagrangeBasis polynomials and PrecomputedWeights --- Cargo.lock | 7 + portal-verkle-primitives/Cargo.toml | 1 + portal-verkle-primitives/src/ec/point.rs | 33 +++-- .../src/ec/scalar_field.rs | 118 +++++++++++++++- .../src/proof/lagrange_basis.rs | 132 ++++++++++++++++++ portal-verkle-primitives/src/proof/mod.rs | 2 + .../src/proof/precomputed_weights.rs | 115 +++++++++++++++ 7 files changed, 387 insertions(+), 21 deletions(-) create mode 100644 portal-verkle-primitives/src/proof/lagrange_basis.rs create mode 100644 portal-verkle-primitives/src/proof/precomputed_weights.rs diff --git a/Cargo.lock b/Cargo.lock index f5c1a8e..13328d3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1065,6 +1065,12 @@ version = "1.19.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92" +[[package]] +name = "overload" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39" + [[package]] name = "parity-scale-codec" version = "3.6.12" @@ -1145,6 +1151,7 @@ dependencies = [ "ethereum_ssz", "ethereum_ssz_derive", "once_cell", + "overload", "rstest", "serde", "sha2", diff --git a/portal-verkle-primitives/Cargo.toml b/portal-verkle-primitives/Cargo.toml index b48dfc0..cd551b9 100644 --- a/portal-verkle-primitives/Cargo.toml +++ b/portal-verkle-primitives/Cargo.toml @@ -17,6 +17,7 @@ derive_more = "0.99" ethereum_ssz = "0.5" ethereum_ssz_derive = "0.5" once_cell = "1" +overload = "0.1" serde = { version = "1", features = ["derive"] } sha2 = "0.10" ssz_types = "0.6" diff --git a/portal-verkle-primitives/src/ec/point.rs b/portal-verkle-primitives/src/ec/point.rs index 9266d07..a8844f0 100644 --- a/portal-verkle-primitives/src/ec/point.rs +++ b/portal-verkle-primitives/src/ec/point.rs @@ -1,18 +1,15 @@ -use std::{ - fmt::Debug, - iter::Sum, - ops::{Add, Sub}, -}; +use std::{fmt::Debug, iter::Sum, ops}; use alloy_primitives::B256; use banderwagon::{CanonicalDeserialize, CanonicalSerialize, Element}; -use derive_more::{Add, AddAssign, Constructor, Deref, From, Into, Sum}; +use derive_more::{Constructor, Deref, From, Into}; +use overload::overload; use serde::{Deserialize, Deserializer, Serialize, Serializer}; use ssz::{Decode, Encode}; use crate::ScalarField; -#[derive(Clone, PartialEq, Eq, Constructor, From, Into, Deref, Add, AddAssign, Sum)] +#[derive(Clone, PartialEq, Eq, Constructor, From, Into, Deref)] pub struct Point(Element); impl Point { @@ -110,19 +107,21 @@ impl Debug for Point { } } -impl Add<&Self> for Point { - type Output = Self; +overload!(- (me: ?Point) -> Point { Point(-me.0) }); - fn add(self, rhs: &Self) -> Self::Output { - Self(self.0 + rhs.0) - } -} +overload!((lhs: &mut Point) += (rhs: ?Point) { lhs.0 += rhs.0 }); +overload!((lhs: Point) + (rhs: ?Point) -> Point { + let mut lhs = lhs; lhs += rhs; lhs +}); -impl Sub<&Self> for Point { - type Output = Self; +overload!((lhs: &mut Point) -= (rhs: ?Point) { lhs.0 = lhs.0 - rhs.0 }); +overload!((lhs: Point) - (rhs: ?Point) -> Point { + let mut lhs = lhs; lhs -= rhs; lhs +}); - fn sub(self, rhs: &Self) -> Self::Output { - Self(self.0 - rhs.0) +impl Sum for Point { + fn sum>(iter: I) -> Self { + iter.fold(Self::zero(), |sum, item| sum + item) } } diff --git a/portal-verkle-primitives/src/ec/scalar_field.rs b/portal-verkle-primitives/src/ec/scalar_field.rs index de9bc21..d24b2d2 100644 --- a/portal-verkle-primitives/src/ec/scalar_field.rs +++ b/portal-verkle-primitives/src/ec/scalar_field.rs @@ -1,14 +1,20 @@ -use std::fmt::Debug; +use std::{ + fmt::Debug, + iter::{Product, Sum}, + ops, +}; use alloy_primitives::B256; -use banderwagon::{CanonicalDeserialize, CanonicalSerialize, Fr, PrimeField, Zero}; -use derive_more::{Add, Constructor, Deref, From, Into, Neg, Sub}; +use ark_ff::batch_inversion_and_mul; +use banderwagon::{CanonicalDeserialize, CanonicalSerialize, Field, Fr, One, PrimeField, Zero}; +use derive_more::{Constructor, Deref, From, Into}; +use overload::overload; use serde::{Deserialize, Deserializer, Serialize, Serializer}; use ssz::{Decode, Encode}; use crate::Stem; -#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Constructor, From, Into, Deref, Neg, Add, Sub)] +#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Constructor, From, Into, Deref)] pub struct ScalarField(pub(crate) Fr); impl ScalarField { @@ -37,6 +43,32 @@ impl ScalarField { pub fn is_zero(&self) -> bool { self.0.is_zero() } + + pub fn one() -> Self { + Self(Fr::one()) + } + + pub fn inverse(&self) -> Option { + self.0.inverse().map(Self) + } + + /// Calculates inverse of all provided scalars, ignoring the ones with value zero. + pub fn batch_inversion(scalars: &mut [Self]) { + Self::batch_inverse_and_multiply(scalars, &Self::one()) + } + + /// Calculates inverses and multiplies them. + /// + /// Updates variable `values`: `v_i` => `m / v_i`. + /// + /// Ignores the zero values. + pub fn batch_inverse_and_multiply(values: &mut [Self], m: &Self) { + let mut frs = values.iter().map(|value| value.0).collect::>(); + batch_inversion_and_mul(&mut frs, m); + for (value, fr) in values.iter_mut().zip(frs.into_iter()) { + value.0 = fr; + } + } } impl From for ScalarField { @@ -117,3 +149,81 @@ impl Debug for ScalarField { B256::from(self).fmt(f) } } + +overload!(- (me: ?ScalarField) -> ScalarField { ScalarField(-me.0) }); + +overload!((lhs: &mut ScalarField) += (rhs: ?ScalarField) { lhs.0 += &rhs.0; }); +overload!((lhs: ScalarField) + (rhs: ?ScalarField) -> ScalarField { + let mut lhs = lhs; lhs += rhs; lhs +}); +overload!((lhs: &ScalarField) + (rhs: ?ScalarField) -> ScalarField { ScalarField(lhs.0) + rhs }); + +overload!((lhs: &mut ScalarField) -= (rhs: ?ScalarField) { lhs.0 -= &rhs.0; }); +overload!((lhs: ScalarField) - (rhs: ?ScalarField) -> ScalarField { + let mut lhs = lhs; lhs -= rhs; lhs +}); +overload!((lhs: &ScalarField) - (rhs: ?ScalarField) -> ScalarField { ScalarField(lhs.0) - rhs }); + +overload!((lhs: &mut ScalarField) *= (rhs: ?ScalarField) { lhs.0 *= &rhs.0; }); +overload!((lhs: ScalarField) * (rhs: ?ScalarField) -> ScalarField { + let mut lhs = lhs; lhs *= rhs; lhs +}); +overload!((lhs: &ScalarField) * (rhs: ?ScalarField) -> ScalarField { ScalarField(lhs.0) * rhs }); + +impl<'a> Sum<&'a Self> for ScalarField { + fn sum>(iter: I) -> Self { + iter.fold(ScalarField::zero(), |sum, item| sum + item) + } +} + +impl Sum for ScalarField { + fn sum>(iter: I) -> Self { + iter.fold(ScalarField::zero(), |sum, item| sum + item) + } +} + +impl<'a> Product<&'a Self> for ScalarField { + fn product>(iter: I) -> Self { + iter.fold(ScalarField::one(), |prod, item| prod * item) + } +} + +impl Product for ScalarField { + fn product>(iter: I) -> Self { + iter.fold(ScalarField::one(), |prod, item| prod * item) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn batch_inversion_and_multiplication() { + let mut values = vec![ + ScalarField::from(1), + ScalarField::from(10), + ScalarField::from(123), + ScalarField::from(0), + ScalarField::from(1_000_000), + ScalarField::from(1 << 30), + ScalarField::from(1 << 60), + ]; + let m = ScalarField::from(42); + + let expected = values + .iter() + .map(|v| { + if v.is_zero() { + v.clone() + } else { + &m * v.inverse().unwrap() + } + }) + .collect::>(); + + ScalarField::batch_inverse_and_multiply(&mut values, &m); + + assert_eq!(expected, values); + } +} diff --git a/portal-verkle-primitives/src/proof/lagrange_basis.rs b/portal-verkle-primitives/src/proof/lagrange_basis.rs new file mode 100644 index 0000000..5cfbf1e --- /dev/null +++ b/portal-verkle-primitives/src/proof/lagrange_basis.rs @@ -0,0 +1,132 @@ +use std::{array, ops}; + +use overload::overload; + +use crate::{ + constants::VERKLE_NODE_WIDTH, + msm::{DefaultMsm, MultiScalarMultiplicator}, + Point, ScalarField, +}; + +use super::precomputed_weights::PrecomputedWeights; + +/// The polynomial expressed using Lagrange's form. +/// +/// Some operations are non-optimal when polynomials are stored in coeficient (monomial) basis: +/// ```text +/// P(x) = v0 + v_1 * X + ... + x_n * X^n +/// ``` +/// +/// Another way to represent the polynomial is using Langrange basis, in which we store the value +/// that polynomial has on a given domain (in our case domain is `[0, 255]`). More precisely: +/// +/// ```text +/// P(x) = y_0 * L_0(x) + y_1 * L_1(x) + y_n * L_n(x) +/// ``` +/// +/// Where `y_i` is the evaluation of the polynomial at `i`: `y_i = P(i)` and `L_i(x)` is Lagrange +/// polynomial: +/// +/// ```text +/// x-j +/// L_i(x) = ∏ ----- +/// j≠i i-j +/// ``` +#[derive(Clone, Debug)] +pub struct LagrangeBasis { + y: [ScalarField; VERKLE_NODE_WIDTH], +} + +impl LagrangeBasis { + pub fn new(values: [ScalarField; VERKLE_NODE_WIDTH]) -> Self { + Self { y: values } + } + + pub fn zero() -> Self { + Self { + y: array::from_fn(|_| ScalarField::zero()), + } + } + + pub fn commit(&self) -> Point { + DefaultMsm.commit_lagrange(&self.y) + } + + /// Divides the polynomial `P(x)-P(k)` with `x-k`, where `k` is in domain. + /// + /// Let's call new polynomial `Q(x)`. We evaluate it on domain manually: + /// + /// - `i ≠ k` + /// ```text + /// Q(i) = (y_i - y_k) / (i - k) + /// ``` + /// - `i = k` - This case can be transofrmed using non obvious math tricks into: + /// ```text + /// Q(k) = ∑ -Q(j) * A'(k) / A'(j) + /// j≠k + /// ``` + pub fn divide_on_domain(&self, k: usize) -> Self { + let mut q = array::from_fn(|_| ScalarField::zero()); + for i in 0..VERKLE_NODE_WIDTH { + // 1/(i-k) + let inverse = match i { + i if i < k => -PrecomputedWeights::domain_inv(k - i), + i if i == k => continue, + i if i > k => PrecomputedWeights::domain_inv(i - k).clone(), + _ => unreachable!(), + }; + + // Q(i) = (y_i-y_k) / (i-k) + q[i] = (&self.y[i] - &self.y[k]) * inverse; + + q[k] -= &q[i] * PrecomputedWeights::a_prime(k) * PrecomputedWeights::a_prime_inv(i); + } + + Self::new(q) + } + + /// Calculates `P(k)` for `k` in domain + pub fn evaluate_in_domain(&self, k: usize) -> &ScalarField { + &self.y[k] + } + + /// Calculates `P(z)` for `z` outside domain + pub fn evaluate_outside_domain(&self, z: &ScalarField) -> ScalarField { + // Lagrange polinomials: L_i(z) + let l = PrecomputedWeights::evaluate_lagrange_polynomials(z); + l.into_iter().zip(&self.y).map(|(l_i, y_i)| l_i * y_i).sum() + } +} + +impl From<&[ScalarField]> for LagrangeBasis { + fn from(other: &[ScalarField]) -> Self { + assert!(other.len() == VERKLE_NODE_WIDTH); + Self { + y: array::from_fn(|i| other[i].clone()), + } + } +} + +overload!((lhs: &mut LagrangeBasis) += (rhs: ?LagrangeBasis) { + lhs.y.iter_mut().zip(&rhs.y).for_each(|(lhs, rhs)| *lhs += rhs) +}); +overload!((lhs: LagrangeBasis) + (rhs: ?LagrangeBasis) -> LagrangeBasis { + let mut lhs = lhs; lhs += rhs; lhs +}); + +overload!((lhs: &mut LagrangeBasis) -= (rhs: ?LagrangeBasis) { + lhs.y.iter_mut().zip(&rhs.y).for_each(|(lhs, rhs)| *lhs -= rhs) +}); +overload!((lhs: LagrangeBasis) - (rhs: ?LagrangeBasis) -> LagrangeBasis { + let mut lhs = lhs; lhs -= rhs; lhs +}); + +overload!((lhs: &mut LagrangeBasis) *= (rhs: ScalarField) { + lhs.y.iter_mut().for_each(|lhs| *lhs *= &rhs) +}); +overload!((lhs: &mut LagrangeBasis) *= (rhs: &ScalarField) { + lhs.y.iter_mut().for_each(|lhs| *lhs *= rhs) +}); +overload!((lhs: LagrangeBasis) * (rhs: ?ScalarField) -> LagrangeBasis { + let mut lhs = lhs; lhs *= rhs; lhs +}); diff --git a/portal-verkle-primitives/src/proof/mod.rs b/portal-verkle-primitives/src/proof/mod.rs index 7f2aee5..d824d7f 100644 --- a/portal-verkle-primitives/src/proof/mod.rs +++ b/portal-verkle-primitives/src/proof/mod.rs @@ -5,6 +5,8 @@ use ssz_types::{typenum, FixedVector}; use crate::{Point, ScalarField}; +pub mod lagrange_basis; +pub mod precomputed_weights; pub mod transcript; #[derive( diff --git a/portal-verkle-primitives/src/proof/precomputed_weights.rs b/portal-verkle-primitives/src/proof/precomputed_weights.rs new file mode 100644 index 0000000..02c318c --- /dev/null +++ b/portal-verkle-primitives/src/proof/precomputed_weights.rs @@ -0,0 +1,115 @@ +use std::array; + +use once_cell::sync::Lazy; + +use crate::{constants::VERKLE_NODE_WIDTH, ScalarField}; + +/// Precomputed weights for Lagrange polynomial (`L_i`) related calculations. +/// +/// Domain `D` is `[0, 255]` => `d = 256``. +/// ```text +/// Lagrange polynomials: +/// x-j +/// L_i(x) = ∏ ----- +/// j≠i i-j +/// +/// A(x) = ∏ (x-i) = (x-0)(x-1)...(x-(d-1)) +/// i +/// +/// A'(x) = ∑ ( ∏ (x-j) ) +/// i j≠i +/// = (x-1)(x-2)...(x-(d-1)) + +/// (x-0)(x-2)...(x-(d-1)) + +/// ... +/// (x-0)(x-1)...(x-(d-2)) +/// +/// Lagrange polynomials in barycentric form +/// x-j A(x) +/// L_i(x) = ∏ ----- = --------------- +/// j≠i i-j A'(i) * (x-i) +/// ``` +pub struct PrecomputedWeights { + /// The `A'(i)`, for i in domain + /// + /// ```text + /// A'(i) = ∏ (i-j) + /// j≠i + /// ``` + a_prime: [ScalarField; VERKLE_NODE_WIDTH], + /// The `1/A'(i)` , for i in domain + a_prime_inv: [ScalarField; VERKLE_NODE_WIDTH], + /// The `1/i` , for i in domain (except when i is zero, in which case value is zero) + domain_inv: [ScalarField; VERKLE_NODE_WIDTH], +} + +static INSTANCE: Lazy = Lazy::new(PrecomputedWeights::new); + +impl PrecomputedWeights { + fn new() -> Self { + let a_prime = array::from_fn(|i| { + // ∏ (i-j) + // j≠i + (0..VERKLE_NODE_WIDTH) + .filter(|j| i != *j) + .map(|j| ScalarField::from(i as u64) - ScalarField::from(j as u64)) + .product() + }); + + let mut a_prime_inv = a_prime.clone(); + ScalarField::batch_inversion(&mut a_prime_inv); + + let mut domain_inv = array::from_fn(|i| ScalarField::from(i as u64)); + ScalarField::batch_inversion(&mut domain_inv); + + Self { + a_prime, + a_prime_inv, + domain_inv, + } + } + + /// Evaluates polynomial `A` at a given point `z` + /// + /// `A(z) = ∏ (z - i) = (z-0)(z-1)...(z-d)` + pub fn evaluate_a(z: &ScalarField) -> ScalarField { + (0..VERKLE_NODE_WIDTH as u64) + .map(|i| z - ScalarField::from(i)) + .product() + } + + /// Returns `A'(i)` for i in domain + pub fn a_prime(i: usize) -> &'static ScalarField { + &INSTANCE.a_prime[i] + } + + /// Returns `1/A'(i)` for i in domain + pub fn a_prime_inv(i: usize) -> &'static ScalarField { + &INSTANCE.a_prime_inv[i] + } + + pub fn domain_inv(i: usize) -> &'static ScalarField { + assert_ne!(i, 0); + &INSTANCE.domain_inv[i] + } + + /// Evaluates Lagrange polynomials `L_i` at a given point `z`, using barycentric formula. + /// + /// ```text + /// z - j A(z) + /// L_i(z) = ∏ ------- = ----------------- + /// j≠i i - j A'(i) * (z - i) + /// ``` + pub fn evaluate_lagrange_polynomials(z: &ScalarField) -> [ScalarField; VERKLE_NODE_WIDTH] { + // A(z) = (z-0)(z-1)(z-2)...(z-d) + let a_z = Self::evaluate_a(z); + + // A'(i) * (z-i) + let mut lagrange_evaluations: [ScalarField; VERKLE_NODE_WIDTH] = + array::from_fn(|i| (z - ScalarField::from(i as u64)) * Self::a_prime(i)); + + // A(z) / (A'(i) * (z-i)) + ScalarField::batch_inverse_and_multiply(&mut lagrange_evaluations, &a_z); + + lagrange_evaluations + } +}