Skip to content

Commit

Permalink
feat: implement LagrangeBasis polynomials and PrecomputedWeights
Browse files Browse the repository at this point in the history
  • Loading branch information
morph-dev committed Jul 20, 2024
1 parent 031b3a7 commit 34ef598
Show file tree
Hide file tree
Showing 7 changed files with 387 additions and 21 deletions.
7 changes: 7 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions portal-verkle-primitives/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ derive_more = "0.99"
ethereum_ssz = "0.5"
ethereum_ssz_derive = "0.5"
once_cell = "1"
overload = "0.1"
serde = { version = "1", features = ["derive"] }
sha2 = "0.10"
ssz_types = "0.6"
Expand Down
33 changes: 16 additions & 17 deletions portal-verkle-primitives/src/ec/point.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,15 @@
use std::{
fmt::Debug,
iter::Sum,
ops::{Add, Sub},
};
use std::{fmt::Debug, iter::Sum, ops};

use alloy_primitives::B256;
use banderwagon::{CanonicalDeserialize, CanonicalSerialize, Element};
use derive_more::{Add, AddAssign, Constructor, Deref, From, Into, Sum};
use derive_more::{Constructor, Deref, From, Into};
use overload::overload;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use ssz::{Decode, Encode};

use crate::ScalarField;

#[derive(Clone, PartialEq, Eq, Constructor, From, Into, Deref, Add, AddAssign, Sum)]
#[derive(Clone, PartialEq, Eq, Constructor, From, Into, Deref)]
pub struct Point(Element);

impl Point {
Expand Down Expand Up @@ -110,19 +107,21 @@ impl Debug for Point {
}
}

impl Add<&Self> for Point {
type Output = Self;
overload!(- (me: ?Point) -> Point { Point(-me.0) });

fn add(self, rhs: &Self) -> Self::Output {
Self(self.0 + rhs.0)
}
}
overload!((lhs: &mut Point) += (rhs: ?Point) { lhs.0 += rhs.0 });
overload!((lhs: Point) + (rhs: ?Point) -> Point {
let mut lhs = lhs; lhs += rhs; lhs
});

impl Sub<&Self> for Point {
type Output = Self;
overload!((lhs: &mut Point) -= (rhs: ?Point) { lhs.0 = lhs.0 - rhs.0 });
overload!((lhs: Point) - (rhs: ?Point) -> Point {
let mut lhs = lhs; lhs -= rhs; lhs
});

fn sub(self, rhs: &Self) -> Self::Output {
Self(self.0 - rhs.0)
impl Sum for Point {
fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
iter.fold(Self::zero(), |sum, item| sum + item)
}
}

Expand Down
118 changes: 114 additions & 4 deletions portal-verkle-primitives/src/ec/scalar_field.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,20 @@
use std::fmt::Debug;
use std::{
fmt::Debug,
iter::{Product, Sum},
ops,
};

use alloy_primitives::B256;
use banderwagon::{CanonicalDeserialize, CanonicalSerialize, Fr, PrimeField, Zero};
use derive_more::{Add, Constructor, Deref, From, Into, Neg, Sub};
use ark_ff::batch_inversion_and_mul;
use banderwagon::{CanonicalDeserialize, CanonicalSerialize, Field, Fr, One, PrimeField, Zero};
use derive_more::{Constructor, Deref, From, Into};
use overload::overload;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use ssz::{Decode, Encode};

use crate::Stem;

#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Constructor, From, Into, Deref, Neg, Add, Sub)]
#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Constructor, From, Into, Deref)]
pub struct ScalarField(pub(crate) Fr);

impl ScalarField {
Expand Down Expand Up @@ -37,6 +43,32 @@ impl ScalarField {
pub fn is_zero(&self) -> bool {
self.0.is_zero()
}

pub fn one() -> Self {
Self(Fr::one())
}

pub fn inverse(&self) -> Option<Self> {
self.0.inverse().map(Self)
}

/// Calculates inverse of all provided scalars, ignoring the ones with value zero.
pub fn batch_inversion(scalars: &mut [Self]) {
Self::batch_inverse_and_multiply(scalars, &Self::one())
}

/// Calculates inverses and multiplies them.
///
/// Updates variable `values`: `v_i` => `m / v_i`.
///
/// Ignores the zero values.
pub fn batch_inverse_and_multiply(values: &mut [Self], m: &Self) {
let mut frs = values.iter().map(|value| value.0).collect::<Vec<_>>();
batch_inversion_and_mul(&mut frs, m);
for (value, fr) in values.iter_mut().zip(frs.into_iter()) {
value.0 = fr;
}
}
}

impl From<B256> for ScalarField {
Expand Down Expand Up @@ -117,3 +149,81 @@ impl Debug for ScalarField {
B256::from(self).fmt(f)
}
}

overload!(- (me: ?ScalarField) -> ScalarField { ScalarField(-me.0) });

overload!((lhs: &mut ScalarField) += (rhs: ?ScalarField) { lhs.0 += &rhs.0; });
overload!((lhs: ScalarField) + (rhs: ?ScalarField) -> ScalarField {
let mut lhs = lhs; lhs += rhs; lhs
});
overload!((lhs: &ScalarField) + (rhs: ?ScalarField) -> ScalarField { ScalarField(lhs.0) + rhs });

overload!((lhs: &mut ScalarField) -= (rhs: ?ScalarField) { lhs.0 -= &rhs.0; });
overload!((lhs: ScalarField) - (rhs: ?ScalarField) -> ScalarField {
let mut lhs = lhs; lhs -= rhs; lhs
});
overload!((lhs: &ScalarField) - (rhs: ?ScalarField) -> ScalarField { ScalarField(lhs.0) - rhs });

overload!((lhs: &mut ScalarField) *= (rhs: ?ScalarField) { lhs.0 *= &rhs.0; });
overload!((lhs: ScalarField) * (rhs: ?ScalarField) -> ScalarField {
let mut lhs = lhs; lhs *= rhs; lhs
});
overload!((lhs: &ScalarField) * (rhs: ?ScalarField) -> ScalarField { ScalarField(lhs.0) * rhs });

impl<'a> Sum<&'a Self> for ScalarField {
fn sum<I: Iterator<Item = &'a Self>>(iter: I) -> Self {
iter.fold(ScalarField::zero(), |sum, item| sum + item)
}
}

impl Sum for ScalarField {
fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
iter.fold(ScalarField::zero(), |sum, item| sum + item)
}
}

