Skip to content

Commit

Permalink
feat: skip zeroes in msm (privacy-scaling-explorations#168)
Browse files Browse the repository at this point in the history
* feat: skip zeroes in msm

* Update src/msm.rs

Co-authored-by: David Nevado <davidnevadoc@users.noreply.github.com>

---------

Co-authored-by: David Nevado <davidnevadoc@users.noreply.github.com>
  • Loading branch information
2 people authored and jonathanpwang committed Aug 13, 2024
1 parent 7613f82 commit dcec17c
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 37 deletions.
12 changes: 9 additions & 3 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
[package]
name = "halo2curves-axiom"
version = "0.6.1"
authors = ["Privacy Scaling Explorations team", "Taiko Labs", "Intrinsic Technologies"]
authors = [
"Privacy Scaling Explorations team",
"Taiko Labs",
"Intrinsic Technologies",
]
license = "MIT/Apache-2.0"
edition = "2021"
repository = "https://github.com/axiom-crypto/halo2curves"
Expand Down Expand Up @@ -39,7 +43,10 @@ num-traits = "0.2"
paste = "1.0.11"
serde = { version = "1.0", default-features = false, optional = true }
serde_arrays = { version = "0.1.0", optional = true }
hex = { version = "0.4", optional = true, default-features = false, features = ["alloc", "serde"] }
hex = { version = "0.4", optional = true, default-features = false, features = [
"alloc",
"serde",
] }
blake2b_simd = "1"
rayon = "1.8"
digest = "0.10.7"
Expand Down Expand Up @@ -87,4 +94,3 @@ harness = false
[[bench]]
name = "msm"
harness = false
required-features = ["multicore"]
110 changes: 77 additions & 33 deletions benches/msm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@
extern crate criterion;

use criterion::{BenchmarkId, Criterion};
use ff::Field;
use ff::{Field, PrimeField};
use group::prime::PrimeCurveAffine;
use halo2curves_axiom::bn256::{Fr as Scalar, G1Affine as Point};
use halo2curves_axiom::msm::{best_multiexp, multiexp_serial};
use maybe_rayon::current_thread_index;
use maybe_rayon::prelude::{IntoParallelIterator, ParallelIterator};
use rand_core::SeedableRng;
use rand_core::{RngCore, SeedableRng};
use rand_xorshift::XorShiftRng;
use rayon::current_thread_index;
use rayon::prelude::{IntoParallelIterator, ParallelIterator};
use std::time::SystemTime;

const SAMPLE_SIZE: usize = 10;
Expand All @@ -30,15 +30,15 @@ const SEED: [u8; 16] = [
0x59, 0x62, 0xbe, 0x5d, 0x76, 0x3d, 0x31, 0x8d, 0x17, 0xdb, 0x37, 0x32, 0x54, 0x06, 0xbc, 0xe5,
];

fn generate_coefficients_and_curvepoints(k: u8) -> (Vec<Scalar>, Vec<Point>) {
fn generate_curvepoints(k: u8) -> Vec<Point> {
let n: u64 = {
assert!(k < 64);
1 << k
};

println!("\n\nGenerating 2^{k} = {n} coefficients and curve points..",);
println!("Generating 2^{k} = {n} curve points..",);
let timer = SystemTime::now();
let coeffs = (0..n)
let bases = (0..n)
.into_par_iter()
.map_init(
|| {
Expand All @@ -51,10 +51,36 @@ fn generate_coefficients_and_curvepoints(k: u8) -> (Vec<Scalar>, Vec<Point>) {
}
XorShiftRng::from_seed(thread_seed)
},
|rng, _| Scalar::random(rng),
|rng, _| Point::random(rng),
)
.collect();
let bases = (0..n)
let end = timer.elapsed().unwrap();
println!(
"Generating 2^{k} = {n} curve points took: {} sec.\n\n",
end.as_secs()
);
bases
}

fn generate_coefficients(k: u8, bits: usize) -> Vec<Scalar> {
let n: u64 = {
assert!(k < 64);
1 << k
};
let max_val: Option<u128> = match bits {
1 => Some(1),
8 => Some(0xff),
16 => Some(0xffff),
32 => Some(0xffff_ffff),
64 => Some(0xffff_ffff_ffff_ffff),
128 => Some(0xffff_ffff_ffff_ffff_ffff_ffff_ffff_ffff),
256 => None,
_ => panic!("unexpected bit size {}", bits),
};

println!("Generating 2^{k} = {n} coefficients..",);
let timer = SystemTime::now();
let coeffs = (0..n)
.into_par_iter()
.map_init(
|| {
Expand All @@ -67,16 +93,25 @@ fn generate_coefficients_and_curvepoints(k: u8) -> (Vec<Scalar>, Vec<Point>) {
}
XorShiftRng::from_seed(thread_seed)
},
|rng, _| Point::random(rng),
|rng, _| {
if let Some(max_val) = max_val {
let v_lo = rng.next_u64() as u128;
let v_hi = rng.next_u64() as u128;
let mut v = v_lo + (v_hi << 64);
v &= max_val; // Mask the 128bit value to get a lower number of bits
Scalar::from_u128(v)
} else {
Scalar::random(rng)
}
},
)
.collect();
let end = timer.elapsed().unwrap();
println!(
"Generating 2^{k} = {n} coefficients and curve points took: {} sec.\n\n",
"Generating 2^{k} = {n} coefficients took: {} sec.\n\n",
end.as_secs()
);

(coeffs, bases)
coeffs
}

fn msm(c: &mut Criterion) {
Expand All @@ -86,28 +121,37 @@ fn msm(c: &mut Criterion) {
.chain(MULTICORE_RANGE.iter())
.max()
.unwrap_or(&16);
let (coeffs, bases) = generate_coefficients_and_curvepoints(max_k);
let bases = generate_curvepoints(max_k);
let bits = [1, 8, 16, 32, 64, 128, 256];
let coeffs: Vec<_> = bits
.iter()
.map(|b| generate_coefficients(max_k, *b))
.collect();

for k in SINGLECORE_RANGE {
group
.bench_function(BenchmarkId::new("singlecore", k), |b| {
assert!(k < 64);
let n: usize = 1 << k;
let mut acc = Point::identity().into();
b.iter(|| multiexp_serial(&coeffs[..n], &bases[..n], &mut acc));
})
.sample_size(10);
}
for k in MULTICORE_RANGE {
group
.bench_function(BenchmarkId::new("multicore", k), |b| {
assert!(k < 64);
let n: usize = 1 << k;
b.iter(|| {
best_multiexp(&coeffs[..n], &bases[..n]);
for (b_index, b) in bits.iter().enumerate() {
for k in SINGLECORE_RANGE {
let id = format!("{b}b_{k}");
group
.bench_function(BenchmarkId::new("singlecore", id), |b| {
assert!(k < 64);
let n: usize = 1 << k;
let mut acc = Point::identity().into();
b.iter(|| multiexp_serial(&coeffs[b_index][..n], &bases[..n], &mut acc));
})
.sample_size(10);
}
for k in MULTICORE_RANGE {
let id = format!("{b}b_{k}");
group
.bench_function(BenchmarkId::new("multicore", id), |b| {
assert!(k < 64);
let n: usize = 1 << k;
b.iter(|| {
best_multiexp(&coeffs[b_index][..n], &bases[..n]);
})
})
})
.sample_size(SAMPLE_SIZE);
.sample_size(SAMPLE_SIZE);
}
}
group.finish();
}
Expand Down
20 changes: 19 additions & 1 deletion src/msm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,25 @@ pub fn multiexp_serial<C: CurveAffine>(coeffs: &[C::Scalar], bases: &[C], acc: &
(f64::from(bases.len() as u32)).ln().ceil() as usize
};

let number_of_windows = C::Scalar::NUM_BITS as usize / c + 1;
let field_byte_size = C::Scalar::NUM_BITS.div_ceil(8u32) as usize;
// OR all coefficients in order to make a mask to figure out the maximum number of bytes used
// among all coefficients.
let mut acc_or = vec![0; field_byte_size];
for coeff in &coeffs {
for (acc_limb, limb) in acc_or.iter_mut().zip(coeff.as_ref().iter()) {
*acc_limb = *acc_limb | *limb;
}
}
let max_byte_size = field_byte_size
- acc_or
.iter()
.rev()
.position(|v| *v != 0)
.unwrap_or(field_byte_size);
if max_byte_size == 0 {
return;
}
let number_of_windows = max_byte_size * 8 as usize / c + 1;

for current_window in (0..number_of_windows).rev() {
for _ in 0..c {
Expand Down

0 comments on commit dcec17c

Please sign in to comment.