Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ML] adds new feature_processors field for data frame analytics #60528

Merged
merged 15 commits into from
Aug 14, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,14 @@
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.index.mapper.FieldAliasMapper;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.LenientlyParsedPreProcessor;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.StrictlyParsedPreProcessor;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.PredictionFieldType;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.core.ml.utils.NamedXContentObjectHelper;

import java.io.IOException;
import java.util.Arrays;
Expand Down Expand Up @@ -46,6 +50,7 @@ public class Classification implements DataFrameAnalysis {
public static final ParseField NUM_TOP_CLASSES = new ParseField("num_top_classes");
public static final ParseField TRAINING_PERCENT = new ParseField("training_percent");
public static final ParseField RANDOMIZE_SEED = new ParseField("randomize_seed");
public static final ParseField FEATURE_PROCESSORS = new ParseField("feature_processors");

private static final String STATE_DOC_ID_SUFFIX = "_classification_state#1";

Expand All @@ -59,6 +64,7 @@ public class Classification implements DataFrameAnalysis {
*/
public static final int MAX_DEPENDENT_VARIABLE_CARDINALITY = 30;

@SuppressWarnings("unchecked")
private static ConstructingObjectParser<Classification, Void> createParser(boolean lenient) {
ConstructingObjectParser<Classification, Void> parser = new ConstructingObjectParser<>(
NAME.getPreferredName(),
Expand All @@ -70,14 +76,21 @@ private static ConstructingObjectParser<Classification, Void> createParser(boole
(ClassAssignmentObjective) a[8],
(Integer) a[9],
(Double) a[10],
(Long) a[11]));
(Long) a[11],
(List<PreProcessor>) a[12]));
parser.declareString(constructorArg(), DEPENDENT_VARIABLE);
BoostedTreeParams.declareFields(parser);
parser.declareString(optionalConstructorArg(), PREDICTION_FIELD_NAME);
parser.declareString(optionalConstructorArg(), ClassAssignmentObjective::fromString, CLASS_ASSIGNMENT_OBJECTIVE);
parser.declareInt(optionalConstructorArg(), NUM_TOP_CLASSES);
parser.declareDouble(optionalConstructorArg(), TRAINING_PERCENT);
parser.declareLong(optionalConstructorArg(), RANDOMIZE_SEED);
parser.declareNamedObjects(optionalConstructorArg(),
(p, c, n) -> lenient ?
p.namedObject(LenientlyParsedPreProcessor.class, n, new PreProcessor.PreProcessorParseContext(true)) :
p.namedObject(StrictlyParsedPreProcessor.class, n, new PreProcessor.PreProcessorParseContext(true)),
(classification) -> {/*TODO should we throw if this is not set?*/},
FEATURE_PROCESSORS);
return parser;
}

