Skip to content

Commit

Permalink
[ML] adds new feature_processors field for data frame analytics (#60528)
Browse files Browse the repository at this point in the history
feature_processors allow users to create custom features from
individual document fields.

These `feature_processors` are the same object as the trained model's pre_processors. 

They are passed to the native process and the native process then appends them to the
pre_processor array in the inference model.

closes #59327
  • Loading branch information
benwtrent authored Aug 14, 2020
1 parent 69f7066 commit de3107a
Show file tree
Hide file tree
Showing 44 changed files with 1,590 additions and 193 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,14 @@
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.index.mapper.FieldAliasMapper;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.LenientlyParsedPreProcessor;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.StrictlyParsedPreProcessor;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.PredictionFieldType;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.core.ml.utils.NamedXContentObjectHelper;

import java.io.IOException;
import java.util.Arrays;
Expand Down Expand Up @@ -46,6 +50,7 @@ public class Classification implements DataFrameAnalysis {
public static final ParseField NUM_TOP_CLASSES = new ParseField("num_top_classes");
public static final ParseField TRAINING_PERCENT = new ParseField("training_percent");
public static final ParseField RANDOMIZE_SEED = new ParseField("randomize_seed");
public static final ParseField FEATURE_PROCESSORS = new ParseField("feature_processors");

private static final String STATE_DOC_ID_SUFFIX = "_classification_state#1";

Expand All @@ -59,6 +64,7 @@ public class Classification implements DataFrameAnalysis {
*/
public static final int MAX_DEPENDENT_VARIABLE_CARDINALITY = 30;

@SuppressWarnings("unchecked")
private static ConstructingObjectParser<Classification, Void> createParser(boolean lenient) {
ConstructingObjectParser<Classification, Void> parser = new ConstructingObjectParser<>(
NAME.getPreferredName(),
Expand All @@ -70,14 +76,21 @@ private static ConstructingObjectParser<Classification, Void> createParser(boole
(ClassAssignmentObjective) a[8],
(Integer) a[9],
(Double) a[10],
(Long) a[11]));
(Long) a[11],
(List<PreProcessor>) a[12]));
parser.declareString(constructorArg(), DEPENDENT_VARIABLE);
BoostedTreeParams.declareFields(parser);
parser.declareString(optionalConstructorArg(), PREDICTION_FIELD_NAME);
parser.declareString(optionalConstructorArg(), ClassAssignmentObjective::fromString, CLASS_ASSIGNMENT_OBJECTIVE);
parser.declareInt(optionalConstructorArg(), NUM_TOP_CLASSES);
parser.declareDouble(optionalConstructorArg(), TRAINING_PERCENT);
parser.declareLong(optionalConstructorArg(), RANDOMIZE_SEED);
parser.declareNamedObjects(optionalConstructorArg(),
(p, c, n) -> lenient ?
p.namedObject(LenientlyParsedPreProcessor.class, n, new PreProcessor.PreProcessorParseContext(true)) :
p.namedObject(StrictlyParsedPreProcessor.class, n, new PreProcessor.PreProcessorParseContext(true)),
(classification) -> {/*TODO should we throw if this is not set?*/},
FEATURE_PROCESSORS);
return parser;
}

Expand Down Expand Up @@ -117,14 +130,16 @@ public static Classification fromXContent(XContentParser parser, boolean ignoreU
private final int numTopClasses;
private final double trainingPercent;
private final long randomizeSeed;
private final List<PreProcessor> featureProcessors;

public Classification(String dependentVariable,
BoostedTreeParams boostedTreeParams,
@Nullable String predictionFieldName,
@Nullable ClassAssignmentObjective classAssignmentObjective,
@Nullable Integer numTopClasses,
@Nullable Double trainingPercent,
@Nullable Long randomizeSeed) {
@Nullable Long randomizeSeed,
@Nullable List<PreProcessor> featureProcessors) {
if (numTopClasses != null && (numTopClasses < 0 || numTopClasses > 1000)) {
throw ExceptionsHelper.badRequestException("[{}] must be an integer in [0, 1000]", NUM_TOP_CLASSES.getPreferredName());
}
Expand All @@ -139,10 +154,11 @@ public Classification(String dependentVariable,
this.numTopClasses = numTopClasses == null ? DEFAULT_NUM_TOP_CLASSES : numTopClasses;
this.trainingPercent = trainingPercent == null ? 100.0 : trainingPercent;
this.randomizeSeed = randomizeSeed == null ? Randomness.get().nextLong() : randomizeSeed;
this.featureProcessors = featureProcessors == null ? Collections.emptyList() : Collections.unmodifiableList(featureProcessors);
}

public Classification(String dependentVariable) {
this(dependentVariable, BoostedTreeParams.builder().build(), null, null, null, null, null);
this(dependentVariable, BoostedTreeParams.builder().build(), null, null, null, null, null, null);
}

public Classification(StreamInput in) throws IOException {
Expand All @@ -161,6 +177,11 @@ public Classification(StreamInput in) throws IOException {
} else {
randomizeSeed = Randomness.get().nextLong();
}
if (in.getVersion().onOrAfter(Version.V_8_0_0)) {
featureProcessors = Collections.unmodifiableList(in.readNamedWriteableList(PreProcessor.class));
} else {
featureProcessors = Collections.emptyList();
}
}

public String getDependentVariable() {
Expand Down Expand Up @@ -191,6 +212,10 @@ public long getRandomizeSeed() {
return randomizeSeed;
}

public List<PreProcessor> getFeatureProcessors() {
return featureProcessors;
}

@Override
public String getWriteableName() {
return NAME.getPreferredName();
Expand All @@ -209,6 +234,9 @@ public void writeTo(StreamOutput out) throws IOException {
if (out.getVersion().onOrAfter(Version.V_7_6_0)) {
out.writeOptionalLong(randomizeSeed);
}
if (out.getVersion().onOrAfter(Version.V_8_0_0)) {
out.writeNamedWriteableList(featureProcessors);
}
}

@Override
Expand All @@ -227,6 +255,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
if (version.onOrAfter(Version.V_7_6_0)) {
builder.field(RANDOMIZE_SEED.getPreferredName(), randomizeSeed);
}
if (featureProcessors.isEmpty() == false) {
NamedXContentObjectHelper.writeNamedObjects(builder, params, true, FEATURE_PROCESSORS.getPreferredName(), featureProcessors);
}
builder.endObject();
return builder;
}
Expand All @@ -247,6 +278,10 @@ public Map<String, Object> getParams(FieldInfo fieldInfo) {
}
params.put(NUM_CLASSES, fieldInfo.getCardinality(dependentVariable));
params.put(TRAINING_PERCENT.getPreferredName(), trainingPercent);
if (featureProcessors.isEmpty() == false) {
params.put(FEATURE_PROCESSORS.getPreferredName(),
featureProcessors.stream().map(p -> Collections.singletonMap(p.getName(), p)).collect(Collectors.toList()));
}
return params;
}

Expand Down Expand Up @@ -388,14 +423,15 @@ public boolean equals(Object o) {
&& Objects.equals(predictionFieldName, that.predictionFieldName)
&& Objects.equals(classAssignmentObjective, that.classAssignmentObjective)
&& Objects.equals(numTopClasses, that.numTopClasses)
&& Objects.equals(featureProcessors, that.featureProcessors)
&& trainingPercent == that.trainingPercent
&& randomizeSeed == that.randomizeSeed;
}

@Override
public int hashCode() {
return Objects.hash(dependentVariable, boostedTreeParams, predictionFieldName, classAssignmentObjective,
numTopClasses, trainingPercent, randomizeSeed);
numTopClasses, trainingPercent, randomizeSeed, featureProcessors);
}

public enum ClassAssignmentObjective {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,13 @@
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.index.mapper.NumberFieldMapper;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.LenientlyParsedPreProcessor;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.StrictlyParsedPreProcessor;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.core.ml.utils.NamedXContentObjectHelper;

import java.io.IOException;
import java.util.Arrays;
Expand All @@ -28,6 +32,7 @@
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;

import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
Expand All @@ -42,12 +47,14 @@ public class Regression implements DataFrameAnalysis {
public static final ParseField RANDOMIZE_SEED = new ParseField("randomize_seed");
public static final ParseField LOSS_FUNCTION = new ParseField("loss_function");
public static final ParseField LOSS_FUNCTION_PARAMETER = new ParseField("loss_function_parameter");
public static final ParseField FEATURE_PROCESSORS = new ParseField("feature_processors");

private static final String STATE_DOC_ID_SUFFIX = "_regression_state#1";

private static final ConstructingObjectParser<Regression, Void> LENIENT_PARSER = createParser(true);
private static final ConstructingObjectParser<Regression, Void> STRICT_PARSER = createParser(false);

@SuppressWarnings("unchecked")
private static ConstructingObjectParser<Regression, Void> createParser(boolean lenient) {
ConstructingObjectParser<Regression, Void> parser = new ConstructingObjectParser<>(
NAME.getPreferredName(),
Expand All @@ -59,14 +66,21 @@ private static ConstructingObjectParser<Regression, Void> createParser(boolean l
(Double) a[8],
(Long) a[9],
(LossFunction) a[10],
(Double) a[11]));
(Double) a[11],
(List<PreProcessor>) a[12]));
parser.declareString(constructorArg(), DEPENDENT_VARIABLE);
BoostedTreeParams.declareFields(parser);
parser.declareString(optionalConstructorArg(), PREDICTION_FIELD_NAME);
parser.declareDouble(optionalConstructorArg(), TRAINING_PERCENT);
parser.declareLong(optionalConstructorArg(), RANDOMIZE_SEED);
parser.declareString(optionalConstructorArg(), LossFunction::fromString, LOSS_FUNCTION);
parser.declareDouble(optionalConstructorArg(), LOSS_FUNCTION_PARAMETER);
parser.declareNamedObjects(optionalConstructorArg(),
(p, c, n) -> lenient ?
p.namedObject(LenientlyParsedPreProcessor.class, n, new PreProcessor.PreProcessorParseContext(true)) :
p.namedObject(StrictlyParsedPreProcessor.class, n, new PreProcessor.PreProcessorParseContext(true)),
(regression) -> {/*TODO should we throw if this is not set?*/},
FEATURE_PROCESSORS);
return parser;
}

Expand All @@ -90,14 +104,16 @@ public static Regression fromXContent(XContentParser parser, boolean ignoreUnkno
private final long randomizeSeed;
private final LossFunction lossFunction;
private final Double lossFunctionParameter;
private final List<PreProcessor> featureProcessors;

public Regression(String dependentVariable,
BoostedTreeParams boostedTreeParams,
@Nullable String predictionFieldName,
@Nullable Double trainingPercent,
@Nullable Long randomizeSeed,
@Nullable LossFunction lossFunction,
@Nullable Double lossFunctionParameter) {
@Nullable Double lossFunctionParameter,
@Nullable List<PreProcessor> featureProcessors) {
if (trainingPercent != null && (trainingPercent < 1.0 || trainingPercent > 100.0)) {
throw ExceptionsHelper.badRequestException("[{}] must be a double in [1, 100]", TRAINING_PERCENT.getPreferredName());
}
Expand All @@ -112,10 +128,11 @@ public Regression(String dependentVariable,
throw ExceptionsHelper.badRequestException("[{}] must be a positive double", LOSS_FUNCTION_PARAMETER.getPreferredName());
}
this.lossFunctionParameter = lossFunctionParameter;
this.featureProcessors = featureProcessors == null ? Collections.emptyList() : Collections.unmodifiableList(featureProcessors);
}

public Regression(String dependentVariable) {
this(dependentVariable, BoostedTreeParams.builder().build(), null, null, null, null, null);
this(dependentVariable, BoostedTreeParams.builder().build(), null, null, null, null, null, null);
}

public Regression(StreamInput in) throws IOException {
Expand All @@ -126,6 +143,11 @@ public Regression(StreamInput in) throws IOException {
randomizeSeed = in.readOptionalLong();
lossFunction = in.readEnum(LossFunction.class);
lossFunctionParameter = in.readOptionalDouble();
if (in.getVersion().onOrAfter(Version.V_8_0_0)) {
featureProcessors = Collections.unmodifiableList(in.readNamedWriteableList(PreProcessor.class));
} else {
featureProcessors = Collections.emptyList();
}
}

public String getDependentVariable() {
Expand Down Expand Up @@ -156,6 +178,10 @@ public Double getLossFunctionParameter() {
return lossFunctionParameter;
}

public List<PreProcessor> getFeatureProcessors() {
return featureProcessors;
}

@Override
public String getWriteableName() {
return NAME.getPreferredName();
Expand All @@ -170,6 +196,9 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeOptionalLong(randomizeSeed);
out.writeEnum(lossFunction);
out.writeOptionalDouble(lossFunctionParameter);
if (out.getVersion().onOrAfter(Version.V_8_0_0)) {
out.writeNamedWriteableList(featureProcessors);
}
}

@Override
Expand All @@ -190,6 +219,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
if (lossFunctionParameter != null) {
builder.field(LOSS_FUNCTION_PARAMETER.getPreferredName(), lossFunctionParameter);
}
if (featureProcessors.isEmpty() == false) {
NamedXContentObjectHelper.writeNamedObjects(builder, params, true, FEATURE_PROCESSORS.getPreferredName(), featureProcessors);
}
builder.endObject();
return builder;
}
Expand All @@ -207,6 +239,10 @@ public Map<String, Object> getParams(FieldInfo fieldInfo) {
if (lossFunctionParameter != null) {
params.put(LOSS_FUNCTION_PARAMETER.getPreferredName(), lossFunctionParameter);
}
if (featureProcessors.isEmpty() == false) {
params.put(FEATURE_PROCESSORS.getPreferredName(),
featureProcessors.stream().map(p -> Collections.singletonMap(p.getName(), p)).collect(Collectors.toList()));
}
return params;
}

Expand Down Expand Up @@ -290,13 +326,14 @@ public boolean equals(Object o) {
&& trainingPercent == that.trainingPercent
&& randomizeSeed == that.randomizeSeed
&& lossFunction == that.lossFunction
&& Objects.equals(featureProcessors, that.featureProcessors)
&& Objects.equals(lossFunctionParameter, that.lossFunctionParameter);
}

@Override
public int hashCode() {
return Objects.hash(dependentVariable, boostedTreeParams, predictionFieldName, trainingPercent, randomizeSeed, lossFunction,
lossFunctionParameter);
lossFunctionParameter, featureProcessors);
}

public enum LossFunction {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,23 +57,23 @@ public List<NamedXContentRegistry.Entry> getNamedXContentParsers() {

// PreProcessing Lenient
namedXContent.add(new NamedXContentRegistry.Entry(LenientlyParsedPreProcessor.class, OneHotEncoding.NAME,
OneHotEncoding::fromXContentLenient));
(p, c) -> OneHotEncoding.fromXContentLenient(p, (PreProcessor.PreProcessorParseContext) c)));
namedXContent.add(new NamedXContentRegistry.Entry(LenientlyParsedPreProcessor.class, TargetMeanEncoding.NAME,
TargetMeanEncoding::fromXContentLenient));
(p, c) -> TargetMeanEncoding.fromXContentLenient(p, (PreProcessor.PreProcessorParseContext) c)));
namedXContent.add(new NamedXContentRegistry.Entry(LenientlyParsedPreProcessor.class, FrequencyEncoding.NAME,
FrequencyEncoding::fromXContentLenient));
(p, c) -> FrequencyEncoding.fromXContentLenient(p, (PreProcessor.PreProcessorParseContext) c)));
namedXContent.add(new NamedXContentRegistry.Entry(LenientlyParsedPreProcessor.class, CustomWordEmbedding.NAME,
CustomWordEmbedding::fromXContentLenient));
(p, c) -> CustomWordEmbedding.fromXContentLenient(p)));

