Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove normalization of confidence scores in intent classification #130

Merged
merged 2 commits into from
Apr 5, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ All notable changes to this project will be documented in this file.
## [Unreleased]
### Fixed
- Fix handling of ambiguous utterances in `DeterministicIntentParser`
- Stop normalizing confidence scores when there is an intents filter

## [0.64.1] - 2019-03-01
### Fixed
Expand Down
49 changes: 2 additions & 47 deletions src/intent_classifier/log_reg_intent_classifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,7 @@ impl LogRegIntentClassifier {
let logreg = self.logreg.as_ref().unwrap(); // Checked above

let features = featurizer.transform(input)?;
let filtered_out_indexes =
get_filtered_out_intents_indexes(&self.intent_list, opt_intents_set.as_ref());
let scores = logreg.run(&features.view(), filtered_out_indexes)?;
let scores = logreg.run(&features.view())?;

Ok(self
.intent_list
Expand Down Expand Up @@ -162,34 +160,11 @@ impl LogRegIntentClassifier {
}
}

fn get_filtered_out_intents_indexes(
intents_list: &[Option<IntentName>],
intents_filter: Option<&HashSet<&str>>,
) -> Option<Vec<usize>> {
intents_filter.map(|filter| {
intents_list
.iter()
.enumerate()
.filter_map(|(i, opt_intent)| {
if let Some(intent) = opt_intent {
if !filter.contains(&**intent) {
Some(i)
} else {
None
}
} else {
None
}
})
.collect()
})
}

#[cfg(test)]
mod tests {
use super::*;

use maplit::{hashmap, hashset};
use maplit::hashmap;
use ndarray::array;

use crate::intent_classifier::TfidfVectorizer;
Expand Down Expand Up @@ -437,24 +412,4 @@ mod tests {
assert_eq!(Some("MakeCoffee".to_string()), result2.intent_name);
assert_eq!(None, result3.intent_name);
}

#[test]
fn test_get_filtered_out_intents_indexes() {
// Given
let intents_list = vec![
Some("intent1".to_string()),
Some("intent2".to_string()),
Some("intent3".to_string()),
None,
];
let intents_filter = hashset!["intent1", "intent3"];

// When
let filtered_indexes =
get_filtered_out_intents_indexes(&intents_list, Some(&intents_filter));

// Then
let expected_indexes = Some(vec![1]);
assert_eq!(expected_indexes, filtered_indexes);
}
}
45 changes: 3 additions & 42 deletions src/intent_classifier/logreg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,7 @@ impl MulticlassLogisticRegression {
})
}

pub fn run(
&self,
features: &ArrayView1<f32>,
filtered_out_indexes: Option<Vec<usize>>,
) -> Result<Array1<f32>> {
pub fn run(&self, features: &ArrayView1<f32>) -> Result<Array1<f32>> {
let reshaped_features = features.into_shape((1, self.nb_features()))?;
let reshaped_features = stack![Axis(1), array![[1.]], reshaped_features];
let mut result = reshaped_features
Expand All @@ -53,15 +49,6 @@ impl MulticlassLogisticRegression {
if self.is_binary() {
return Ok(arr1(&[1.0 - result[0], result[0]]));
}
if let Some(indexes) = filtered_out_indexes {
if !indexes.is_empty() {
for index in indexes {
result[index] = 0.0;
}
let divider = result.scalar_sum();
result /= divider;
}
}
Ok(result)
}
}
Expand Down Expand Up @@ -91,7 +78,7 @@ mod tests {
let regression = MulticlassLogisticRegression::new(intercept, weights).unwrap();

// When
let predictions = regression.run(&features.view(), None).unwrap();
let predictions = regression.run(&features.view()).unwrap();

// Then
let expected_predictions = array![0.7109495, 0.3384968, 0.8710191];
Expand All @@ -108,36 +95,10 @@ mod tests {
let regression = MulticlassLogisticRegression::new(intercept, weights).unwrap();

// When
let predictions = regression.run(&features.view(), None).unwrap();
let predictions = regression.run(&features.view()).unwrap();

// Then
let expected_predictions = array![0.2890504, 0.7109495];
assert_epsilon_eq_array1(&predictions, &expected_predictions, 1e-06);
}

#[test]
fn test_multiclass_logistic_regression_with_filtered_out_indexes() {
// Given
let intercept = array![0.98, 0.32, -0.76];
let weights = array![
[2.5, -0.6, 0.5],
[1.2, 1.2, -2.7],
[1.5, 0.1, -3.2],
[-0.9, 1.4, 1.8]
];

let features = array![0.4, -2.3, 1.9, 1.3];

let filtered_out_indexes = Some(vec![2]);
let regression = MulticlassLogisticRegression::new(intercept, weights).unwrap();

// When
let predictions = regression
.run(&features.view(), filtered_out_indexes)
.unwrap();

// Then
let expected_predictions = array![0.67745198, 0.32254802, 0.0];
assert_epsilon_eq_array1(&predictions, &expected_predictions, 1e-06);
}
}