Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Oracle::variability #165

Merged
merged 3 commits into from
Jan 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
316 changes: 134 additions & 182 deletions cli/Cargo.lock

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions lace/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion lace/lace_consts/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@ repository = "https://github.com/promised-ai/lace"
description = "Default constants for Lace"

[dependencies]
rv = { version = "0.16.2", features = ["serde1", "arraydist"] }
rv = { version = "0.16.3", features = ["serde1", "arraydist"] }
2 changes: 1 addition & 1 deletion lace/src/interface/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ pub use oracle::utils;

pub use oracle::{
ConditionalEntropyType, DatalessOracle, MiComponents, MiType, Oracle,
OracleT, RowSimilarityVariant,
OracleT, RowSimilarityVariant, Variability,
};

pub use given::Given;
Expand Down
15 changes: 13 additions & 2 deletions lace/src/interface/oracle/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,17 @@ pub enum PredictError {
GivenError(#[from] GivenError),
}

/// Describes errors that can occur from bad inputs to `Oracle::variability`
#[derive(Debug, Clone, PartialEq, Error)]
pub enum VariabilityError {
/// The target column index is out of bounds
#[error("Target index error in predict query: {0}")]
IndexError(#[from] IndexError),
/// The Given is invalid
#[error("Invalid predict 'given' argument: {0}")]
GivenError(#[from] GivenError),
}

/// Describes errors that arise from invalid predict uncertainty arguments
#[derive(Debug, Clone, PartialEq, Error)]
pub enum PredictUncertaintyError {
Expand All @@ -192,7 +203,7 @@ pub enum PredictUncertaintyError {

/// Describes errors from incompatible `col_max_logp` caches
#[derive(Debug, Clone, PartialEq, Eq, Error)]
pub enum ColumnMaxiumLogPError {
pub enum ColumnMaximumLogPError {
/// The state indices used to compute the cache do not match those passed to the function.
#[error("The state indices used to compute the cache do not match those passed to the function.")]
InvalidStateIndices,
Expand Down Expand Up @@ -247,7 +258,7 @@ pub enum LogpError {
#[error("Invalid logp 'given' argument: {0}")]
GivenError(#[from] GivenError),
#[error("Invalid `col_max_logps` argument: {0}")]
ColumnMaxiumLogPError(#[from] ColumnMaxiumLogPError),
ColumnMaximumLogPError(#[from] ColumnMaximumLogPError),
}

/// Describes errors from bad inputs to Oracle::simulate
Expand Down
2 changes: 1 addition & 1 deletion lace/src/interface/oracle/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ pub mod utils;
mod validation;

pub use dataless::DatalessOracle;
pub use traits::OracleT;
pub use traits::{OracleT, Variability};

use std::path::Path;

Expand Down
139 changes: 139 additions & 0 deletions lace/src/interface/oracle/traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ use lace_stats::rv::traits::Rv;
use lace_stats::SampleError;
use rand::Rng;
use rayon::prelude::*;
use serde::{Deserialize, Serialize};
use std::collections::BTreeSet;

macro_rules! col_indices_ok {
Expand All @@ -41,6 +42,25 @@ macro_rules! state_indices_ok {
}}
}

/// Represents different formalizations of variability in distributions
#[derive(Clone, Copy, Debug, PartialEq, PartialOrd, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum Variability {
/// The variance of a univariate distribution
Variance(f64),
/// The entropy of a distribution
Entropy(f64),
}

impl From<Variability> for f64 {
fn from(value: Variability) -> Self {
match value {
Variability::Variance(x) => x,
Variability::Entropy(x) => x,
}
}
}

pub trait OracleT: CanOracle {
/// Returns the diagnostics for each state
fn state_diagnostics(&self) -> Vec<StateDiagnostics> {
Expand Down Expand Up @@ -2046,6 +2066,125 @@ pub trait OracleT: CanOracle {
}
}

/// Compute the variability of a conditional distribution
///
/// # Notes
/// - Returns variance for Continuous and Count columns
/// - Returns Entropy for Categorical columns
///
/// # Arguments
/// - col_ix: the index of the column for which to compute the variability
/// - given: optional observations by which to constrain the prediction
/// - state_ixs_opt: Optional vector of state indices from which to compute,
/// if None, use all states.
fn variability<Ix: ColumnIndex, GIx: ColumnIndex>(
&self,
col_ix: Ix,
given: &Given<GIx>,
state_ixs_opt: Option<&[usize]>,
) -> Result<Variability, error::VariabilityError> {
use crate::stats::rv::traits::{Entropy, Variance};
use crate::stats::MixtureType;

let states: Vec<&State> = if let Some(state_ixs) = state_ixs_opt {
state_ixs.iter().map(|&ix| &self.states()[ix]).collect()
} else {
self.states().iter().collect()
};

let given =
given.clone().canonical(self.codebook()).map_err(|err| {
error::VariabilityError::GivenError(
error::GivenError::IndexError(err),
)
})?;

let col_ix = col_ix.col_ix(self.codebook())?;

// Get the mixture weights for each state
let mut mixture_types: Vec<MixtureType> = states
.iter()
.map(|state| {
let view_ix = state.asgn.asgn[col_ix];
let weights =
&utils::given_weights(&[state], &[col_ix], &given)[0];

// combine the state weights with the given weights
let mut mm_weights: Vec<f64> = state.views[view_ix]
.weights
.iter()
.zip(weights[&view_ix].iter())
.map(|(&w1, &w2)| w1 + w2)
.collect();

let z: f64 = logsumexp(&mm_weights);
mm_weights.iter_mut().for_each(|w| *w = (*w - z).exp());

state.views[view_ix].ftrs[&col_ix].to_mixture(mm_weights)
})
.collect();

enum MType {
Gaussian,
Categorical,
Count,
Unsupported,
}

let mtype = match mixture_types[0] {
MixtureType::Gaussian(_) => MType::Gaussian,
MixtureType::Poisson(_) => MType::Count,
MixtureType::Categorical(_) => MType::Categorical,
_ => MType::Unsupported,
};

match mtype {
MType::Gaussian => {
let mms: Vec<_> = mixture_types
.drain(..)
.map(|mt| {
if let MixtureType::Gaussian(mm) = mt {
mm
} else {
panic!("Expected Gaussian Mixture Type")
}
})
.collect();
let mm = Mixture::combine(mms);
Ok(Variability::Variance(mm.variance().unwrap()))
}
MType::Count => {
let mms: Vec<_> = mixture_types
.drain(..)
.map(|mt| {
if let MixtureType::Poisson(mm) = mt {
mm
} else {
panic!("Expected Poisson Mixture Type")
}
})
.collect();
let mm = Mixture::combine(mms);
Ok(Variability::Variance(mm.variance().unwrap()))
}
MType::Categorical => {
let mms: Vec<_> = mixture_types
.drain(..)
.map(|mt| {
if let MixtureType::Categorical(mm) = mt {
mm
} else {
panic!("Expected Categorical Mixture Type")
}
})
.collect();
let mm = Mixture::combine(mms);
Ok(Variability::Entropy(mm.entropy()))
}
_ => panic!("Unsupported MType"),
}
Swandog marked this conversation as resolved.
Show resolved Hide resolved
}

/// Compute the error between the observed data in a feature and the feature
/// model.
///
Expand Down
2 changes: 2 additions & 0 deletions lace/src/prelude.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ pub use crate::{
RowSimilarityVariant, SupportExtension, Value, WriteMode,
};

pub use crate::interface::Variability;

pub use crate::data::DataSource;

pub use lace_cc::{
Expand Down
4 changes: 2 additions & 2 deletions pylace/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

20 changes: 20 additions & 0 deletions pylace/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,23 @@ engine.update(10_000)
engine.predict('Class_of_Orbit', given={'Period_minutes': 1436.0})
# ('GEO', 0.13583714831550336)
```

## Tests

To run tests, use `pytest`

```console
$ pytest -x
```

To run doctets:

```console
$ python tests/test_docs.py
Swandog marked this conversation as resolved.
Show resolved Hide resolved
```

To prevent plotly from displaying

```console
$ LACE_DOCTEST_NOPLOT=1 python tests/test_docs.py
```
30 changes: 12 additions & 18 deletions pylace/lace/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ def held_out_neglogp(
│ ["Apogee_km"] ┆ 5.106627 ┆ 1 │
│ ["Apogee_km", "Eccentricity"] ┆ 2.951662 ┆ 2 │
│ ["Apogee_km", "Country_of_Operat… ┆ 2.951254 ┆ 3 │
… ┆ … ┆ …
["Apogee_km", "Country_of_Operat… ┆ 2.952801 ┆ 4
│ ["Apogee_km", "Country_of_Contra… ┆ 2.956224 ┆ 5 │
│ ["Apogee_km", "Country_of_Contra… ┆ 2.96479 ┆ 6 │
│ ["Apogee_km", "Country_of_Contra… ┆ 2.992173 ┆ 7 │
Expand Down Expand Up @@ -415,7 +415,7 @@ def held_out_inconsistency(
│ ["Apogee_km"] ┆ 1.290609 ┆ 1 │
│ ["Apogee_km", "Eccentricity"] ┆ 0.74598 ┆ 2 │
│ ["Apogee_km", "Country_of_Operat… ┆ 0.745877 ┆ 3 │
… ┆ … ┆ …
["Apogee_km", "Country_of_Operat… ┆ 0.746268 ┆ 4
│ ["Apogee_km", "Country_of_Contra… ┆ 0.747133 ┆ 5 │
│ ["Apogee_km", "Country_of_Contra… ┆ 0.749297 ┆ 6 │
│ ["Apogee_km", "Country_of_Contra… ┆ 0.756218 ┆ 7 │
Expand Down Expand Up @@ -525,7 +525,7 @@ def held_out_uncertainty(
│ ["Expected_Lifetime"] ┆ 0.437647 ┆ 1 │
│ ["Apogee_km", "Eccentricity"] ┆ 0.05561 ┆ 2 │
│ ["Apogee_km", "Country_of_Operat… ┆ 0.055283 ┆ 3 │
… ┆ … ┆ …
["Apogee_km", "Country_of_Operat… ┆ 0.056185 ┆ 4
│ ["Apogee_km", "Country_of_Operat… ┆ 0.057624 ┆ 5 │
│ ["Apogee_km", "Country_of_Contra… ┆ 0.0595 ┆ 6 │
│ ["Apogee_km", "Country_of_Contra… ┆ 0.077359 ┆ 7 │
Expand Down Expand Up @@ -945,15 +945,15 @@ def explain_prediction(
│ --- ┆ --- │
│ str ┆ f64 │
╞══════════════════════════════╪═════════════╡
│ Country_of_Operator ┆ 3.5216e-16 │
│ Users ┆ -3.1668e-14
│ Purpose ┆ -9.5636e-14
│ Class_of_Orbit ┆ -1.8263e-15 │
│ Country_of_Operator ┆ 2.4617e-16 │
│ Users ┆ -2.1412e-15
│ Purpose ┆ -8.0193e-15
│ Class_of_Orbit ┆ -2.2727e-15 │
│ … ┆ … │
│ Launch_Site ┆ -2.8416e-15
│ Launch_Vehicle ┆ 1.0704e-14
│ Source_Used_for_Orbital_Data ┆ -3.9301e-15 │
│ Inclination_radians ┆ -9.6259e-15 │
│ Launch_Site ┆ -5.8214e-16
│ Launch_Vehicle ┆ -9.6101e-16
│ Source_Used_for_Orbital_Data ┆ -9.1997e-15 │
│ Inclination_radians ┆ -1.5407e-15 │
└──────────────────────────────┴─────────────┘

Get the importances using the 'ablative-dist' method, which measures how
Expand All @@ -975,7 +975,7 @@ def explain_prediction(
│ Country_of_Operator ┆ -0.000109 │
│ Users ┆ 0.081289 │
│ Purpose ┆ 0.18938 │
│ Class_of_Orbit ┆ 0.000133
│ Class_of_Orbit ┆ 0.000119
│ … ┆ … │
│ Launch_Site ┆ 0.003411 │
│ Launch_Vehicle ┆ -0.018817 │
Expand All @@ -994,9 +994,3 @@ def explain_prediction(
raise ValueError(
f"Invalid method `{method}`, valid methods are {PRED_EXPLAIN_METHODS}"
)


if __name__ == "__main__":
import doctest

doctest.testmod()
Loading
Loading