Skip to content

Commit

Permalink
perf(lc): optimize sequential inserts
Browse files Browse the repository at this point in the history
  • Loading branch information
dignifiedquire committed Sep 7, 2021
1 parent d3d00a2 commit 38ed8bf
Showing 1 changed file with 144 additions and 37 deletions.
181 changes: 144 additions & 37 deletions src/lc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,88 @@ pub enum Index {
/// in the scalar field of a pairing-friendly elliptic curve group.
#[derive(Clone)]
pub struct LinearCombination<E: ScalarEngine> {
inputs: Vec<(usize, E::Fr)>,
aux: Vec<(usize, E::Fr)>,
inputs: Indexer<E::Fr>,
aux: Indexer<E::Fr>,
}

#[derive(Clone)]
struct Indexer<T> {
values: Vec<(usize, T)>,
// (index, key) of the last insertion operation
last_inserted: Option<(usize, usize)>,
}

impl<T> Default for Indexer<T> {
fn default() -> Self {
Indexer {
values: Vec::new(),
last_inserted: None,
}
}
}

impl<T> Indexer<T> {
pub fn from_value(index: usize, value: T) -> Self {
Indexer {
values: vec![(index, value)],
last_inserted: Some((0, index)),
}
}

pub fn iter(&self) -> impl Iterator<Item = (&usize, &T)> + '_ {
self.values.iter().map(|(key, value)| (key, value))
}

pub fn iter_mut(&mut self) -> impl Iterator<Item = (&mut usize, &mut T)> + '_ {
self.values.iter_mut().map(|(key, value)| (key, value))
}

pub fn insert_or_update<F, G>(&mut self, key: usize, insert: F, update: G)
where
F: FnOnce() -> T,
G: FnOnce(&mut T),
{
if let Some((last_index, last_key)) = self.last_inserted {
if last_key == key {
// update the same key again
update(&mut self.values[last_index].1);
return;
} else if last_key + 1 == key {
// optimization for follow on updates
let i = last_index + 1;
if i >= self.values.len() {
// insert at the end
self.values.push((key, insert()));
self.last_inserted = Some((i, key));
} else if self.values[i].0 == key {
// update
update(&mut self.values[i].1);
} else {
// insert
self.values.insert(i, (key, insert()));
self.last_inserted = Some((i, key));
}
return;
}
}
match self.values.binary_search_by_key(&key, |(k, _)| *k) {
Ok(i) => {
update(&mut self.values[i].1);
}
Err(i) => {
self.values.insert(i, (key, insert()));
self.last_inserted = Some((i, key));
}
}
}

pub fn len(&self) -> usize {
self.values.len()
}

pub fn is_empty(&self) -> bool {
self.values.is_empty()
}
}

impl<E: ScalarEngine> Default for LinearCombination<E> {
Expand All @@ -47,11 +127,28 @@ impl<E: ScalarEngine> Default for LinearCombination<E> {
impl<E: ScalarEngine> LinearCombination<E> {
pub fn zero() -> LinearCombination<E> {
LinearCombination {
inputs: Vec::new(),
aux: Vec::new(),
inputs: Default::default(),
aux: Default::default(),
}
}

pub fn from_coeff(var: Variable, coeff: E::Fr) -> Self {
match var {
Variable(Index::Input(i)) => Self {
inputs: Indexer::from_value(i, coeff),
aux: Default::default(),
},
Variable(Index::Aux(i)) => Self {
inputs: Default::default(),
aux: Indexer::from_value(i, coeff),
},
}
}

pub fn from_variable(var: Variable) -> Self {
Self::from_coeff(var, E::Fr::one())
}

pub fn iter(&self) -> impl Iterator<Item = (Variable, &E::Fr)> + '_ {
self.inputs
.iter()
Expand All @@ -60,12 +157,12 @@ impl<E: ScalarEngine> LinearCombination<E> {
}

#[inline]
pub(crate) fn iter_inputs(&self) -> impl Iterator<Item = &(usize, E::Fr)> + '_ {
pub(crate) fn iter_inputs(&self) -> impl Iterator<Item = (&usize, &E::Fr)> + '_ {
self.inputs.iter()
}

#[inline]
pub(crate) fn iter_aux(&self) -> impl Iterator<Item = &(usize, E::Fr)> + '_ {
pub(crate) fn iter_aux(&self) -> impl Iterator<Item = (&usize, &E::Fr)> + '_ {
self.aux.iter()
}

Expand All @@ -80,32 +177,16 @@ impl<E: ScalarEngine> LinearCombination<E> {
)
}

#[inline]
fn add_assign_unsimplified_input(&mut self, new_var: usize, coeff: E::Fr) {
match self
.inputs
.binary_search_by_key(&new_var, |(var, _coeff)| *var)
{
Ok(index) => {
self.inputs[index].1.add_assign(&coeff);
}
Err(index) => {
self.inputs.insert(index, (new_var, coeff));
}
}
self.inputs
.insert_or_update(new_var, || coeff, |val| val.add_assign(&coeff));
}

