From 95242ecceed6f20a559aca15e39b3a5de1ac3e34 Mon Sep 17 00:00:00 2001 From: Benjamin Trent Date: Thu, 1 Oct 2020 09:48:07 -0400 Subject: [PATCH] [ML] adding `baseline` field to total_feature_importance objects (#63098) (#63125) This adds a new `baseline` field to the feature importance values. This field contains the baseline importance for a given feature and class. --- .../metadata/TotalFeatureImportance.java | 16 ++++++++++++---- .../metadata/TotalFeatureImportanceTests.java | 6 +++++- .../metadata/TotalFeatureImportance.java | 18 ++++++++++++++---- .../core/ml/inference_index_template.json | 6 ++++++ .../metadata/TotalFeatureImportanceTests.java | 6 +++++- 5 files changed, 42 insertions(+), 10 deletions(-) diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/metadata/TotalFeatureImportance.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/metadata/TotalFeatureImportance.java index 882dc046d6d64..7f981c8327c39 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/metadata/TotalFeatureImportance.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/metadata/TotalFeatureImportance.java @@ -40,6 +40,7 @@ public class TotalFeatureImportance implements ToXContentObject { public static final ParseField IMPORTANCE = new ParseField("importance"); public static final ParseField CLASSES = new ParseField("classes"); public static final ParseField MEAN_MAGNITUDE = new ParseField("mean_magnitude"); + public static final ParseField BASELINE = new ParseField("baseline"); public static final ParseField MIN = new ParseField("min"); public static final ParseField MAX = new ParseField("max"); @@ -102,22 +103,25 @@ public static class Importance implements ToXContentObject { public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>(NAME, true, - a -> new Importance((double)a[0], (double)a[1], (double)a[2])); + a -> new Importance((double)a[0], (double)a[1], (double)a[2], (Double)a[3])); static { PARSER.declareDouble(ConstructingObjectParser.constructorArg(), MEAN_MAGNITUDE); PARSER.declareDouble(ConstructingObjectParser.constructorArg(), MIN); PARSER.declareDouble(ConstructingObjectParser.constructorArg(), MAX); + PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), BASELINE); } private final double meanMagnitude; private final double min; private final double max; + private final Double baseline; - public Importance(double meanMagnitude, double min, double max) { + public Importance(double meanMagnitude, double min, double max, Double baseline) { this.meanMagnitude = meanMagnitude; this.min = min; this.max = max; + this.baseline = baseline; } @Override @@ -127,12 +131,13 @@ public boolean equals(Object o) { Importance that = (Importance) o; return Double.compare(that.meanMagnitude, meanMagnitude) == 0 && Double.compare(that.min, min) == 0 && - Double.compare(that.max, max) == 0; + Double.compare(that.max, max) == 0 && + Objects.equals(that.baseline, baseline); } @Override public int hashCode() { - return Objects.hash(meanMagnitude, min, max); + return Objects.hash(meanMagnitude, min, max, baseline); } @Override @@ -141,6 +146,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.field(MEAN_MAGNITUDE.getPreferredName(), meanMagnitude); builder.field(MIN.getPreferredName(), min); builder.field(MAX.getPreferredName(), max); + if (baseline != null) { + builder.field(BASELINE.getPreferredName(), baseline); + } builder.endObject(); return builder; } diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/metadata/TotalFeatureImportanceTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/metadata/TotalFeatureImportanceTests.java index adbf9ab052d72..5f185df6e6a60 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/metadata/TotalFeatureImportanceTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/metadata/TotalFeatureImportanceTests.java @@ -50,7 +50,11 @@ public static TotalFeatureImportance randomInstance() { } private static TotalFeatureImportance.Importance randomImportance() { - return new TotalFeatureImportance.Importance(randomDouble(), randomDouble(), randomDouble()); + return new TotalFeatureImportance.Importance( + randomDouble(), + randomDouble(), + randomDouble(), + randomBoolean() ? null : randomDouble()); } @Override diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/metadata/TotalFeatureImportance.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/metadata/TotalFeatureImportance.java index 8676af6ff5ca0..8f40072c8177c 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/metadata/TotalFeatureImportance.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/metadata/TotalFeatureImportance.java @@ -35,6 +35,7 @@ public class TotalFeatureImportance implements ToXContentObject, Writeable { public static final ParseField MEAN_MAGNITUDE = new ParseField("mean_magnitude"); public static final ParseField MIN = new ParseField("min"); public static final ParseField MAX = new ParseField("max"); + public static final ParseField BASELINE = new ParseField("baseline"); // These parsers follow the pattern that metadata is parsed leniently (to allow for enhancements), whilst config is parsed strictly public static final ConstructingObjectParser LENIENT_PARSER = createParser(true); @@ -124,27 +125,31 @@ public static class Importance implements ToXContentObject, Writeable { private static ConstructingObjectParser createParser(boolean ignoreUnknownFields) { ConstructingObjectParser parser = new ConstructingObjectParser<>(NAME, ignoreUnknownFields, - a -> new Importance((double)a[0], (double)a[1], (double)a[2])); + a -> new Importance((double)a[0], (double)a[1], (double)a[2], (Double)a[3])); parser.declareDouble(ConstructingObjectParser.constructorArg(), MEAN_MAGNITUDE); parser.declareDouble(ConstructingObjectParser.constructorArg(), MIN); parser.declareDouble(ConstructingObjectParser.constructorArg(), MAX); + parser.declareDouble(ConstructingObjectParser.optionalConstructorArg(), BASELINE); return parser; } private final double meanMagnitude; private final double min; private final double max; + private final Double baseline; - public Importance(double meanMagnitude, double min, double max) { + public Importance(double meanMagnitude, double min, double max, Double baseline) { this.meanMagnitude = meanMagnitude; this.min = min; this.max = max; + this.baseline = baseline; } public Importance(StreamInput in) throws IOException { this.meanMagnitude = in.readDouble(); this.min = in.readDouble(); this.max = in.readDouble(); + this.baseline = in.readOptionalDouble(); } @Override @@ -154,12 +159,13 @@ public boolean equals(Object o) { Importance that = (Importance) o; return Double.compare(that.meanMagnitude, meanMagnitude) == 0 && Double.compare(that.min, min) == 0 && - Double.compare(that.max, max) == 0; + Double.compare(that.max, max) == 0 && + Objects.equals(that.baseline, baseline); } @Override public int hashCode() { - return Objects.hash(meanMagnitude, min, max); + return Objects.hash(meanMagnitude, min, max, baseline); } @Override @@ -167,6 +173,7 @@ public void writeTo(StreamOutput out) throws IOException { out.writeDouble(meanMagnitude); out.writeDouble(min); out.writeDouble(max); + out.writeOptionalDouble(baseline); } @Override @@ -179,6 +186,9 @@ private Map asMap() { map.put(MEAN_MAGNITUDE.getPreferredName(), meanMagnitude); map.put(MIN.getPreferredName(), min); map.put(MAX.getPreferredName(), max); + if (baseline != null) { + map.put(BASELINE.getPreferredName(), baseline); + } return map; } } diff --git a/x-pack/plugin/core/src/main/resources/org/elasticsearch/xpack/core/ml/inference_index_template.json b/x-pack/plugin/core/src/main/resources/org/elasticsearch/xpack/core/ml/inference_index_template.json index 00f5eb2a90fe2..f5fb2768a8de5 100644 --- a/x-pack/plugin/core/src/main/resources/org/elasticsearch/xpack/core/ml/inference_index_template.json +++ b/x-pack/plugin/core/src/main/resources/org/elasticsearch/xpack/core/ml/inference_index_template.json @@ -85,6 +85,9 @@ }, "mean_magnitude": { "type": "double" + }, + "baseline": { + "type": "double" } } }, @@ -105,6 +108,9 @@ }, "mean_magnitude": { "type": "double" + }, + "baseline": { + "type": "double" } } }, diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/metadata/TotalFeatureImportanceTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/metadata/TotalFeatureImportanceTests.java index fa68e71e8cc5d..ea5ccde3b9ca7 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/metadata/TotalFeatureImportanceTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/metadata/TotalFeatureImportanceTests.java @@ -41,7 +41,11 @@ public static TotalFeatureImportance randomInstance() { } private static TotalFeatureImportance.Importance randomImportance() { - return new TotalFeatureImportance.Importance(randomDouble(), randomDouble(), randomDouble()); + return new TotalFeatureImportance.Importance( + randomDouble(), + randomDouble(), + randomDouble(), + randomBoolean() ? null : randomDouble()); } @Before