Skip to content

Commit

Permalink
Merge pull request #24 from vaaaaanquish/add_feature_importance
Browse files Browse the repository at this point in the history
Add get feature importance method
  • Loading branch information
vaaaaanquish committed Feb 11, 2021
2 parents 5b1c505 + 7c59c1a commit 6c11c59
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 14 deletions.
8 changes: 7 additions & 1 deletion examples/binary_classification/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,12 @@ fn main() -> std::io::Result<()> {
}
println!("{}, {}", label, pred)
}
println!("{} / {}", &tp, result[0].len());
println!("feature importance");
let feature_name = booster.feature_name().unwrap();
let feature_importance = booster.feature_importance().unwrap();
for (feature, importance) in zip(&feature_name, &feature_importance) {
println!("{}: {}", feature, importance);
}
println!("result: {} / {}", &tp, result[0].len());
Ok(())
}
111 changes: 98 additions & 13 deletions src/booster.rs
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,57 @@ impl Booster {
Ok(reshaped_output)
}

/// Get Feature Num.
pub fn num_feature(&self) -> Result<i32> {
let mut out_len = 0;
lgbm_call!(lightgbm_sys::LGBM_BoosterGetNumFeature(
self.handle,
&mut out_len
))?;
Ok(out_len)
}

/// Get Feature Names.
pub fn feature_name(&self) -> Result<Vec<String>> {
let num_feature = self.num_feature().unwrap();
let feature_name_length = 32;
let mut num_feature_names = 0;
let mut out_buffer_len = 0;
let out_strs = (0..num_feature)
.map(|_| {
CString::new(" ".repeat(feature_name_length))
.unwrap()
.into_raw() as *mut c_char
})
.collect::<Vec<_>>();
lgbm_call!(lightgbm_sys::LGBM_BoosterGetFeatureNames(
self.handle,
feature_name_length as i32,
&mut num_feature_names,
num_feature as u64,
&mut out_buffer_len,
out_strs.as_ptr() as *mut *mut c_char
))?;
let output: Vec<String> = out_strs
.into_iter()
.map(|s| unsafe { CString::from_raw(s).into_string().unwrap() })
.collect();
Ok(output)
}

// Get Feature Importance
pub fn feature_importance(&self) -> Result<Vec<f64>> {
let num_feature = self.num_feature().unwrap();
let out_result: Vec<f64> = vec![Default::default(); num_feature as usize];
lgbm_call!(lightgbm_sys::LGBM_BoosterFeatureImportance(
self.handle,
0_i32,
0_i32,
out_result.as_ptr() as *mut c_double
))?;
Ok(out_result)
}

/// Save model to file.
pub fn save_file(&self, filename: &str) -> Result<()> {
let filename_str = CString::new(filename).unwrap();
Expand Down Expand Up @@ -174,13 +225,30 @@ mod tests {
use std::fs;
use std::path::Path;

fn read_train_file() -> Result<Dataset> {
fn _read_train_file() -> Result<Dataset> {
Dataset::from_file(&"lightgbm-sys/lightgbm/examples/binary_classification/binary.train")
}

fn _train_booster(params: &Value) -> Booster {
let dataset = _read_train_file().unwrap();
let bst = Booster::train(dataset, &params).unwrap();
bst
}

fn _default_params() -> Value {
let params = json! {
{
"num_iterations": 1,
"objective": "binary",
"metric": "auc",
"data_random_seed": 0
}
};
params
}

#[test]
fn predict() {
let dataset = read_train_file().unwrap();
let params = json! {
{
"num_iterations": 10,
Expand All @@ -189,7 +257,7 @@ mod tests {
"data_random_seed": 0
}
};
let bst = Booster::train(dataset, &params).unwrap();
let bst = _train_booster(&params);
let feature = vec![vec![0.5; 28], vec![0.0; 28], vec![0.9; 28]];
let result = bst.predict(feature).unwrap();
let mut normalized_result = Vec::new();
Expand All @@ -199,18 +267,35 @@ mod tests {
assert_eq!(normalized_result, vec![0, 0, 1]);
}

#[test]
fn num_feature() {
let params = _default_params();
let bst = _train_booster(&params);
let num_feature = bst.num_feature().unwrap();
assert_eq!(num_feature, 28);
}

#[test]
fn feature_importance() {
let params = _default_params();
let bst = _train_booster(&params);
let feature_importance = bst.feature_importance().unwrap();
assert_eq!(feature_importance, vec![0.0; 28]);
}

#[test]
fn feature_name() {
let params = _default_params();
let bst = _train_booster(&params);
let feature_name = bst.feature_name().unwrap();
let target = (0..28).map(|i| format!("Column_{}", i)).collect::<Vec<_>>();
assert_eq!(feature_name, target);
}

#[test]
fn save_file() {
let dataset = read_train_file().unwrap();
let params = json! {
{
"num_iterations": 1,
"objective": "binary",
"metric": "auc",
"data_random_seed": 0
}
};
let bst = Booster::train(dataset, &params).unwrap();
let params = _default_params();
let bst = _train_booster(&params);
assert_eq!(bst.save_file(&"./test/test_save_file.output"), Ok(()));
assert!(Path::new("./test/test_save_file.output").exists());
let _ = fs::remove_file("./test/test_save_file.output");
Expand Down

0 comments on commit 6c11c59

Please sign in to comment.