Skip to content

Commit

Permalink
[ML] adding baseline field to total_feature_importance objects (#63098
Browse files Browse the repository at this point in the history
) (#63125)

This adds a new `baseline` field to the feature importance values. 

This field contains the baseline importance for a given feature and class.
  • Loading branch information
benwtrent authored Oct 1, 2020
1 parent fbf552d commit 95242ec
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ public class TotalFeatureImportance implements ToXContentObject {
public static final ParseField IMPORTANCE = new ParseField("importance");
public static final ParseField CLASSES = new ParseField("classes");
public static final ParseField MEAN_MAGNITUDE = new ParseField("mean_magnitude");
public static final ParseField BASELINE = new ParseField("baseline");
public static final ParseField MIN = new ParseField("min");
public static final ParseField MAX = new ParseField("max");

Expand Down Expand Up @@ -102,22 +103,25 @@ public static class Importance implements ToXContentObject {

public static final ConstructingObjectParser<Importance, Void> PARSER = new ConstructingObjectParser<>(NAME,
true,
a -> new Importance((double)a[0], (double)a[1], (double)a[2]));
a -> new Importance((double)a[0], (double)a[1], (double)a[2], (Double)a[3]));

static {
PARSER.declareDouble(ConstructingObjectParser.constructorArg(), MEAN_MAGNITUDE);
PARSER.declareDouble(ConstructingObjectParser.constructorArg(), MIN);
PARSER.declareDouble(ConstructingObjectParser.constructorArg(), MAX);
PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), BASELINE);
}

private final double meanMagnitude;
private final double min;
private final double max;
private final Double baseline;

public Importance(double meanMagnitude, double min, double max) {
public Importance(double meanMagnitude, double min, double max, Double baseline) {
this.meanMagnitude = meanMagnitude;
this.min = min;
this.max = max;
this.baseline = baseline;
}

@Override
Expand All @@ -127,12 +131,13 @@ public boolean equals(Object o) {
Importance that = (Importance) o;
return Double.compare(that.meanMagnitude, meanMagnitude) == 0 &&
Double.compare(that.min, min) == 0 &&
Double.compare(that.max, max) == 0;
Double.compare(that.max, max) == 0 &&
Objects.equals(that.baseline, baseline);
}

@Override
public int hashCode() {
return Objects.hash(meanMagnitude, min, max);
return Objects.hash(meanMagnitude, min, max, baseline);
}

@Override
Expand All @@ -141,6 +146,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
builder.field(MEAN_MAGNITUDE.getPreferredName(), meanMagnitude);
builder.field(MIN.getPreferredName(), min);
builder.field(MAX.getPreferredName(), max);
if (baseline != null) {
builder.field(BASELINE.getPreferredName(), baseline);
}
builder.endObject();
return builder;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,11 @@ public static TotalFeatureImportance randomInstance() {
}

private static TotalFeatureImportance.Importance randomImportance() {
return new TotalFeatureImportance.Importance(randomDouble(), randomDouble(), randomDouble());
return new TotalFeatureImportance.Importance(
randomDouble(),
randomDouble(),
randomDouble(),
randomBoolean() ? null : randomDouble());
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ public class TotalFeatureImportance implements ToXContentObject, Writeable {
public static final ParseField MEAN_MAGNITUDE = new ParseField("mean_magnitude");
public static final ParseField MIN = new ParseField("min");
public static final ParseField MAX = new ParseField("max");
public static final ParseField BASELINE = new ParseField("baseline");

// These parsers follow the pattern that metadata is parsed leniently (to allow for enhancements), whilst config is parsed strictly
public static final ConstructingObjectParser<TotalFeatureImportance, Void> LENIENT_PARSER = createParser(true);
Expand Down Expand Up @@ -124,27 +125,31 @@ public static class Importance implements ToXContentObject, Writeable {
private static ConstructingObjectParser<Importance, Void> createParser(boolean ignoreUnknownFields) {
ConstructingObjectParser<Importance, Void> parser = new ConstructingObjectParser<>(NAME,
ignoreUnknownFields,
a -> new Importance((double)a[0], (double)a[1], (double)a[2]));
a -> new Importance((double)a[0], (double)a[1], (double)a[2], (Double)a[3]));
parser.declareDouble(ConstructingObjectParser.constructorArg(), MEAN_MAGNITUDE);
parser.declareDouble(ConstructingObjectParser.constructorArg(), MIN);
parser.declareDouble(ConstructingObjectParser.constructorArg(), MAX);
parser.declareDouble(ConstructingObjectParser.optionalConstructorArg(), BASELINE);
return parser;
}

private final double meanMagnitude;
private final double min;
private final double max;
private final Double baseline;

public Importance(double meanMagnitude, double min, double max) {
public Importance(double meanMagnitude, double min, double max, Double baseline) {
this.meanMagnitude = meanMagnitude;
this.min = min;
this.max = max;
this.baseline = baseline;
}

public Importance(StreamInput in) throws IOException {
this.meanMagnitude = in.readDouble();
this.min = in.readDouble();
this.max = in.readDouble();
this.baseline = in.readOptionalDouble();
}

@Override
Expand All @@ -154,19 +159,21 @@ public boolean equals(Object o) {
Importance that = (Importance) o;
return Double.compare(that.meanMagnitude, meanMagnitude) == 0 &&
Double.compare(that.min, min) == 0 &&
Double.compare(that.max, max) == 0;
Double.compare(that.max, max) == 0 &&
Objects.equals(that.baseline, baseline);
}

@Override
public int hashCode() {
return Objects.hash(meanMagnitude, min, max);
return Objects.hash(meanMagnitude, min, max, baseline);
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeDouble(meanMagnitude);
out.writeDouble(min);
out.writeDouble(max);
out.writeOptionalDouble(baseline);
}

@Override
Expand All @@ -179,6 +186,9 @@ private Map<String, Object> asMap() {
map.put(MEAN_MAGNITUDE.getPreferredName(), meanMagnitude);
map.put(MIN.getPreferredName(), min);
map.put(MAX.getPreferredName(), max);
if (baseline != null) {
map.put(BASELINE.getPreferredName(), baseline);
}
return map;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,9 @@
},
"mean_magnitude": {
"type": "double"
},
"baseline": {
"type": "double"
}
}
},
Expand All @@ -105,6 +108,9 @@
},
"mean_magnitude": {
"type": "double"
},
"baseline": {
"type": "double"
}
}
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,11 @@ public static TotalFeatureImportance randomInstance() {
}

private static TotalFeatureImportance.Importance randomImportance() {
return new TotalFeatureImportance.Importance(randomDouble(), randomDouble(), randomDouble());
return new TotalFeatureImportance.Importance(
randomDouble(),
randomDouble(),
randomDouble(),
randomBoolean() ? null : randomDouble());
}

@Before
Expand Down

0 comments on commit 95242ec

Please sign in to comment.