Expand Down Expand Up @@ -117,14 +130,16 @@ public static Classification fromXContent(XContentParser parser, boolean ignoreU
private final int numTopClasses;
private final double trainingPercent;
private final long randomizeSeed;
private final List<PreProcessor> featureProcessors;

public Classification(String dependentVariable,
BoostedTreeParams boostedTreeParams,
@Nullable String predictionFieldName,
@Nullable ClassAssignmentObjective classAssignmentObjective,
@Nullable Integer numTopClasses,
@Nullable Double trainingPercent,
@Nullable Long randomizeSeed) {
@Nullable Long randomizeSeed,
@Nullable List<PreProcessor> featureProcessors) {
if (numTopClasses != null && (numTopClasses < 0 || numTopClasses > 1000)) {
throw ExceptionsHelper.badRequestException("[{}] must be an integer in [0, 1000]", NUM_TOP_CLASSES.getPreferredName());
}
Expand All @@ -139,10 +154,11 @@ public Classification(String dependentVariable,
this.numTopClasses = numTopClasses == null ? DEFAULT_NUM_TOP_CLASSES : numTopClasses;
this.trainingPercent = trainingPercent == null ? 100.0 : trainingPercent;
this.randomizeSeed = randomizeSeed == null ? Randomness.get().nextLong() : randomizeSeed;
this.featureProcessors = featureProcessors == null ? Collections.emptyList() : Collections.unmodifiableList(featureProcessors);
}

public Classification(String dependentVariable) {
this(dependentVariable, BoostedTreeParams.builder().build(), null, null, null, null, null);
this(dependentVariable, BoostedTreeParams.builder().build(), null, null, null, null, null, null);
}

public Classification(StreamInput in) throws IOException {
Expand All @@ -161,6 +177,11 @@ public Classification(StreamInput in) throws IOException {
} else {
randomizeSeed = Randomness.get().nextLong();
}
if (in.getVersion().onOrAfter(Version.V_8_0_0)) {
featureProcessors = Collections.unmodifiableList(in.readNamedWriteableList(PreProcessor.class));
} else {
featureProcessors = Collections.emptyList();
}
}

public String getDependentVariable() {
Expand Down Expand Up @@ -191,6 +212,10 @@ public long getRandomizeSeed() {
return randomizeSeed;
}

public List<PreProcessor> getFeatureProcessors() {
return featureProcessors;
}

@Override
public String getWriteableName() {
return NAME.getPreferredName();
Expand All @@ -209,6 +234,9 @@ public void writeTo(StreamOutput out) throws IOException {
if (out.getVersion().onOrAfter(Version.V_7_6_0)) {
out.writeOptionalLong(randomizeSeed);
}
if (out.getVersion().onOrAfter(Version.V_8_0_0)) {
out.writeNamedWriteableList(featureProcessors);
}
}

@Override
Expand All @@ -227,6 +255,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
if (version.onOrAfter(Version.V_7_6_0)) {
builder.field(RANDOMIZE_SEED.getPreferredName(), randomizeSeed);
}
if (featureProcessors.isEmpty() == false) {
NamedXContentObjectHelper.writeNamedObjects(builder, params, true, FEATURE_PROCESSORS.getPreferredName(), featureProcessors);
}
builder.endObject();
return builder;
}
Expand All @@ -247,6 +278,10 @@ public Map<String, Object> getParams(FieldInfo fieldInfo) {
}
params.put(NUM_CLASSES, fieldInfo.getCardinality(dependentVariable));
params.put(TRAINING_PERCENT.getPreferredName(), trainingPercent);
if (featureProcessors.isEmpty() == false) {
params.put(FEATURE_PROCESSORS.getPreferredName(),
featureProcessors.stream().map(p -> Collections.singletonMap(p.getName(), p)).collect(Collectors.toList()));
}
return params;
}

Expand Down Expand Up @@ -388,14 +423,15 @@ public boolean equals(Object o) {
&& Objects.equals(predictionFieldName, that.predictionFieldName)
&& Objects.equals(classAssignmentObjective, that.classAssignmentObjective)
&& Objects.equals(numTopClasses, that.numTopClasses)
&& Objects.equals(featureProcessors, that.featureProcessors)
&& trainingPercent == that.trainingPercent
&& randomizeSeed == that.randomizeSeed;
}

@Override
public int hashCode() {
return Objects.hash(dependentVariable, boostedTreeParams, predictionFieldName, classAssignmentObjective,
numTopClasses, trainingPercent, randomizeSeed);
numTopClasses, trainingPercent, randomizeSeed, featureProcessors);
}

public enum ClassAssignmentObjective {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,13 @@
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.index.mapper.NumberFieldMapper;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.LenientlyParsedPreProcessor;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.StrictlyParsedPreProcessor;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.core.ml.utils.NamedXContentObjectHelper;

import java.io.IOException;
import java.util.Arrays;
Expand All @@ -28,6 +32,7 @@
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;

import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
Expand All @@ -42,12 +47,14 @@ public class Regression implements DataFrameAnalysis {
public static final ParseField RANDOMIZE_SEED = new ParseField("randomize_seed");
public static final ParseField LOSS_FUNCTION = new ParseField("loss_function");
public static final ParseField LOSS_FUNCTION_PARAMETER = new ParseField("loss_function_parameter");
public static final ParseField FEATURE_PROCESSORS = new ParseField("feature_processors");

private static final String STATE_DOC_ID_SUFFIX = "_regression_state#1";

private static final ConstructingObjectParser<Regression, Void> LENIENT_PARSER = createParser(true);
private static final ConstructingObjectParser<Regression, Void> STRICT_PARSER = createParser(false);

@SuppressWarnings("unchecked")
private static ConstructingObjectParser<Regression, Void> createParser(boolean lenient) {
ConstructingObjectParser<Regression, Void> parser = new ConstructingObjectParser<>(
NAME.getPreferredName(),
Expand All @@ -59,14 +66,21 @@ private static ConstructingObjectParser<Regression, Void> createParser(boolean l
(Double) a[8],
(Long) a[9],
(LossFunction) a[10],
(Double) a[11]));
(Double) a[11],
(List<PreProcessor>) a[12]));
parser.declareString(constructorArg(), DEPENDENT_VARIABLE);
BoostedTreeParams.declareFields(parser);
parser.declareString(optionalConstructorArg(), PREDICTION_FIELD_NAME);
parser.declareDouble(optionalConstructorArg(), TRAINING_PERCENT);
parser.declareLong(optionalConstructorArg(), RANDOMIZE_SEED);
parser.declareString(optionalConstructorArg(), LossFunction::fromString, LOSS_FUNCTION);
parser.declareDouble(optionalConstructorArg(), LOSS_FUNCTION_PARAMETER);
parser.declareNamedObjects(optionalConstructorArg(),
(p, c, n) -> lenient ?
p.namedObject(LenientlyParsedPreProcessor.class, n, new PreProcessor.PreProcessorParseContext(true)) :
p.namedObject(StrictlyParsedPreProcessor.class, n, new PreProcessor.PreProcessorParseContext(true)),
(regression) -> {/*TODO should we throw if this is not set?*/},
benwtrent marked this conversation as resolved.
Show resolved Hide resolved
FEATURE_PROCESSORS);
return parser;
}

Expand All @@ -90,14 +104,16 @@ public static Regression fromXContent(XContentParser parser, boolean ignoreUnkno
private final long randomizeSeed;
private final LossFunction lossFunction;
private final Double lossFunctionParameter;
private final List<PreProcessor> featureProcessors;

public Regression(String dependentVariable,
BoostedTreeParams boostedTreeParams,
@Nullable String predictionFieldName,
@Nullable Double trainingPercent,
@Nullable Long randomizeSeed,
@Nullable LossFunction lossFunction,
@Nullable Double lossFunctionParameter) {
@Nullable Double lossFunctionParameter,
@Nullable List<PreProcessor> featureProcessors) {
if (trainingPercent != null && (trainingPercent < 1.0 || trainingPercent > 100.0)) {
throw ExceptionsHelper.badRequestException("[{}] must be a double in [1, 100]", TRAINING_PERCENT.getPreferredName());
}
Expand All @@ -112,10 +128,11 @@ public Regression(String dependentVariable,
throw ExceptionsHelper.badRequestException("[{}] must be a positive double", LOSS_FUNCTION_PARAMETER.getPreferredName());
}
this.lossFunctionParameter = lossFunctionParameter;
this.featureProcessors = featureProcessors == null ? Collections.emptyList() : Collections.unmodifiableList(featureProcessors);
}

public Regression(String dependentVariable) {
this(dependentVariable, BoostedTreeParams.builder().build(), null, null, null, null, null);
this(dependentVariable, BoostedTreeParams.builder().build(), null, null, null, null, null, null);
}

public Regression(StreamInput in) throws IOException {
Expand All @@ -126,6 +143,11 @@ public Regression(StreamInput in) throws IOException {
randomizeSeed = in.readOptionalLong();
lossFunction = in.readEnum(LossFunction.class);
lossFunctionParameter = in.readOptionalDouble();
if (in.getVersion().onOrAfter(Version.V_8_0_0)) {
featureProcessors = Collections.unmodifiableList(in.readNamedWriteableList(PreProcessor.class));
} else {
featureProcessors = Collections.emptyList();
}
}

public String getDependentVariable() {
Expand Down Expand Up @@ -156,6 +178,10 @@ public Double getLossFunctionParameter() {
return lossFunctionParameter;
}

public List<PreProcessor> getFeatureProcessors() {
return featureProcessors;
}

@Override
public String getWriteableName() {
return NAME.getPreferredName();
Expand All @@ -170,6 +196,9 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeOptionalLong(randomizeSeed);
out.writeEnum(lossFunction);
out.writeOptionalDouble(lossFunctionParameter);
if (out.getVersion().onOrAfter(Version.V_8_0_0)) {
out.writeNamedWriteableList(featureProcessors);
}
}

@Override
Expand All @@ -190,6 +219,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
if (lossFunctionParameter != null) {
builder.field(LOSS_FUNCTION_PARAMETER.getPreferredName(), lossFunctionParameter);
}
if (featureProcessors.isEmpty() == false) {
NamedXContentObjectHelper.writeNamedObjects(builder, params, true, FEATURE_PROCESSORS.getPreferredName(), featureProcessors);
}
builder.endObject();
return builder;
}
Expand All @@ -207,6 +239,10 @@ public Map<String, Object> getParams(FieldInfo fieldInfo) {
if (lossFunctionParameter != null) {
params.put(LOSS_FUNCTION_PARAMETER.getPreferredName(), lossFunctionParameter);
}
if (featureProcessors.isEmpty() == false) {
params.put(FEATURE_PROCESSORS.getPreferredName(),
featureProcessors.stream().map(p -> Collections.singletonMap(p.getName(), p)).collect(Collectors.toList()));
}
return params;
}

Expand Down Expand Up @@ -290,13 +326,14 @@ public boolean equals(Object o) {
&& trainingPercent == that.trainingPercent
&& randomizeSeed == that.randomizeSeed
&& lossFunction == that.lossFunction
&& Objects.equals(featureProcessors, that.featureProcessors)
&& Objects.equals(lossFunctionParameter, that.lossFunctionParameter);
}

@Override
public int hashCode() {
return Objects.hash(dependentVariable, boostedTreeParams, predictionFieldName, trainingPercent, randomizeSeed, lossFunction,
lossFunctionParameter);
lossFunctionParameter, featureProcessors);
}

public enum LossFunction {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,23 +57,23 @@ public List<NamedXContentRegistry.Entry> getNamedXContentParsers() {

// PreProcessing Lenient
namedXContent.add(new NamedXContentRegistry.Entry(LenientlyParsedPreProcessor.class, OneHotEncoding.NAME,
OneHotEncoding::fromXContentLenient));
(p, c) -> OneHotEncoding.fromXContentLenient(p, (PreProcessor.PreProcessorParseContext) c)));
namedXContent.add(new NamedXContentRegistry.Entry(LenientlyParsedPreProcessor.class, TargetMeanEncoding.NAME,
TargetMeanEncoding::fromXContentLenient));
(p, c) -> TargetMeanEncoding.fromXContentLenient(p, (PreProcessor.PreProcessorParseContext) c)));
namedXContent.add(new NamedXContentRegistry.Entry(LenientlyParsedPreProcessor.class, FrequencyEncoding.NAME,
FrequencyEncoding::fromXContentLenient));
(p, c) -> FrequencyEncoding.fromXContentLenient(p, (PreProcessor.PreProcessorParseContext) c)));
namedXContent.add(new NamedXContentRegistry.Entry(LenientlyParsedPreProcessor.class, CustomWordEmbedding.NAME,
CustomWordEmbedding::fromXContentLenient));
(p, c) -> CustomWordEmbedding.fromXContentLenient(p)));

