Skip to content
This repository has been archived by the owner on Aug 30, 2022. It is now read-only.

generalised scalar extension #496

Merged
merged 7 commits into from
Aug 21, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 2 additions & 6 deletions rust/src/client/mobile_client/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -286,12 +286,8 @@ impl ClientState<Update> {
.ok_or(ClientError::TooEarly("local model"))?
.clone();

debug!("polling for model scalar");
let scalar = self
.proxy
.get_scalar()
.await?
.ok_or(ClientError::TooEarly("scalar"))?;
debug!("setting model scalar");
let scalar = 1_f64; // TODO parametrise this!
finiteprods marked this conversation as resolved.
Show resolved Hide resolved

debug!("polling for sum dict");
let sums = self
Expand Down
26 changes: 17 additions & 9 deletions rust/src/client/mobile_client/participant/sum2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,11 @@ impl Participant<Sum2> {
mask_len: usize,
) -> Result<MessageOwned, PetError> {
let mask_seeds = self.get_seeds(seed_dict)?;
let mask = self.compute_global_mask(mask_seeds, mask_len)?;
let (model_mask, scalar_mask) = self.compute_global_mask(mask_seeds, mask_len)?;
let payload = Sum2Owned {
mask,
sum_signature: self.inner.sum_signature,
model_mask,
scalar_mask,
};

Ok(MessageOwned::new_sum2(pk, self.state.keys.public, payload))
Expand All @@ -73,20 +74,27 @@ impl Participant<Sum2> {
&self,
mask_seeds: Vec<MaskSeed>,
mask_len: usize,
) -> Result<MaskObject, PetError> {
) -> Result<(MaskObject, MaskObject), PetError> {
if mask_seeds.is_empty() {
return Err(PetError::InvalidMask);
}

let mut aggregation = Aggregation::new(self.state.mask_config, mask_len);
let mut model_mask_agg = Aggregation::new(self.state.mask_config, mask_len);
let mut scalar_mask_agg = Aggregation::new(self.state.mask_config, 1);
for seed in mask_seeds.into_iter() {
let mask = seed.derive_mask(mask_len, self.state.mask_config);
aggregation
.validate_aggregation(&mask)
let (model_mask, scalar_mask) = seed.derive_mask(mask_len, self.state.mask_config);

model_mask_agg
.validate_aggregation(&model_mask)
.map_err(|_| PetError::InvalidMask)?;
scalar_mask_agg
.validate_aggregation(&scalar_mask)
.map_err(|_| PetError::InvalidMask)?;
aggregation.aggregate(mask);

model_mask_agg.aggregate(model_mask);
scalar_mask_agg.aggregate(scalar_mask);
}
Ok(aggregation.into())
Ok((model_mask_agg.into(), scalar_mask_agg.into()))
}
}

Expand Down
5 changes: 3 additions & 2 deletions rust/src/client/mobile_client/participant/update.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,21 +36,22 @@ impl Participant<Update> {
scalar: f64,
local_model: Model,
) -> MessageOwned {
let (mask_seed, masked_model) = self.mask_model(scalar, local_model);
let (mask_seed, masked_model, masked_scalar) = self.mask_model(scalar, local_model);
let local_seed_dict = Self::create_local_seed_dict(sum_dict, &mask_seed);

let payload = UpdateOwned {
sum_signature: self.inner.sum_signature,
update_signature: self.inner.update_signature,
masked_model,
masked_scalar,
local_seed_dict,
};

MessageOwned::new_update(pk, self.state.keys.public, payload)
}

/// Generate a mask seed and mask a local model.
fn mask_model(&self, scalar: f64, local_model: Model) -> (MaskSeed, MaskObject) {
fn mask_model(&self, scalar: f64, local_model: Model) -> (MaskSeed, MaskObject, MaskObject) {
Masker::new(self.state.mask_config).mask(scalar, local_model)
}

Expand Down
14 changes: 4 additions & 10 deletions rust/src/client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -154,8 +154,10 @@ pub struct Client {
pub(crate) cached_model: Option<CachedModel>,
pub(crate) has_new_global_model_since_last_check: bool,
pub(crate) has_new_global_model_since_last_cache: bool,

// TEMP pub visibility to allow access from test-drive
pub local_model: Option<Model>,
pub scalar: f64,

/// Identifier for this client
id: u32,
Expand All @@ -176,6 +178,7 @@ impl Default for Client {
has_new_global_model_since_last_check: false,
has_new_global_model_since_last_cache: false,
local_model: None,
scalar: 1.0,
id: 0,
proxy: Proxy::new_remote("http://127.0.0.1:3030"),
}
Expand Down Expand Up @@ -365,23 +368,14 @@ impl Client {
self.interval.tick().await;
};

debug!(client_id = %self.id, "polling for model scalar");
let scalar = loop {
if let Some(scalar) = self.proxy.get_scalar().await? {
break scalar;
}
trace!(client_id = %self.id, "model scalar not ready, retrying.");
self.interval.tick().await;
};

debug!(client_id = %self.id, "polling for sum dict");
loop {
if let Some(sums) = self.proxy.get_sums().await? {
debug!(client_id = %self.id, "sum dict received, sending update message.");
let msg = self.participant.compose_update_message(
self.coordinator_pk,
&sums,
scalar,
self.scalar,
model,
);
let sealed_msg = self.participant.seal_message(&self.coordinator_pk, &msg);
Expand Down
32 changes: 21 additions & 11 deletions rust/src/client/participant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -142,13 +142,14 @@ impl Participant {
scalar: f64,
local_model: Model,
) -> MessageOwned {
let (mask_seed, masked_model) = Self::mask_model(scalar, local_model);
let (mask_seed, masked_model, masked_scalar) = Self::mask_model(scalar, local_model);
let local_seed_dict = Self::create_local_seed_dict(sum_dict, &mask_seed);

let payload = UpdateOwned {
sum_signature: self.sum_signature,
update_signature: self.update_signature,
masked_model,
masked_scalar,
local_seed_dict,
};

Expand All @@ -169,10 +170,12 @@ impl Participant {
mask_len: usize,
) -> Result<MessageOwned, PetError> {
let mask_seeds = self.get_seeds(seed_dict)?;
let mask = self.compute_global_mask(mask_seeds, mask_len, dummy_config())?;
let (model_mask, scalar_mask) =
self.compute_global_mask(mask_seeds, mask_len, dummy_config())?;
let payload = Sum2Owned {
mask,
sum_signature: self.sum_signature,
model_mask,
scalar_mask,
};

Ok(MessageOwned::new_sum2(pk, self.pk, payload))
Expand All @@ -196,7 +199,7 @@ impl Participant {
}

/// Generate a mask seed and mask a local model.
fn mask_model(scalar: f64, local_model: Model) -> (MaskSeed, MaskObject) {
fn mask_model(scalar: f64, local_model: Model) -> (MaskSeed, MaskObject, MaskObject) {
// TODO: use proper config
Masker::new(dummy_config()).mask(scalar, local_model)
}
Expand All @@ -223,20 +226,27 @@ impl Participant {
mask_seeds: Vec<MaskSeed>,
mask_len: usize,
mask_config: MaskConfig,
) -> Result<MaskObject, PetError> {
) -> Result<(MaskObject, MaskObject), PetError> {
if mask_seeds.is_empty() {
return Err(PetError::InvalidMask);
}

let mut aggregation = Aggregation::new(mask_config, mask_len);
let mut model_mask_agg = Aggregation::new(mask_config, mask_len);
let mut scalar_mask_agg = Aggregation::new(mask_config, 1);
for seed in mask_seeds.into_iter() {
let mask = seed.derive_mask(mask_len, mask_config);
aggregation
.validate_aggregation(&mask)
let (model_mask, scalar_mask) = seed.derive_mask(mask_len, mask_config);

model_mask_agg
.validate_aggregation(&model_mask)
.map_err(|_| PetError::InvalidMask)?;
scalar_mask_agg
.validate_aggregation(&scalar_mask)
.map_err(|_| PetError::InvalidMask)?;
aggregation.aggregate(mask);

model_mask_agg.aggregate(model_mask);
scalar_mask_agg.aggregate(scalar_mask);
}
Ok(aggregation.into())
Ok((model_mask_agg.into(), scalar_mask_agg.into()))
}
}

Expand Down
34 changes: 0 additions & 34 deletions rust/src/client/request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,35 +127,6 @@ impl Proxy {
}
}

/// Get the model scalar data from the service proxy.
///
/// Returns `Ok(Some(data))` if the `data` is available on the
/// service, `Ok(None)` if it is not.
///
/// # Errors
/// * Returns `Fetch` if an error occurs fetching from the in-memory proxy.
/// * Returns `NetworkErr` if a network error occurs while getting the data.
/// * Returns `ParseErr` if an error occurs while parsing the response.
pub async fn get_scalar(&mut self) -> Result<Option<f64>, ClientError> {
match self {
InMem(ref mut hdl, _) => hdl.scalar().await.map_err(ClientError::Fetch),
Remote(req) => {
let opt_text = req.get_scalar().await.map_err(|e| {
error!("failed to GET model scalar: {}", e);
ClientError::NetworkErr(e)
})?;
opt_text
.map(|text| {
text.parse().map_err(|e| {
error!("failed to parse model scalar: {}: {:?}", e, text);
ClientError::ParseErr
})
})
.transpose()
}
}
}

/// Get the seed dictionary data from the service proxy.
///
/// Returns `Ok(Some(data))` if the `data` is available on the
Expand Down Expand Up @@ -300,11 +271,6 @@ impl ClientReq {
self.simple_get_bytes(&url).await
}

async fn get_scalar(&self) -> Result<Option<String>, Error> {
let url = format!("{}/scalar", self.address);
self.simple_get_text(&url).await
}

async fn get_seeds(&self, pk: SumParticipantPublicKey) -> Result<Option<Bytes>, Error> {
let url = format!("{}/seeds", self.address);
let response = self
Expand Down
66 changes: 54 additions & 12 deletions rust/src/mask/masking.rs
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,20 @@ impl Aggregation {
.collect()
}

/// Applies a correction to the given unmasked model based on the associated
/// unmasked scalar sum, in order to scale it correctly.
///
/// This should be called after [`unmask()`] is called for both the model
/// and scalar aggregations.
pub(crate) fn correct(overscaled: Model, scalar_sum: Model) -> Model {
finiteprods marked this conversation as resolved.
Show resolved Hide resolved
// FIXME later on, tidy up API so that scalar_sum is encapsulated away
let correction = scalar_sum.into_iter().next().unwrap();
overscaled
.into_iter()
.map(|weight| weight / &correction)
.collect()
}

/// Validates if aggregation of the aggregated mask object with the given `object` may be safely
/// performed.
///
Expand Down Expand Up @@ -213,7 +227,7 @@ impl Aggregation {
return Err(AggregationError::ModelMismatch);
}

if self.nb_models == self.object.config.model_type.max_nb_models() {
if self.nb_models >= self.object.config.model_type.max_nb_models() {
finiteprods marked this conversation as resolved.
Show resolved Hide resolved
return Err(AggregationError::TooManyModels);
}

Expand Down Expand Up @@ -287,32 +301,46 @@ impl Masker {
/// proceeds in reverse order.
///
/// [`unmask()`]: struct.Aggregation.html#method.unmask
pub fn mask(self, scalar: f64, model: Model) -> (MaskSeed, MaskObject) {
let random_ints = self.random_ints();
pub fn mask(self, scalar: f64, model: Model) -> (MaskSeed, MaskObject, MaskObject) {
let mut random_ints = self.random_ints();
let Self { seed, config } = self;

let exp_shift = config.exp_shift();
let add_shift = config.add_shift();
let order = config.order();
let higher_bound = &add_shift;
let lower_bound = -&add_shift;
let scalar = Ratio::<BigInt>::from_float(clamp(scalar, 0_f64, 1_f64)).unwrap();

let scalar_ratio = crate::mask::model::float_to_ratio_bounded(scalar);
// HACK reuse upper bound for scaled weights for now, really should be tighter
// TODO give scalar its own config in later refactoring
let zero = Ratio::<BigInt>::from_float(0_f64).unwrap();
let scalar_clamped = clamp(&scalar_ratio, &zero, higher_bound);

let masked_weights = model
.into_iter()
.zip(random_ints)
.zip(&mut random_ints)
.map(|(weight, rand_int)| {
let scaled = &scalar * clamp(&weight, &lower_bound, higher_bound);
let scaled = scalar_clamped * &weight;
let scaled_clamped = clamp(&scaled, &lower_bound, higher_bound);
// PANIC_SAFE: shifted weight is guaranteed to be non-negative
let shifted = ((scaled + &add_shift) * &exp_shift)
let shifted = ((scaled_clamped + &add_shift) * &exp_shift)
.to_integer()
.to_biguint()
.unwrap();
(shifted + rand_int) % &order
})
.collect();
let masked_model = MaskObject::new(config, masked_weights);
(seed, masked_model)

let rand_int = random_ints.next().unwrap();
let shifted = ((scalar_clamped + &add_shift) * &exp_shift)
.to_integer()
.to_biguint()
.unwrap();
let masked_scalar = MaskObject::new(config, vec![(shifted + rand_int) % &order]);

(seed, masked_model, masked_scalar)
}

/// Creates an iterator that yields randomly generated integers wrt the masking configuration.
Expand Down Expand Up @@ -395,11 +423,14 @@ mod tests {
// a. mask the model
// b. derive the mask corresponding to the seed used
// c. unmask the model and check it against the original one.
let (mask_seed, masked_model) = Masker::new(config.clone()).mask(1_f64, model.clone());
let (mask_seed, masked_model, masked_scalar) =
Masker::new(config.clone()).mask(1_f64, model.clone());
assert_eq!(masked_model.data.len(), model.len());
assert!(masked_model.is_valid());
assert_eq!(masked_scalar.data.len(), 1);
assert!(masked_scalar.is_valid());

let mask = mask_seed.derive_mask(model.len(), config);
let (mask, _scalar_mask) = mask_seed.derive_mask(model.len(), config);
let aggregation = Aggregation::from(masked_model);
let unmasked_model = aggregation.unmask(mask);

Expand Down Expand Up @@ -683,6 +714,8 @@ mod tests {
.unwrap();
let mut aggregated_masked_model = Aggregation::new(config, model_size);
let mut aggregated_mask = Aggregation::new(config, model_size);
let mut aggregated_masked_scalar = Aggregation::new(config, 1);
let mut aggregated_scalar_mask = Aggregation::new(config, 1);
let scalar = 1_f64 / ($count as f64);
let scalar_ratio = Ratio::from_float(scalar).unwrap();
for _ in 0..$count as usize {
Expand All @@ -694,15 +727,23 @@ mod tests {
*averaged_weight += &scalar_ratio * weight;
});

let (mask_seed, masked_model) = Masker::new(config).mask(scalar, model);
let mask = mask_seed.derive_mask($len as usize, config);
let (mask_seed, masked_model, masked_scalar) =
Masker::new(config).mask(scalar, model);
let (mask, scalar_mask) = mask_seed.derive_mask($len as usize, config);

assert!(
aggregated_masked_model.validate_aggregation(&masked_model).is_ok()
);
aggregated_masked_model.aggregate(masked_model);
assert!(aggregated_mask.validate_aggregation(&mask).is_ok());
aggregated_mask.aggregate(mask);

assert!(
aggregated_masked_scalar.validate_aggregation(&masked_scalar).is_ok()
);
aggregated_masked_scalar.aggregate(masked_scalar);
assert!(aggregated_scalar_mask.validate_aggregation(&scalar_mask).is_ok());
aggregated_scalar_mask.aggregate(scalar_mask);
}

let unmasked_model = aggregated_masked_model.unmask(aggregated_mask.into());
Expand All @@ -715,6 +756,7 @@ mod tests {
(averaged_weight - unmasked_weight).abs() <= tolerance
})
);
// TODO check scalar as well, after future refactoring
}
}
};
Expand Down
Loading