#[inline]
fn add_assign_unsimplified_aux(&mut self, new_var: usize, coeff: E::Fr) {
match self
.aux
.binary_search_by_key(&new_var, |(var, _coeff)| *var)
{
Ok(index) => {
self.aux[index].1.add_assign(&coeff);
}
Err(index) => {
self.aux.insert(index, (new_var, coeff));
}
}
self.aux
.insert_or_update(new_var, || coeff, |val| val.add_assign(&coeff));
}

pub fn add_unsimplified(mut self, (coeff, var): (E::Fr, Variable)) -> LinearCombination<E> {
Expand All @@ -121,11 +202,13 @@ impl<E: ScalarEngine> LinearCombination<E> {
self
}

#[inline]
fn sub_assign_unsimplified_input(&mut self, new_var: usize, mut coeff: E::Fr) {
coeff.negate();
self.add_assign_unsimplified_input(new_var, coeff);
}

#[inline]
fn sub_assign_unsimplified_aux(&mut self, new_var: usize, mut coeff: E::Fr) {
coeff.negate();
self.add_assign_unsimplified_aux(new_var, coeff);
Expand All @@ -149,7 +232,7 @@ impl<E: ScalarEngine> LinearCombination<E> {
}

pub fn is_empty(&self) -> bool {
self.len() == 0
self.inputs.is_empty() && self.aux.is_empty()
}

pub(crate) fn eval(
Expand Down Expand Up @@ -232,11 +315,11 @@ impl<'a, E: ScalarEngine> Add<&'a LinearCombination<E>> for LinearCombination<E>
type Output = LinearCombination<E>;

fn add(mut self, other: &'a LinearCombination<E>) -> LinearCombination<E> {
for (var, val) in &other.inputs {
for (var, val) in other.inputs.iter() {
self.add_assign_unsimplified_input(*var, *val);
}

for (var, val) in &other.aux {
for (var, val) in other.aux.iter() {
self.add_assign_unsimplified_aux(*var, *val);
}

Expand All @@ -248,11 +331,11 @@ impl<'a, E: ScalarEngine> Sub<&'a LinearCombination<E>> for LinearCombination<E>
type Output = LinearCombination<E>;

fn sub(mut self, other: &'a LinearCombination<E>) -> LinearCombination<E> {
for (var, val) in &other.inputs {
for (var, val) in other.inputs.iter() {
self.sub_assign_unsimplified_input(*var, *val);
}

for (var, val) in &other.aux {
for (var, val) in other.aux.iter() {
self.sub_assign_unsimplified_aux(*var, *val);
}

Expand All @@ -264,13 +347,13 @@ impl<'a, E: ScalarEngine> Add<(E::Fr, &'a LinearCombination<E>)> for LinearCombi
type Output = LinearCombination<E>;

fn add(mut self, (coeff, other): (E::Fr, &'a LinearCombination<E>)) -> LinearCombination<E> {
for (var, val) in &other.inputs {
for (var, val) in other.inputs.iter() {
let mut tmp = *val;
tmp.mul_assign(&coeff);
self.add_assign_unsimplified_input(*var, tmp);
}

for (var, val) in &other.aux {
for (var, val) in other.aux.iter() {
let mut tmp = *val;
tmp.mul_assign(&coeff);
self.add_assign_unsimplified_aux(*var, tmp);
Expand All @@ -284,13 +367,13 @@ impl<'a, E: ScalarEngine> Sub<(E::Fr, &'a LinearCombination<E>)> for LinearCombi
type Output = LinearCombination<E>;

fn sub(mut self, (coeff, other): (E::Fr, &'a LinearCombination<E>)) -> LinearCombination<E> {
for (var, val) in &other.inputs {
for (var, val) in other.inputs.iter() {
let mut tmp = *val;
tmp.mul_assign(&coeff);
self.sub_assign_unsimplified_input(*var, tmp);
}

for (var, val) in &other.aux {
for (var, val) in other.aux.iter() {
let mut tmp = *val;
tmp.mul_assign(&coeff);
self.sub_assign_unsimplified_aux(*var, tmp);
Expand Down Expand Up @@ -332,4 +415,28 @@ mod tests {
_ => panic!("unexpected variable type"),
});
}

#[test]
fn test_insert_or_update() {
let mut indexer = Indexer::default();
let one = <Bls12 as ScalarEngine>::Fr::one();
let mut two = one;
two.add_assign(&one);

indexer.insert_or_update(2, || one, |v| v.add_assign(&one));
assert_eq!(&indexer.values, &[(2, one)]);
assert_eq!(&indexer.last_inserted, &Some((0, 2)));

indexer.insert_or_update(3, || one, |v| v.add_assign(&one));
assert_eq!(&indexer.values, &[(2, one), (3, one)]);
assert_eq!(&indexer.last_inserted, &Some((1, 3)));

indexer.insert_or_update(1, || one, |v| v.add_assign(&one));
assert_eq!(&indexer.values, &[(1, one), (2, one), (3, one)]);
assert_eq!(&indexer.last_inserted, &Some((0, 1)));

indexer.insert_or_update(2, || one, |v| v.add_assign(&one));
assert_eq!(&indexer.values, &[(1, one), (2, two), (3, one)]);
assert_eq!(&indexer.last_inserted, &Some((0, 1)));
}
}

0 comments on commit 38ed8bf

Please sign in to comment.