// PreProcessing Strict
namedXContent.add(new NamedXContentRegistry.Entry(StrictlyParsedPreProcessor.class, OneHotEncoding.NAME,
OneHotEncoding::fromXContentStrict));
(p, c) -> OneHotEncoding.fromXContentStrict(p, (PreProcessor.PreProcessorParseContext) c)));
namedXContent.add(new NamedXContentRegistry.Entry(StrictlyParsedPreProcessor.class, TargetMeanEncoding.NAME,
TargetMeanEncoding::fromXContentStrict));
(p, c) -> TargetMeanEncoding.fromXContentStrict(p, (PreProcessor.PreProcessorParseContext) c)));
namedXContent.add(new NamedXContentRegistry.Entry(StrictlyParsedPreProcessor.class, FrequencyEncoding.NAME,
FrequencyEncoding::fromXContentStrict));
(p, c) -> FrequencyEncoding.fromXContentStrict(p, (PreProcessor.PreProcessorParseContext) c)));
namedXContent.add(new NamedXContentRegistry.Entry(StrictlyParsedPreProcessor.class, CustomWordEmbedding.NAME,
CustomWordEmbedding::fromXContentStrict));
(p, c) -> CustomWordEmbedding.fromXContentStrict(p)));

// Model Lenient
namedXContent.add(new NamedXContentRegistry.Entry(LenientlyParsedTrainedModel.class, Tree.NAME, Tree::fromXContentLenient));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ private static ObjectParser<TrainedModelDefinition.Builder, Void> createParser(b
TRAINED_MODEL);
parser.declareNamedObjects(TrainedModelDefinition.Builder::setPreProcessors,
(p, c, n) -> ignoreUnknownFields ?
p.namedObject(LenientlyParsedPreProcessor.class, n, null) :
p.namedObject(StrictlyParsedPreProcessor.class, n, null),
p.namedObject(LenientlyParsedPreProcessor.class, n, PreProcessor.PreProcessorParseContext.DEFAULT) :
p.namedObject(StrictlyParsedPreProcessor.class, n, PreProcessor.PreProcessorParseContext.DEFAULT),
(trainedModelDefBuilder) -> trainedModelDefBuilder.setProcessorsInOrder(true),
PREPROCESSORS);
return parser;
Expand Down
Loading