From 8600b574c91a01b65396b6771d623d1d9a81b757 Mon Sep 17 00:00:00 2001 From: vaaaaanquish <6syun9@gmail.com> Date: Fri, 12 Feb 2021 01:57:45 +0900 Subject: [PATCH 1/4] add get feature importance --- examples/binary_classification/src/main.rs | 8 +++- src/booster.rs | 46 ++++++++++++++++++++++ 2 files changed, 53 insertions(+), 1 deletion(-) diff --git a/examples/binary_classification/src/main.rs b/examples/binary_classification/src/main.rs index f08dfa3..82260ca 100644 --- a/examples/binary_classification/src/main.rs +++ b/examples/binary_classification/src/main.rs @@ -53,6 +53,12 @@ fn main() -> std::io::Result<()> { } println!("{}, {}", label, pred) } - println!("{} / {}", &tp, result[0].len()); + println!("feature importance"); + let feature_name = booster.feature_name().unwrap(); + let feature_importance = booster.feature_importance().unwrap(); + for (feature, importance) in zip(&feature_name, &feature_importance) { + println!("{}: {}", feature, importance); + } + println!("result: {} / {}", &tp, result[0].len()); Ok(()) } diff --git a/src/booster.rs b/src/booster.rs index 06e80ce..80a0756 100644 --- a/src/booster.rs +++ b/src/booster.rs @@ -147,6 +147,52 @@ impl Booster { Ok(reshaped_output) } + /// Get Feature Num. + pub fn num_feature(&self) -> Result { + let mut out_len = 0; + lgbm_call!(lightgbm_sys::LGBM_BoosterGetNumFeature( + self.handle, + &mut out_len + ))?; + Ok(out_len) + } + + /// Get Feature Names. + pub fn feature_name(&self) -> Result> { + let num_feature = self.num_feature().unwrap(); + let feature_name_length = 32; + let mut num_feature_names = 0; + let mut out_buffer_len = 0; + let out_strs = (0..num_feature) + .map(|_| CString::new(" ".repeat(feature_name_length)).unwrap().into_raw() as *mut c_char) + .collect::>(); + lgbm_call!(lightgbm_sys::LGBM_BoosterGetFeatureNames( + self.handle, + feature_name_length as i32, + &mut num_feature_names, + num_feature as u64, + &mut out_buffer_len, + out_strs.as_ptr() as *mut *mut c_char + ))?; + let output: Vec = out_strs.into_iter() + .map(|s| unsafe{ CString::from_raw(s).into_string().unwrap() }) + .collect(); + Ok(output) + } + + // Get Feature Importance + pub fn feature_importance(&self) -> Result> { + let num_feature = self.num_feature().unwrap(); + let out_result: Vec = vec![Default::default(); num_feature as usize]; + lgbm_call!(lightgbm_sys::LGBM_BoosterFeatureImportance( + self.handle, + 0_i32, + 0_i32, + out_result.as_ptr() as *mut c_double + ))?; + Ok(out_result) + } + /// Save model to file. pub fn save_file(&self, filename: &str) -> Result<()> { let filename_str = CString::new(filename).unwrap(); From 52825a38f827dc5a9d80f754517bf43e4536bbf6 Mon Sep 17 00:00:00 2001 From: vaaaaanquish <6syun9@gmail.com> Date: Fri, 12 Feb 2021 02:02:32 +0900 Subject: [PATCH 2/4] cargo fmt --- src/booster.rs | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/src/booster.rs b/src/booster.rs index 80a0756..595c717 100644 --- a/src/booster.rs +++ b/src/booster.rs @@ -164,7 +164,11 @@ impl Booster { let mut num_feature_names = 0; let mut out_buffer_len = 0; let out_strs = (0..num_feature) - .map(|_| CString::new(" ".repeat(feature_name_length)).unwrap().into_raw() as *mut c_char) + .map(|_| { + CString::new(" ".repeat(feature_name_length)) + .unwrap() + .into_raw() as *mut c_char + }) .collect::>(); lgbm_call!(lightgbm_sys::LGBM_BoosterGetFeatureNames( self.handle, @@ -174,8 +178,9 @@ impl Booster { &mut out_buffer_len, out_strs.as_ptr() as *mut *mut c_char ))?; - let output: Vec = out_strs.into_iter() - .map(|s| unsafe{ CString::from_raw(s).into_string().unwrap() }) + let output: Vec = out_strs + .into_iter() + .map(|s| unsafe { CString::from_raw(s).into_string().unwrap() }) .collect(); Ok(output) } From cd4343d8f862110277bada0e42b80adb10d705ba Mon Sep 17 00:00:00 2001 From: vaaaaanquish <6syun9@gmail.com> Date: Fri, 12 Feb 2021 02:45:14 +0900 Subject: [PATCH 3/4] add test --- src/booster.rs | 60 +++++++++++++++++++++++++++++++++++++++----------- 1 file changed, 47 insertions(+), 13 deletions(-) diff --git a/src/booster.rs b/src/booster.rs index 595c717..6979138 100644 --- a/src/booster.rs +++ b/src/booster.rs @@ -225,13 +225,30 @@ mod tests { use std::fs; use std::path::Path; - fn read_train_file() -> Result { + fn _read_train_file() -> Result { Dataset::from_file(&"lightgbm-sys/lightgbm/examples/binary_classification/binary.train") } + fn _train_booster(params: &Value) -> Booster { + let dataset = _read_train_file().unwrap(); + let bst = Booster::train(dataset, ¶ms).unwrap(); + bst + } + + fn _default_params() -> Value { + let params = json! { + { + "num_iterations": 1, + "objective": "binary", + "metric": "auc", + "data_random_seed": 0 + } + }; + params + } + #[test] fn predict() { - let dataset = read_train_file().unwrap(); let params = json! { { "num_iterations": 10, @@ -240,7 +257,7 @@ mod tests { "data_random_seed": 0 } }; - let bst = Booster::train(dataset, ¶ms).unwrap(); + let bst = _train_booster(¶ms); let feature = vec![vec![0.5; 28], vec![0.0; 28], vec![0.9; 28]]; let result = bst.predict(feature).unwrap(); let mut normalized_result = Vec::new(); @@ -250,18 +267,35 @@ mod tests { assert_eq!(normalized_result, vec![0, 0, 1]); } + #[test] + fn num_feature(){ + let params = _default_params(); + let bst = _train_booster(¶ms); + let num_feature = bst.num_feature().unwrap(); + assert_eq!(num_feature, 28); + } + + #[test] + fn feature_importance() { + let params = _default_params(); + let bst = _train_booster(¶ms); + let feature_importance = bst.feature_importance().unwrap(); + assert_eq!(feature_importance, vec![0.0; 28]); + } + + #[test] + fn feature_name() { + let params = _default_params(); + let bst = _train_booster(¶ms); + let feature_name = bst.feature_name().unwrap(); + let target = (0..28).map(|i| format!("Column_{}", i)).collect::>(); + assert_eq!(feature_name, target); + } + #[test] fn save_file() { - let dataset = read_train_file().unwrap(); - let params = json! { - { - "num_iterations": 1, - "objective": "binary", - "metric": "auc", - "data_random_seed": 0 - } - }; - let bst = Booster::train(dataset, ¶ms).unwrap(); + let params = _default_params(); + let bst = _train_booster(¶ms); assert_eq!(bst.save_file(&"./test/test_save_file.output"), Ok(())); assert!(Path::new("./test/test_save_file.output").exists()); let _ = fs::remove_file("./test/test_save_file.output"); From 7c59c1af30b174eac4f62015eee25079f5668277 Mon Sep 17 00:00:00 2001 From: vaaaaanquish <6syun9@gmail.com> Date: Fri, 12 Feb 2021 02:46:20 +0900 Subject: [PATCH 4/4] cargo fmt --- src/booster.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/booster.rs b/src/booster.rs index 6979138..ff86b42 100644 --- a/src/booster.rs +++ b/src/booster.rs @@ -268,7 +268,7 @@ mod tests { } #[test] - fn num_feature(){ + fn num_feature() { let params = _default_params(); let bst = _train_booster(¶ms); let num_feature = bst.num_feature().unwrap();