// PreProcessing Strict
namedXContent.add(new NamedXContentRegistry.Entry(StrictlyParsedPreProcessor.class, OneHotEncoding.NAME,
OneHotEncoding::fromXContentStrict));
(p, c) -> OneHotEncoding.fromXContentStrict(p, (PreProcessor.PreProcessorParseContext) c)));
namedXContent.add(new NamedXContentRegistry.Entry(StrictlyParsedPreProcessor.class, TargetMeanEncoding.NAME,
TargetMeanEncoding::fromXContentStrict));
(p, c) -> TargetMeanEncoding.fromXContentStrict(p, (PreProcessor.PreProcessorParseContext) c)));
namedXContent.add(new NamedXContentRegistry.Entry(StrictlyParsedPreProcessor.class, FrequencyEncoding.NAME,
FrequencyEncoding::fromXContentStrict));
(p, c) -> FrequencyEncoding.fromXContentStrict(p, (PreProcessor.PreProcessorParseContext) c)));
namedXContent.add(new NamedXContentRegistry.Entry(StrictlyParsedPreProcessor.class, CustomWordEmbedding.NAME,
CustomWordEmbedding::fromXContentStrict));
(p, c) -> CustomWordEmbedding.fromXContentStrict(p)));

// Model Lenient
namedXContent.add(new NamedXContentRegistry.Entry(LenientlyParsedTrainedModel.class, Tree.NAME, Tree::fromXContentLenient));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ private static ObjectParser<TrainedModelDefinition.Builder, Void> createParser(b
TRAINED_MODEL);
parser.declareNamedObjects(TrainedModelDefinition.Builder::setPreProcessors,
(p, c, n) -> ignoreUnknownFields ?
p.namedObject(LenientlyParsedPreProcessor.class, n, null) :
p.namedObject(StrictlyParsedPreProcessor.class, n, null),
p.namedObject(LenientlyParsedPreProcessor.class, n, PreProcessor.PreProcessorParseContext.DEFAULT) :
p.namedObject(StrictlyParsedPreProcessor.class, n, PreProcessor.PreProcessorParseContext.DEFAULT),
(trainedModelDefBuilder) -> trainedModelDefBuilder.setProcessorsInOrder(true),
PREPROCESSORS);
return parser;
Expand Down
Loading

0 comments on commit de3107a

Please sign in to comment.