Skip to content

Commit

Permalink
[ML] Partition-wise maximum scores (#32748)
Browse files Browse the repository at this point in the history
Added infrastructure to push through the 'person name field value' to
the normalizer process. This is required by the normalizer to retrieve
the maximum scores for individual partitions.
  • Loading branch information
edsavage authored Aug 13, 2018
1 parent 4d20e69 commit d147cd7
Show file tree
Hide file tree
Showing 13 changed files with 74 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,11 @@ public String getPersonFieldName() {
return bucketInfluencer.getInfluencerFieldName();
}

@Override
public String getPersonFieldValue() {
return null;
}

@Override
public String getFunctionName() {
return null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,11 @@ public String getPersonFieldName() {
return null;
}

@Override
public String getPersonFieldValue() {
return null;
}

@Override
public String getFunctionName() {
return null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,11 @@ public String getPersonFieldName() {
return influencer.getInfluencerFieldName();
}

@Override
public String getPersonFieldValue() {
return influencer.getInfluencerFieldValue();
}

@Override
public String getFunctionName() {
return null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,11 @@ public void writeRecord(String[] record) throws IOException {
result.setPartitionFieldName(record[1]);
result.setPartitionFieldValue(record[2]);
result.setPersonFieldName(record[3]);
result.setFunctionName(record[4]);
result.setValueFieldName(record[5]);
result.setProbability(Double.parseDouble(record[6]));
result.setNormalizedScore(factor * Double.parseDouble(record[7]));
result.setPersonFieldValue(record[4]);
result.setFunctionName(record[5]);
result.setValueFieldName(record[6]);
result.setProbability(Double.parseDouble(record[7]));
result.setNormalizedScore(factor * Double.parseDouble(record[8]));
} catch (NumberFormatException | ArrayIndexOutOfBoundsException e) {
throw new IOException("Unable to write to no-op normalizer", e);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ public Normalizable(String indexName) {

abstract String getPersonFieldName();

abstract String getPersonFieldValue();

abstract String getFunctionName();

abstract String getValueFieldName();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ public void normalize(Integer bucketSpan, boolean perPartitionNormalization,
NormalizerResult.PARTITION_FIELD_NAME_FIELD.getPreferredName(),
NormalizerResult.PARTITION_FIELD_VALUE_FIELD.getPreferredName(),
NormalizerResult.PERSON_FIELD_NAME_FIELD.getPreferredName(),
NormalizerResult.PERSON_FIELD_VALUE_FIELD.getPreferredName(),
NormalizerResult.FUNCTION_NAME_FIELD.getPreferredName(),
NormalizerResult.VALUE_FIELD_NAME_FIELD.getPreferredName(),
NormalizerResult.PROBABILITY_FIELD.getPreferredName(),
Expand Down Expand Up @@ -108,6 +109,7 @@ private static void writeNormalizableAndChildrenRecursively(Normalizable normali
Strings.coalesceToEmpty(normalizable.getPartitionFieldName()),
Strings.coalesceToEmpty(normalizable.getPartitionFieldValue()),
Strings.coalesceToEmpty(normalizable.getPersonFieldName()),
Strings.coalesceToEmpty(normalizable.getPersonFieldValue()),
Strings.coalesceToEmpty(normalizable.getFunctionName()),
Strings.coalesceToEmpty(normalizable.getValueFieldName()),
Double.toString(normalizable.getProbability()),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
*/
package org.elasticsearch.xpack.ml.job.process.normalizer;

import org.elasticsearch.Version;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
Expand All @@ -26,6 +27,7 @@ public class NormalizerResult implements ToXContentObject, Writeable {
static final ParseField PARTITION_FIELD_NAME_FIELD = new ParseField("partition_field_name");
static final ParseField PARTITION_FIELD_VALUE_FIELD = new ParseField("partition_field_value");
static final ParseField PERSON_FIELD_NAME_FIELD = new ParseField("person_field_name");
static final ParseField PERSON_FIELD_VALUE_FIELD = new ParseField("person_field_value");
static final ParseField FUNCTION_NAME_FIELD = new ParseField("function_name");
static final ParseField VALUE_FIELD_NAME_FIELD = new ParseField("value_field_name");
static final ParseField PROBABILITY_FIELD = new ParseField("probability");
Expand All @@ -39,6 +41,7 @@ public class NormalizerResult implements ToXContentObject, Writeable {
PARSER.declareString(NormalizerResult::setPartitionFieldName, PARTITION_FIELD_NAME_FIELD);
PARSER.declareString(NormalizerResult::setPartitionFieldValue, PARTITION_FIELD_VALUE_FIELD);
PARSER.declareString(NormalizerResult::setPersonFieldName, PERSON_FIELD_NAME_FIELD);
PARSER.declareString(NormalizerResult::setPersonFieldValue, PERSON_FIELD_VALUE_FIELD);
PARSER.declareString(NormalizerResult::setFunctionName, FUNCTION_NAME_FIELD);
PARSER.declareString(NormalizerResult::setValueFieldName, VALUE_FIELD_NAME_FIELD);
PARSER.declareDouble(NormalizerResult::setProbability, PROBABILITY_FIELD);
Expand All @@ -49,6 +52,7 @@ public class NormalizerResult implements ToXContentObject, Writeable {
private String partitionFieldName;
private String partitionFieldValue;
private String personFieldName;
private String personFieldValue;
private String functionName;
private String valueFieldName;
private double probability;
Expand All @@ -62,6 +66,9 @@ public NormalizerResult(StreamInput in) throws IOException {
partitionFieldName = in.readOptionalString();
partitionFieldValue = in.readOptionalString();
personFieldName = in.readOptionalString();
if (in.getVersion().onOrAfter(Version.V_6_5_0)) {
personFieldValue = in.readOptionalString();
}
functionName = in.readOptionalString();
valueFieldName = in.readOptionalString();
probability = in.readDouble();
Expand All @@ -74,6 +81,9 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeOptionalString(partitionFieldName);
out.writeOptionalString(partitionFieldValue);
out.writeOptionalString(personFieldName);
if (out.getVersion().onOrAfter(Version.V_6_5_0)) {
out.writeOptionalString(personFieldValue);
}
out.writeOptionalString(functionName);
out.writeOptionalString(valueFieldName);
out.writeDouble(probability);
Expand All @@ -87,6 +97,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
builder.field(PARTITION_FIELD_NAME_FIELD.getPreferredName(), partitionFieldName);
builder.field(PARTITION_FIELD_VALUE_FIELD.getPreferredName(), partitionFieldValue);
builder.field(PERSON_FIELD_NAME_FIELD.getPreferredName(), personFieldName);
builder.field(PERSON_FIELD_VALUE_FIELD.getPreferredName(), personFieldValue);
builder.field(FUNCTION_NAME_FIELD.getPreferredName(), functionName);
builder.field(VALUE_FIELD_NAME_FIELD.getPreferredName(), valueFieldName);
builder.field(PROBABILITY_FIELD.getPreferredName(), probability);
Expand Down Expand Up @@ -127,6 +138,14 @@ public void setPersonFieldName(String personFieldName) {
this.personFieldName = personFieldName;
}

public String getPersonFieldValue() {
return personFieldValue;
}

public void setPersonFieldValue(String personFieldValue) {
this.personFieldValue = personFieldValue;
}

public String getFunctionName() {
return functionName;
}
Expand Down Expand Up @@ -161,7 +180,7 @@ public void setNormalizedScore(double normalizedScore) {

@Override
public int hashCode() {
return Objects.hash(level, partitionFieldName, partitionFieldValue, personFieldName,
return Objects.hash(level, partitionFieldName, partitionFieldValue, personFieldName, personFieldValue,
functionName, valueFieldName, probability, normalizedScore);
}

Expand All @@ -184,6 +203,7 @@ public boolean equals(Object other) {
&& Objects.equals(this.partitionFieldName, that.partitionFieldName)
&& Objects.equals(this.partitionFieldValue, that.partitionFieldValue)
&& Objects.equals(this.personFieldName, that.personFieldName)
&& Objects.equals(this.personFieldValue, that.personFieldValue)
&& Objects.equals(this.functionName, that.functionName)
&& Objects.equals(this.valueFieldName, that.valueFieldName)
&& this.probability == that.probability
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,11 @@ public String getPersonFieldName() {
return null;
}

@Override
public String getPersonFieldValue() {
return null;
}

@Override
public String getFunctionName() {
return null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,12 @@ public String getPersonFieldName() {
return over != null ? over : record.getByFieldName();
}

@Override
public String getPersonFieldValue() {
String over = record.getOverFieldValue();
return over != null ? over : record.getByFieldValue();
}

@Override
public String getFunctionName() {
return record.getFunction();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,18 @@ public void testGetPartitionFieldName() {
assertNull(new BucketInfluencerNormalizable(bucketInfluencer, INDEX_NAME).getPartitionFieldName());
}

public void testGetPartitionFieldValue() {
assertNull(new BucketInfluencerNormalizable(bucketInfluencer, INDEX_NAME).getPartitionFieldValue());
}

public void testGetPersonFieldName() {
assertEquals("airline", new BucketInfluencerNormalizable(bucketInfluencer, INDEX_NAME).getPersonFieldName());
}

public void testGetPersonFieldValue() {
assertNull(new BucketInfluencerNormalizable(bucketInfluencer, INDEX_NAME).getPersonFieldValue());
}

public void testGetFunctionName() {
assertNull(new BucketInfluencerNormalizable(bucketInfluencer, INDEX_NAME).getFunctionName());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,10 @@ public void testGetPersonFieldName() {
assertNull(new BucketNormalizable(bucket, INDEX_NAME).getPersonFieldName());
}

public void testGetPersonFieldValue() {
assertNull(new BucketNormalizable(bucket, INDEX_NAME).getPersonFieldValue());
}

public void testGetFunctionName() {
assertNull(new BucketNormalizable(bucket, INDEX_NAME).getFunctionName());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ public void testGetPersonFieldName() {
assertEquals("airline", new InfluencerNormalizable(influencer, INDEX_NAME).getPersonFieldName());
}

public void testGetPersonFieldValue() {
assertEquals("AAL", new InfluencerNormalizable(influencer, INDEX_NAME).getPersonFieldValue());
}

public void testGetFunctionName() {
assertNull(new InfluencerNormalizable(influencer, INDEX_NAME).getFunctionName());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ public void testDefaultConstructor() {
assertNull(msg.getPartitionFieldName());
assertNull(msg.getPartitionFieldValue());
assertNull(msg.getPersonFieldName());
assertNull(msg.getPersonFieldValue());
assertNull(msg.getFunctionName());
assertNull(msg.getValueFieldName());
assertEquals(0.0, msg.getProbability(), EPSILON);
Expand All @@ -32,6 +33,7 @@ protected NormalizerResult createTestInstance() {
msg.setPartitionFieldName("part");
msg.setPartitionFieldValue("something");
msg.setPersonFieldName("person");
msg.setPersonFieldValue("fred");
msg.setFunctionName("mean");
msg.setValueFieldName("value");
msg.setProbability(0.005);
Expand Down

0 comments on commit d147cd7

Please sign in to comment.