impl<'a> Product<&'a Self> for ScalarField {
fn product<I: Iterator<Item = &'a Self>>(iter: I) -> Self {
iter.fold(ScalarField::one(), |prod, item| prod * item)
}
}

impl Product for ScalarField {
fn product<I: Iterator<Item = Self>>(iter: I) -> Self {
iter.fold(ScalarField::one(), |prod, item| prod * item)
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn batch_inversion_and_multiplication() {
let mut values = vec![
ScalarField::from(1),
ScalarField::from(10),
ScalarField::from(123),
ScalarField::from(0),
ScalarField::from(1_000_000),
ScalarField::from(1 << 30),
ScalarField::from(1 << 60),
];
let m = ScalarField::from(42);

let expected = values
.iter()
.map(|v| {
if v.is_zero() {
v.clone()
} else {
&m * v.inverse().unwrap()
}
})
.collect::<Vec<_>>();

ScalarField::batch_inverse_and_multiply(&mut values, &m);

assert_eq!(expected, values);
}
}
132 changes: 132 additions & 0 deletions portal-verkle-primitives/src/proof/lagrange_basis.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
use std::{array, ops};

use overload::overload;

use crate::{
constants::VERKLE_NODE_WIDTH,
msm::{DefaultMsm, MultiScalarMultiplicator},
Point, ScalarField,
};

use super::precomputed_weights::PrecomputedWeights;

/// The polynomial expressed using Lagrange's form.
///
/// Some operations are non-optimal when polynomials are stored in coeficient (monomial) basis:
/// ```text
/// P(x) = v0 + v_1 * X + ... + x_n * X^n
/// ```
///
/// Another way to represent the polynomial is using Langrange basis, in which we store the value
/// that polynomial has on a given domain (in our case domain is `[0, 255]`). More precisely:
///
/// ```text
/// P(x) = y_0 * L_0(x) + y_1 * L_1(x) + y_n * L_n(x)
/// ```
///
/// Where `y_i` is the evaluation of the polynomial at `i`: `y_i = P(i)` and `L_i(x)` is Lagrange
/// polynomial:
///
/// ```text
/// x-j
/// L_i(x) = ∏ -----
/// j≠i i-j
/// ```
#[derive(Clone, Debug)]
pub struct LagrangeBasis {
y: [ScalarField; VERKLE_NODE_WIDTH],
}

impl LagrangeBasis {
pub fn new(values: [ScalarField; VERKLE_NODE_WIDTH]) -> Self {
Self { y: values }
}

pub fn zero() -> Self {
Self {
y: array::from_fn(|_| ScalarField::zero()),
}
}

pub fn commit(&self) -> Point {
DefaultMsm.commit_lagrange(&self.y)
}

/// Divides the polynomial `P(x)-P(k)` with `x-k`, where `k` is in domain.
///
/// Let's call new polynomial `Q(x)`. We evaluate it on domain manually:
///
/// - `i ≠ k`
/// ```text
/// Q(i) = (y_i - y_k) / (i - k)
/// ```
/// - `i = k` - This case can be transofrmed using non obvious math tricks into:
/// ```text
/// Q(k) = ∑ -Q(j) * A'(k) / A'(j)
/// j≠k
/// ```
pub fn divide_on_domain(&self, k: usize) -> Self {
let mut q = array::from_fn(|_| ScalarField::zero());
for i in 0..VERKLE_NODE_WIDTH {
// 1/(i-k)
let inverse = match i {
i if i < k => -PrecomputedWeights::domain_inv(k - i),
i if i == k => continue,
i if i > k => PrecomputedWeights::domain_inv(i - k).clone(),
_ => unreachable!(),
};

// Q(i) = (y_i-y_k) / (i-k)
q[i] = (&self.y[i] - &self.y[k]) * inverse;

q[k] -= &q[i] * PrecomputedWeights::a_prime(k) * PrecomputedWeights::a_prime_inv(i);
}

Self::new(q)
}

/// Calculates `P(k)` for `k` in domain
pub fn evaluate_in_domain(&self, k: usize) -> &ScalarField {
&self.y[k]
}

/// Calculates `P(z)` for `z` outside domain
pub fn evaluate_outside_domain(&self, z: &ScalarField) -> ScalarField {
// Lagrange polinomials: L_i(z)
let l = PrecomputedWeights::evaluate_lagrange_polynomials(z);
l.into_iter().zip(&self.y).map(|(l_i, y_i)| l_i * y_i).sum()
}
}

impl From<&[ScalarField]> for LagrangeBasis {
fn from(other: &[ScalarField]) -> Self {
assert!(other.len() == VERKLE_NODE_WIDTH);
Self {
y: array::from_fn(|i| other[i].clone()),
}
}
}

overload!((lhs: &mut LagrangeBasis) += (rhs: ?LagrangeBasis) {
lhs.y.iter_mut().zip(&rhs.y).for_each(|(lhs, rhs)| *lhs += rhs)
});
overload!((lhs: LagrangeBasis) + (rhs: ?LagrangeBasis) -> LagrangeBasis {
let mut lhs = lhs; lhs += rhs; lhs
});

overload!((lhs: &mut LagrangeBasis) -= (rhs: ?LagrangeBasis) {
lhs.y.iter_mut().zip(&rhs.y).for_each(|(lhs, rhs)| *lhs -= rhs)
});
overload!((lhs: LagrangeBasis) - (rhs: ?LagrangeBasis) -> LagrangeBasis {
let mut lhs = lhs; lhs -= rhs; lhs
});

overload!((lhs: &mut LagrangeBasis) *= (rhs: ScalarField) {
lhs.y.iter_mut().for_each(|lhs| *lhs *= &rhs)
});
overload!((lhs: &mut LagrangeBasis) *= (rhs: &ScalarField) {
lhs.y.iter_mut().for_each(|lhs| *lhs *= rhs)
});
overload!((lhs: LagrangeBasis) * (rhs: ?ScalarField) -> LagrangeBasis {
let mut lhs = lhs; lhs *= rhs; lhs
});
2 changes: 2 additions & 0 deletions portal-verkle-primitives/src/proof/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ use ssz_types::{typenum, FixedVector};

use crate::{Point, ScalarField};

pub mod lagrange_basis;
pub mod precomputed_weights;
pub mod transcript;

#[derive(
Expand Down
Loading

0 comments on commit 34ef598

Please sign in to comment.