Skip to content

Commit

Permalink
Make m_PredictionFieldType field of an enum type
Browse files Browse the repository at this point in the history
  • Loading branch information
przemekwitek committed Dec 5, 2019
1 parent 2d085dc commit becdf58
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 17 deletions.
8 changes: 7 additions & 1 deletion include/api/CDataFrameTrainBoostedTreeClassifierRunner.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,12 @@ namespace api {
class API_EXPORT CDataFrameTrainBoostedTreeClassifierRunner final
: public CDataFrameTrainBoostedTreeRunner {
public:
enum EPredictionFieldType {
E_PredictionFieldTypeString,
E_PredictionFieldTypeInt,
E_PredictionFieldTypeBool
};

static const CDataFrameAnalysisConfigReader& parameterReader();

//! This is not intended to be called directly: use CDataFrameTrainBoostedTreeClassifierRunnerFactory.
Expand Down Expand Up @@ -59,7 +65,7 @@ class API_EXPORT CDataFrameTrainBoostedTreeClassifierRunner final

private:
std::size_t m_NumTopClasses;
std::string m_PredictionFieldType;
EPredictionFieldType m_PredictionFieldType;
};

//! \brief Makes a core::CDataFrame boosted tree classification runner.
Expand Down
33 changes: 21 additions & 12 deletions lib/api/CDataFrameTrainBoostedTreeClassifierRunner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,6 @@ using TSizeVec = std::vector<std::size_t>;
// Configuration
const std::string NUM_TOP_CLASSES{"num_top_classes"};
const std::string PREDICTION_FIELD_TYPE{"prediction_field_type"};
const std::string PREDICTION_FIELD_TYPE_STRING{"string"};
const std::string PREDICTION_FIELD_TYPE_INT{"int"};
const std::string PREDICTION_FIELD_TYPE_BOOL{"bool"};
const std::string BALANCED_CLASS_LOSS{"balanced_class_loss"};

// Output
Expand All @@ -49,10 +46,16 @@ const std::string CLASS_PROBABILITY_FIELD_NAME{"class_probability"};
const CDataFrameAnalysisConfigReader&
CDataFrameTrainBoostedTreeClassifierRunner::parameterReader() {
static const CDataFrameAnalysisConfigReader PARAMETER_READER{[] {
const std::string typeString{"string"};
const std::string typeInt{"int"};
const std::string typeBool{"bool"};
auto theReader = CDataFrameTrainBoostedTreeRunner::parameterReader();
theReader.addParameter(NUM_TOP_CLASSES, CDataFrameAnalysisConfigReader::E_OptionalParameter);
theReader.addParameter(PREDICTION_FIELD_TYPE,
CDataFrameAnalysisConfigReader::E_OptionalParameter);
CDataFrameAnalysisConfigReader::E_OptionalParameter,
{{typeString, int{E_PredictionFieldTypeString}},
{typeInt, int{E_PredictionFieldTypeInt}},
{typeBool, int{E_PredictionFieldTypeBool}}});
theReader.addParameter(BALANCED_CLASS_LOSS,
CDataFrameAnalysisConfigReader::E_OptionalParameter);
return theReader;
Expand All @@ -67,7 +70,7 @@ CDataFrameTrainBoostedTreeClassifierRunner::CDataFrameTrainBoostedTreeClassifier

m_NumTopClasses = parameters[NUM_TOP_CLASSES].fallback(std::size_t{0});
m_PredictionFieldType =
parameters[PREDICTION_FIELD_TYPE].fallback(PREDICTION_FIELD_TYPE_STRING);
parameters[PREDICTION_FIELD_TYPE].fallback(E_PredictionFieldTypeString);
this->boostedTreeFactory().balanceClassTrainingLoss(
parameters[BALANCED_CLASS_LOSS].fallback(true));

Expand Down Expand Up @@ -170,20 +173,26 @@ void CDataFrameTrainBoostedTreeClassifierRunner::writePredictedCategoryValue(
const std::string& categoryValue,
core::CRapidJsonConcurrentLineWriter& writer) const {

if (m_PredictionFieldType == PREDICTION_FIELD_TYPE_INT) {
double doubleValue;
double doubleValue;
switch (m_PredictionFieldType) {
case E_PredictionFieldTypeString:
writer.String(categoryValue);
break;
case E_PredictionFieldTypeInt:
if (core::CStringUtils::stringToType(categoryValue, doubleValue)) {
writer.Int64(static_cast<std::int64_t>(doubleValue));
return;
} else {
writer.String(categoryValue);
}
} else if (m_PredictionFieldType == PREDICTION_FIELD_TYPE_BOOL) {
double doubleValue;
break;
case E_PredictionFieldTypeBool:
if (core::CStringUtils::stringToType(categoryValue, doubleValue)) {
writer.Bool(static_cast<std::int64_t>(doubleValue) == 1.0);
return;
} else {
writer.String(categoryValue);
}
break;
}
writer.String(categoryValue);
}

CDataFrameTrainBoostedTreeClassifierRunner::TLossFunctionUPtr
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,14 @@ void testWriteOneRow(const std::string& dependentVariableField,
"classification", dependentVariableField, rows.size(),
columnNames.size(), 13000000, 0, 0, categoricalColumns)};
rapidjson::Document jsonParameters;
jsonParameters.Parse("{"
" \"dependent_variable\": \"" + dependentVariableField + "\","
" \"prediction_field_type\": \"" + predictionFieldType + "\""
"}");
if (predictionFieldType.empty()) {
jsonParameters.Parse("{\"dependent_variable\": \"" + dependentVariableField + "\"}");
} else {
jsonParameters.Parse("{"
" \"dependent_variable\": \"" + dependentVariableField + "\","
" \"prediction_field_type\": \"" + predictionFieldType + "\""
"}");
}
const auto parameters{
api::CDataFrameTrainBoostedTreeClassifierRunner::parameterReader().read(jsonParameters)};
api::CDataFrameTrainBoostedTreeClassifierRunner runner(*spec, parameters);
Expand Down

0 comments on commit becdf58

Please sign in to comment.