Skip to content

Commit

Permalink
[DROOLS-6627] Verify Scorecard reasoncode when result is null (apache…
Browse files Browse the repository at this point in the history
…#3875)

* [DROOLS-6625] Managing missing "required" input data

* [DROOLS-6625] Managing best-effort conversion of input data

* [DROOLS-6625] Managing invalid values - TODO: integration tests

* [DROOLS-6625] Managing invalid values

* [DROOLS-6625] Managing missing values

* [DROOLS-6625] Validate input data

* [DROOLS-6635] Move testing sources in specific files

* [DROOLS-6625] Fix merge with base branch

* [DROOLS-6625] Fix as per PR suggestion

* [DROOLS-6635] Fix merge

* [DROOLS-6627] Verify Scorecard reasoncode when result is null

* [DROOLS-6635] Fix merge with 7.x

* [DROOLS-6627] Fix merge

* [DROOLS-6627] Fixed as per PR request
  • Loading branch information
gitgabrio committed Oct 6, 2021
1 parent 0100bd0 commit b40751b
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ public abstract class KiePMMLDroolsModel extends KiePMMLModel implements IsDrool
private static final Logger logger = LoggerFactory.getLogger(KiePMMLDroolsModel.class);

private static final AgendaEventListener agendaEventListener = getAgendaEventListener(logger);
private static final long serialVersionUID = 5471400949048174357L;

/**
* Map between the original field name and the generated type.
Expand All @@ -75,6 +76,7 @@ public Object evaluate(final Object knowledgeBase, Map<String, Object> requestDa
String fullClassName = this.getClass().getName();
String packageName = fullClassName.contains(".") ?
fullClassName.substring(0, fullClassName.lastIndexOf('.')) : "";
outputFieldsMap.clear();
KiePMMLSessionUtils.Builder builder = KiePMMLSessionUtils.builder((KieBase) knowledgeBase, name, packageName,
toReturn)
.withObjectsInSession(requestData, fieldTypeMap)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

package org.kie.pmml.models.drools.scorecard.tests;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
Expand All @@ -35,6 +34,8 @@
import org.kie.pmml.api.runtime.PMMLRuntime;
import org.kie.pmml.models.tests.AbstractPMMLTest;

import static org.junit.Assert.assertFalse;

@RunWith(Parameterized.class)
public class SimpleScorecardCategoricalTest extends AbstractPMMLTest {

Expand All @@ -43,7 +44,7 @@ public class SimpleScorecardCategoricalTest extends AbstractPMMLTest {
private static final String TARGET_FIELD = "Score";
private static final String REASON_CODE1_FIELD = "Reason Code 1";
private static final String REASON_CODE2_FIELD = "Reason Code 2";
private static final String[] CATEGORY = new String[] { "classA", "classB", "classC", "classD", "classE", "NA" };
private static final String[] CATEGORY = new String[]{"classA", "classB", "classC", "classD", "classE", "NA"};
private static PMMLRuntime pmmlRuntime;

private String input1;
Expand All @@ -52,15 +53,16 @@ public class SimpleScorecardCategoricalTest extends AbstractPMMLTest {
private String reasonCode1;
private String reasonCode2;

public SimpleScorecardCategoricalTest(String input1, String input2, double score, String reasonCode1, String reasonCode2) {
public SimpleScorecardCategoricalTest(String input1, String input2, double score, String reasonCode1,
String reasonCode2) {
this.input1 = input1;
this.input2 = input2;
this.score = score;
this.reasonCode1 = reasonCode1;
this.reasonCode2 = reasonCode2;
}

@BeforeClass
@BeforeClass
public static void setupClass() {
pmmlRuntime = getPMMLRuntime(FILE_NAME);
}
Expand Down Expand Up @@ -93,6 +95,16 @@ public void testSimpleScorecardCategoricalVerifyNoException() {
getSamples().stream().map(sample -> evaluate(pmmlRuntime, sample, MODEL_NAME)).forEach(Assert::assertNotNull);
}

@Test
public void testSimpleScorecardCategoricalVerifyNoReasonCodeWithoutScore() {
getSamples().stream().map(sample -> evaluate(pmmlRuntime, sample, MODEL_NAME))
.filter(pmml4Result -> pmml4Result.getResultVariables().get(TARGET_FIELD) == null)
.forEach(pmml4Result -> {
assertFalse(pmml4Result.getResultVariables().containsKey(REASON_CODE1_FIELD));
assertFalse(pmml4Result.getResultVariables().containsKey(REASON_CODE2_FIELD));
});
}

private List<Map<String, Object>> getSamples() {
return IntStream.range(0, 10).boxed().map(i -> new HashMap<String, Object>() {{
put("input1", CATEGORY[i % CATEGORY.length]);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,13 @@ public Object evaluate(final Object knowledgeBase, final Map<String, Object> req
if (localTransformations != null) {
derivedFields.addAll(localTransformations.getDerivedFields());
}
outputFieldsMap.clear();
return characteristics.evaluate(defineFunctions, derivedFields, kiePMMLOutputFields, requestData,
outputFieldsMap,
initialScore,
reasonCodeAlgorithm,
useReasonCodes,
baselineScore).orElse(null);
outputFieldsMap,
initialScore,
reasonCodeAlgorithm,
useReasonCodes,
baselineScore).orElse(null);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
import org.kie.pmml.api.runtime.PMMLRuntime;
import org.kie.pmml.models.tests.AbstractPMMLTest;

import static org.junit.Assert.assertFalse;

@RunWith(Parameterized.class)
public class SimpleScorecardCategoricalTest extends AbstractPMMLTest {

Expand All @@ -42,7 +44,7 @@ public class SimpleScorecardCategoricalTest extends AbstractPMMLTest {
private static final String TARGET_FIELD = "Score";
private static final String REASON_CODE1_FIELD = "Reason Code 1";
private static final String REASON_CODE2_FIELD = "Reason Code 2";
private static final String[] CATEGORY = new String[] { "classA", "classB", "classC", "classD", "classE", "NA" };
private static final String[] CATEGORY = new String[]{"classA", "classB", "classC", "classD", "classE", "NA"};
private static PMMLRuntime pmmlRuntime;

private String input1;
Expand All @@ -51,7 +53,8 @@ public class SimpleScorecardCategoricalTest extends AbstractPMMLTest {
private String reasonCode1;
private String reasonCode2;

public SimpleScorecardCategoricalTest(String input1, String input2, double score, String reasonCode1, String reasonCode2) {
public SimpleScorecardCategoricalTest(String input1, String input2, double score, String reasonCode1,
String reasonCode2) {
this.input1 = input1;
this.input2 = input2;
this.score = score;
Expand Down Expand Up @@ -89,7 +92,18 @@ public void testSimpleScorecardCategorical() {

@Test
public void testSimpleScorecardCategoricalVerifyNoException() {
getSamples().stream().map(sample -> evaluate(pmmlRuntime, sample, MODEL_NAME)).forEach(Assert::assertNotNull);
getSamples().stream().map(sample -> evaluate(pmmlRuntime, sample, MODEL_NAME))
.forEach(Assert::assertNotNull);
}

@Test
public void testSimpleScorecardCategoricalVerifyNoReasonCodeWithoutScore() {
getSamples().stream().map(sample -> evaluate(pmmlRuntime, sample, MODEL_NAME))
.filter(pmml4Result -> pmml4Result.getResultVariables().get(TARGET_FIELD) == null)
.forEach(pmml4Result -> {
assertFalse(pmml4Result.getResultVariables().containsKey(REASON_CODE1_FIELD));
assertFalse(pmml4Result.getResultVariables().containsKey(REASON_CODE2_FIELD));
});
}

private List<Map<String, Object>> getSamples() {
Expand All @@ -98,5 +112,4 @@ private List<Map<String, Object>> getSamples() {
put("input2", CATEGORY[Math.abs(CATEGORY.length - i) % CATEGORY.length]);
}}).collect(Collectors.toList());
}

}

0 comments on commit b40751b

Please sign in to comment.