Skip to content
This repository has been archived by the owner on Aug 30, 2022. It is now read-only.

first class scalar #712

Merged
merged 7 commits into from
Feb 9, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 43 additions & 41 deletions rust/xaynet-core/src/mask/masking.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use num::{
bigint::{BigInt, BigUint, ToBigInt},
clamp,
rational::Ratio,
traits::clamp_max,
};
use rand::SeedableRng;
use rand_chacha::ChaCha20Rng;
Expand All @@ -19,8 +20,9 @@ use crate::{
crypto::{prng::generate_integer, ByteObject},
mask::{
config::MaskConfigPair,
model::{float_to_ratio_bounded, Model},
model::Model,
object::{MaskObject, MaskUnit, MaskVect},
scalar::Scalar,
seed::MaskSeed,
},
};
Expand Down Expand Up @@ -353,7 +355,7 @@ impl Masker {
/// proceeds in reverse order.
///
/// [`unmask()`]: Aggregation::unmask
pub fn mask(self, scalar: f64, model: &Model) -> (MaskSeed, MaskObject) {
pub fn mask(self, scalar: Scalar, model: &Model) -> (MaskSeed, MaskObject) {
let (random_int, mut random_ints) = self.random_ints();
let Self { config, seed } = self;
let MaskConfigPair {
Expand All @@ -363,9 +365,8 @@ impl Masker {

// clamp the scalar
let add_shift_1 = config_1.add_shift();
let scalar_ratio = float_to_ratio_bounded(scalar);
let zero = Ratio::<BigInt>::from_float(0_f64).unwrap();
let scalar_clamped = clamp(&scalar_ratio, &zero, &add_shift_1);
let scalar_ratio = scalar.into();
let scalar_clamped = clamp_max(&scalar_ratio, &add_shift_1);

let exp_shift_n = config_n.exp_shift();
let add_shift_n = config_n.add_shift();
Expand Down Expand Up @@ -437,6 +438,7 @@ mod tests {
ModelType::M3,
},
model::FromPrimitives,
scalar::FromPrimitive,
};

/// Generate tests for masking and unmasking of a single model:
Expand Down Expand Up @@ -491,7 +493,7 @@ mod tests {
// b. derive the mask corresponding to the seed used
// c. unmask the model and check it against the original one.
let (mask_seed, masked_model) =
Masker::new(config.into()).mask(1_f64, &model);
Masker::new(config.into()).mask(Scalar::unit(), &model);
assert_eq!(masked_model.vect.data.len(), vect_len);
assert!(masked_model.is_valid());

Expand Down Expand Up @@ -631,7 +633,8 @@ mod tests {
};
let eps = [<$data:lower>]::EPSILON;
let mut prng = ChaCha20Rng::from_seed(MaskSeed::generate().as_array());
let scalar = Uniform::new_inclusive(eps, bound).sample(&mut prng) as f64;
let random_weight = Uniform::new_inclusive(eps, bound).sample(&mut prng);
let scalar = Scalar::from_primitive(random_weight).unwrap();
let model = Model::from_primitives(iter::repeat(1).take(vect_len)).unwrap();
assert_eq!(model.len(), vect_len);

Expand Down Expand Up @@ -865,6 +868,7 @@ mod tests {
model_type: M3,
};
let vect_len = $len as usize;
let model_count = $count as usize;

// Step 2: Generate random models
let bound = if $bound == 0 {
Expand Down Expand Up @@ -894,19 +898,19 @@ mod tests {
.unwrap();
let mut aggregated_masked_model = Aggregation::new(config.into(), vect_len);
let mut aggregated_mask = Aggregation::new(config.into(), vect_len);
let scalar = 1_f64 / ($count as f64);
let scalar_ratio = Ratio::from_float(scalar).unwrap();
for _ in 0..$count as usize {
let scalar = Scalar::new(1, model_count);
let scalar_ratio = &scalar.to_ratio();
for _ in 0..model_count {
let model = models.next().unwrap();
averaged_model
.iter_mut()
.zip(model.iter())
.for_each(|(averaged_weight, weight)| {
*averaged_weight += &scalar_ratio * weight;
*averaged_weight += scalar_ratio * weight;
});

let (mask_seed, masked_model) =
Masker::new(config.into()).mask(scalar, &model);
Masker::new(config.into()).mask(scalar.clone(), &model);
let mask = mask_seed.derive_mask(vect_len, config.into());

assert!(
Expand All @@ -920,7 +924,7 @@ mod tests {
let mask = aggregated_mask.into();
assert!(aggregated_masked_model.validate_unmasking(&mask).is_ok());
let unmasked_model = aggregated_masked_model.unmask(mask);
let tolerance = Ratio::from_integer(BigInt::from($count as usize))
let tolerance = Ratio::from_integer(BigInt::from(model_count))
/ Ratio::from_integer(config.exp_shift());
assert!(
averaged_model.iter()
Expand All @@ -937,45 +941,41 @@ mod tests {
};
}

// FIXME some of the test cases below exceed the closeness checks, namely
// those with data type f64, and f32_bmax. For now, reduce the number of
// models for these test cases to 2 to minimise the error.

test_masking_and_aggregation!(int_f32_b0, Integer, f32, 1, 10, 5);
test_masking_and_aggregation!(int_f32_b2, Integer, f32, 100, 10, 5);
test_masking_and_aggregation!(int_f32_b4, Integer, f32, 10_000, 10, 5);
test_masking_and_aggregation!(int_f32_b6, Integer, f32, 1_000_000, 10, 5);
test_masking_and_aggregation!(int_f32_bmax, Integer, f32, 10, 2);
test_masking_and_aggregation!(int_f32_bmax, Integer, f32, 10, 5);

test_masking_and_aggregation!(prime_f32_b0, Prime, f32, 1, 10, 5);
test_masking_and_aggregation!(prime_f32_b2, Prime, f32, 100, 10, 5);
test_masking_and_aggregation!(prime_f32_b4, Prime, f32, 10_000, 10, 5);
test_masking_and_aggregation!(prime_f32_b6, Prime, f32, 1_000_000, 10, 5);
test_masking_and_aggregation!(prime_f32_bmax, Prime, f32, 10, 2);
test_masking_and_aggregation!(prime_f32_bmax, Prime, f32, 10, 5);

test_masking_and_aggregation!(pow_f32_b0, Power2, f32, 1, 10, 5);
test_masking_and_aggregation!(pow_f32_b2, Power2, f32, 100, 10, 5);
test_masking_and_aggregation!(pow_f32_b4, Power2, f32, 10_000, 10, 5);
test_masking_and_aggregation!(pow_f32_b6, Power2, f32, 1_000_000, 10, 5);
test_masking_and_aggregation!(pow_f32_bmax, Power2, f32, 10, 2);

test_masking_and_aggregation!(int_f64_b0, Integer, f64, 1, 10, 2);
test_masking_and_aggregation!(int_f64_b2, Integer, f64, 100, 10, 2);
test_masking_and_aggregation!(int_f64_b4, Integer, f64, 10_000, 10, 2);
test_masking_and_aggregation!(int_f64_b6, Integer, f64, 1_000_000, 10, 2);
test_masking_and_aggregation!(int_f64_bmax, Integer, f64, 10, 2);

test_masking_and_aggregation!(prime_f64_b0, Prime, f64, 1, 10, 2);
test_masking_and_aggregation!(prime_f64_b2, Prime, f64, 100, 10, 2);
test_masking_and_aggregation!(prime_f64_b4, Prime, f64, 10_000, 10, 2);
test_masking_and_aggregation!(prime_f64_b6, Prime, f64, 1_000_000, 10, 2);
test_masking_and_aggregation!(prime_f64_bmax, Prime, f64, 10, 2);

test_masking_and_aggregation!(pow_f64_b0, Power2, f64, 1, 10, 2);
test_masking_and_aggregation!(pow_f64_b2, Power2, f64, 100, 10, 2);
test_masking_and_aggregation!(pow_f64_b4, Power2, f64, 10_000, 10, 2);
test_masking_and_aggregation!(pow_f64_b6, Power2, f64, 1_000_000, 10, 2);
test_masking_and_aggregation!(pow_f64_bmax, Power2, f64, 10, 2);
test_masking_and_aggregation!(pow_f32_bmax, Power2, f32, 10, 5);

test_masking_and_aggregation!(int_f64_b0, Integer, f64, 1, 10, 5);
test_masking_and_aggregation!(int_f64_b2, Integer, f64, 100, 10, 5);
test_masking_and_aggregation!(int_f64_b4, Integer, f64, 10_000, 10, 5);
test_masking_and_aggregation!(int_f64_b6, Integer, f64, 1_000_000, 10, 5);
test_masking_and_aggregation!(int_f64_bmax, Integer, f64, 10, 5);

test_masking_and_aggregation!(prime_f64_b0, Prime, f64, 1, 10, 5);
test_masking_and_aggregation!(prime_f64_b2, Prime, f64, 100, 10, 5);
test_masking_and_aggregation!(prime_f64_b4, Prime, f64, 10_000, 10, 5);
test_masking_and_aggregation!(prime_f64_b6, Prime, f64, 1_000_000, 10, 5);
test_masking_and_aggregation!(prime_f64_bmax, Prime, f64, 10, 5);

test_masking_and_aggregation!(pow_f64_b0, Power2, f64, 1, 10, 5);
test_masking_and_aggregation!(pow_f64_b2, Power2, f64, 100, 10, 5);
test_masking_and_aggregation!(pow_f64_b4, Power2, f64, 10_000, 10, 5);
test_masking_and_aggregation!(pow_f64_b6, Power2, f64, 1_000_000, 10, 5);
test_masking_and_aggregation!(pow_f64_bmax, Power2, f64, 10, 5);

test_masking_and_aggregation!(int_i32_b0, Integer, i32, 1, 10, 5);
test_masking_and_aggregation!(int_i32_b2, Integer, i32, 100, 10, 5);
Expand Down Expand Up @@ -1048,6 +1048,7 @@ mod tests {
model_type: M3,
};
let vect_len = $len as usize;
let model_count = $count as usize;

// Step 2: Generate random scalars
// take vectors [1, ..., 1] as models to scale
Expand All @@ -1059,7 +1060,8 @@ mod tests {
let eps = [<$data:lower>]::EPSILON;
let mut prng = ChaCha20Rng::from_seed(MaskSeed::generate().as_array());
let mut scalars = iter::repeat_with(move || {
Uniform::new_inclusive(eps, bound).sample(&mut prng) as f64
let random_weight = Uniform::new_inclusive(eps, bound).sample(&mut prng);
Scalar::from_primitive(random_weight).unwrap()
});
let mut models =
iter::repeat(Model::from_primitives(iter::repeat(1).take(vect_len)).unwrap());
Expand All @@ -1071,7 +1073,7 @@ mod tests {
// d. repeat a-c, unmask the model and check it against the expected [1, ..., 1]
let mut aggregated_masked_model = Aggregation::new(config.into(), vect_len);
let mut aggregated_mask = Aggregation::new(config.into(), vect_len);
for _ in 0..$count as usize {
for _ in 0..model_count {
let model = models.next().unwrap();
let scalar = scalars.next().unwrap();

Expand All @@ -1090,7 +1092,7 @@ mod tests {
let mask = aggregated_mask.into();
assert!(aggregated_masked_model.validate_unmasking(&mask).is_ok());
let unmasked_model = aggregated_masked_model.unmask(mask);
let tolerance = Ratio::from_integer(BigInt::from($count as usize))
let tolerance = Ratio::from_integer(BigInt::from(model_count))
/ Ratio::from_integer(config.exp_shift());
let expected_weight = Ratio::from_integer(BigInt::from(1));
assert!(
Expand Down
20 changes: 11 additions & 9 deletions rust/xaynet-core/src/mask/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,10 +89,10 @@
//! can be generated via the additionally returned [`MaskSeed`].
//!
//! ```
//! # use xaynet_core::mask::{BoundType, DataType, FromPrimitives, GroupType, MaskConfig, Masker, Model, ModelType};
//! # use xaynet_core::mask::{BoundType, DataType, FromPrimitives, GroupType, MaskConfig, Masker, Model, ModelType, Scalar};
//! // create local models and a fitting masking configuration
//! let number_weights = 10;
//! let scalar = 0.5;
//! let scalar = Scalar::new(1, 2_u8);
//! let local_model_1 = Model::from_primitives_bounded(vec![0_f32; number_weights].into_iter());
//! let local_model_2 = Model::from_primitives_bounded(vec![1_f32; number_weights].into_iter());
//! let config = MaskConfig {
Expand All @@ -103,7 +103,7 @@
//! };
//!
//! // mask the local models
//! let (local_mask_seed_1, masked_local_model_1) = Masker::new(config.into()).mask(scalar, &local_model_1);
//! let (local_mask_seed_1, masked_local_model_1) = Masker::new(config.into()).mask(scalar.clone(), &local_model_1);
//! let (local_mask_seed_2, masked_local_model_2) = Masker::new(config.into()).mask(scalar, &local_model_2);
//!
//! // derive the masks of the local masked models
Expand All @@ -118,13 +118,13 @@
//! safely performed wrt the chosen masking configuration without possible loss of information.
//!
//! ```
//! # use xaynet_core::mask::{Aggregation, BoundType, DataType, FromPrimitives, GroupType, MaskConfig, Masker, MaskObject, Model, ModelType};
//! # use xaynet_core::mask::{Aggregation, BoundType, DataType, FromPrimitives, GroupType, MaskConfig, Masker, MaskObject, Model, ModelType, Scalar};
//! # let number_weights = 10;
//! # let scalar = 0.5;
//! # let scalar = Scalar::new(1, 2_u8);
//! # let local_model_1 = Model::from_primitives_bounded(vec![0_f32; number_weights].into_iter());
//! # let local_model_2 = Model::from_primitives_bounded(vec![1_f32; number_weights].into_iter());
//! # let config = MaskConfig { group_type: GroupType::Prime, data_type: DataType::F32, bound_type: BoundType::B0, model_type: ModelType::M3};
//! # let (local_mask_seed_1, masked_local_model_1) = Masker::new(config.into()).mask(scalar, &local_model_1);
//! # let (local_mask_seed_1, masked_local_model_1) = Masker::new(config.into()).mask(scalar.clone(), &local_model_1);
//! # let (local_mask_seed_2, masked_local_model_2) = Masker::new(config.into()).mask(scalar, &local_model_2);
//! # let local_model_mask_1 = local_mask_seed_1.derive_mask(number_weights, config.into());
//! # let local_model_mask_2 = local_mask_seed_2.derive_mask(number_weights, config.into());
Expand Down Expand Up @@ -154,13 +154,13 @@
//! configuration without possible loss of information.
//!
//! ```no_run
//! # use xaynet_core::mask::{Aggregation, BoundType, DataType, FromPrimitives, GroupType, MaskConfig, Masker, MaskObject, Model, ModelType};
//! # use xaynet_core::mask::{Aggregation, BoundType, DataType, FromPrimitives, GroupType, MaskConfig, Masker, MaskObject, Model, ModelType, Scalar};
//! # let number_weights = 10;
//! # let scalar = 0.5;
//! # let scalar = Scalar::new(1, 2_u8);
//! # let local_model_1 = Model::from_primitives_bounded(vec![0_f32; number_weights].into_iter());
//! # let local_model_2 = Model::from_primitives_bounded(vec![1_f32; number_weights].into_iter());
//! # let config = MaskConfig { group_type: GroupType::Prime, data_type: DataType::F32, bound_type: BoundType::B0, model_type: ModelType::M3};
//! # let (local_mask_seed_1, masked_local_model_1) = Masker::new(config.into()).mask(scalar, &local_model_1);
//! # let (local_mask_seed_1, masked_local_model_1) = Masker::new(config.into()).mask(scalar.clone(), &local_model_1);
//! # let (local_mask_seed_2, masked_local_model_2) = Masker::new(config.into()).mask(scalar, &local_model_2);
//! # let local_model_mask_1 = local_mask_seed_1.derive_mask(number_weights, config.into());
//! # let local_model_mask_2 = local_mask_seed_2.derive_mask(number_weights, config.into());
Expand All @@ -185,6 +185,7 @@ pub(crate) mod config;
pub(crate) mod masking;
pub(crate) mod model;
pub(crate) mod object;
pub(crate) mod scalar;
pub(crate) mod seed;

pub use self::{
Expand All @@ -207,5 +208,6 @@ pub use self::{
MaskUnit,
MaskVect,
},
scalar::{FromPrimitive, IntoPrimitive, Scalar, ScalarCastError},
seed::{EncryptedMaskSeed, MaskSeed},
};
10 changes: 5 additions & 5 deletions rust/xaynet-core/src/mask/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ impl IntoIterator for Model {

#[derive(Debug, Display)]
/// A primitive data type as a target for model conversion.
enum PrimitiveType {
pub(crate) enum PrimitiveType {
F32,
F64,
I32,
Expand All @@ -81,10 +81,10 @@ pub struct ModelCastError {
target: PrimitiveType,
}

#[derive(Error, Debug)]
#[error("Could not convert primitive type {0:?} to model weight")]
/// Errors related to model conversion from primitives.
pub struct PrimitiveCastError<P: Debug>(P);
#[derive(Clone, Error, Debug)]
#[error("Could not convert primitive type {0:?} to weight")]
/// Errors related to weight conversion from primitives.
pub struct PrimitiveCastError<P: Debug>(pub(crate) P);

/// An interface to convert a collection of numerical values into an iterator of primitive values.
///
Expand Down
Loading