Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix Nan & Infinity comparison And Value's operator/ & operator% #4893

Merged
merged 2 commits into from
Nov 17, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 8 additions & 38 deletions src/common/datatypes/Value.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2487,12 +2487,7 @@ Value operator/(const Value& lhs, const Value& rhs) {
return lVal / denom;
}
case Value::Type::FLOAT: {
double denom = rhs.getFloat();
if (std::abs(denom) > kEpsilon) {
return lhs.getInt() / denom;
} else {
return Value::kNullDivByZero;
}
return lhs.getInt() / rhs.getFloat();
}
default: {
return Value::kNullBadType;
Expand All @@ -2502,20 +2497,10 @@ Value operator/(const Value& lhs, const Value& rhs) {
case Value::Type::FLOAT: {
switch (rhs.type()) {
case Value::Type::INT: {
int64_t denom = rhs.getInt();
if (denom != 0) {
return lhs.getFloat() / denom;
} else {
return Value::kNullDivByZero;
}
return lhs.getFloat() / rhs.getInt();
}
case Value::Type::FLOAT: {
double denom = rhs.getFloat();
if (std::abs(denom) > kEpsilon) {
return lhs.getFloat() / denom;
} else {
return Value::kNullDivByZero;
}
return lhs.getFloat() / rhs.getFloat();
}
default: {
return Value::kNullBadType;
Expand Down Expand Up @@ -2548,12 +2533,7 @@ Value operator%(const Value& lhs, const Value& rhs) {
}
}
case Value::Type::FLOAT: {
double denom = rhs.getFloat();
if (std::abs(denom) > kEpsilon) {
return std::fmod(lhs.getInt(), denom);
} else {
return Value::kNullDivByZero;
}
return std::fmod(lhs.getInt(), rhs.getFloat());
}
default: {
return Value::kNullBadType;
Expand All @@ -2563,20 +2543,10 @@ Value operator%(const Value& lhs, const Value& rhs) {
case Value::Type::FLOAT: {
switch (rhs.type()) {
case Value::Type::INT: {
int64_t denom = rhs.getInt();
if (denom != 0) {
return std::fmod(lhs.getFloat(), denom);
} else {
return Value::kNullDivByZero;
}
return std::fmod(lhs.getFloat(), rhs.getInt());
}
case Value::Type::FLOAT: {
double denom = rhs.getFloat();
if (std::abs(denom) > kEpsilon) {
return std::fmod(lhs.getFloat(), denom);
} else {
return Value::kNullDivByZero;
}
return std::fmod(lhs.getFloat(), rhs.getFloat());
}
default: {
return Value::kNullBadType;
Expand Down Expand Up @@ -2877,11 +2847,11 @@ bool operator>(const Value& lhs, const Value& rhs) {
}

bool operator<=(const Value& lhs, const Value& rhs) {
return !(rhs < lhs);
return lhs < rhs || lhs == rhs;
}

bool operator>=(const Value& lhs, const Value& rhs) {
return !(lhs < rhs);
return lhs > rhs || lhs == rhs;
}

Value operator&&(const Value& lhs, const Value& rhs) {
Expand Down
4 changes: 4 additions & 0 deletions src/common/datatypes/Value.h
Original file line number Diff line number Diff line change
Expand Up @@ -367,8 +367,12 @@ struct Value {
Value toInt() const;
Value toSet() const;

// Expr use this function instead of operator<, because a Value compare to a Null
// return null instead of true or false
Value lessThan(const Value& v) const;

// Expr use this function instead of operator==, because a Value compare to a Null
// return null instead of true or false
Value equal(const Value& v) const;

// Whether the value can be converted to bool implicitly
Expand Down
106 changes: 104 additions & 2 deletions src/common/datatypes/test/ValueTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
#include <robin_hood.h>
#include <thrift/lib/cpp2/protocol/Serializer.h>

#include <cmath>

#include "common/base/Base.h"
#include "common/datatypes/CommonCpp2Ops.h"
#include "common/datatypes/DataSet.h"
Expand Down Expand Up @@ -598,7 +600,8 @@ TEST(Value, Arithmetics) {
EXPECT_EQ((vFloat1.getFloat() / vFloat2.getFloat()), v.getFloat());

v = vFloat1 / vZero;
EXPECT_EQ(Value::Type::NULLVALUE, v.type());
EXPECT_EQ(Value::Type::FLOAT, v.type());
EXPECT_EQ(std::numeric_limits<double>::infinity(), v.getFloat());
v = vInt1 / vZero;
EXPECT_EQ(Value::Type::NULLVALUE, v.type());
}
Expand Down Expand Up @@ -629,7 +632,8 @@ TEST(Value, Arithmetics) {
EXPECT_EQ(std::fmod(vFloat1.getFloat(), vFloat2.getFloat()), v.getFloat());

v = vFloat1 % vZero;
EXPECT_EQ(Value::Type::NULLVALUE, v.type());
EXPECT_EQ(Value::Type::FLOAT, v.type());
EXPECT_TRUE(std::isnan(v.getFloat()));
v = vInt1 % vZero;
EXPECT_EQ(Value::Type::NULLVALUE, v.type());
}
Expand Down Expand Up @@ -682,6 +686,11 @@ TEST(Value, Comparison) {
Value vInt2(2);
Value vFloat1(3.14);
Value vFloat2(2.67);
Value vFloat3(-2.67);
Value vFloatNaN(0 / 0.0);
Value vFloatPositiveInfinity(1 / 0.0);
Value vFloatNegativeInfinity(-1 / 0.0);

Value vStr1("Hello ");
Value vStr2("World");
Value vBool1(false);
Expand Down Expand Up @@ -811,6 +820,99 @@ TEST(Value, Comparison) {
v = vFloat1 <= vFloat2;
EXPECT_EQ(Value::Type::BOOL, v.type());
EXPECT_EQ(false, v.getBool());

// NaN comparison
// https://en.wikipedia.org/wiki/NaN#Comparison_with_NaN
// Comparison between NaN and any floating-point value x (including NaN and ±Inf)
// Comparison NaN ≥ x NaN ≤ x NaN > x NaN < x NaN = x NaN ≠ x
// Result False False False False False True
v = vFloatNaN >= vFloat1;
EXPECT_EQ(Value::Type::BOOL, v.type());
EXPECT_EQ(false, v.getBool());

v = vFloatNaN > vFloat1;
EXPECT_EQ(Value::Type::BOOL, v.type());
EXPECT_EQ(false, v.getBool());

v = vFloatNaN < vFloat1;
EXPECT_EQ(Value::Type::BOOL, v.type());
EXPECT_EQ(false, v.getBool());

v = vFloatNaN.lessThan(vFloat1);
EXPECT_EQ(Value::Type::BOOL, v.type());
EXPECT_EQ(false, v.getBool());

v = vFloatNaN <= vFloat1;
EXPECT_EQ(Value::Type::BOOL, v.type());
EXPECT_EQ(false, v.getBool());

v = vFloatNaN >= vFloat3;
EXPECT_EQ(Value::Type::BOOL, v.type());
EXPECT_EQ(false, v.getBool());

v = vFloatNaN > vFloat3;
EXPECT_EQ(Value::Type::BOOL, v.type());
EXPECT_EQ(false, v.getBool());

v = vFloatNaN < vFloat3;
EXPECT_EQ(Value::Type::BOOL, v.type());
EXPECT_EQ(false, v.getBool());

v = vFloatNaN <= vFloat3;
EXPECT_EQ(Value::Type::BOOL, v.type());
EXPECT_EQ(false, v.getBool());

// NaN != any Value
v = vFloatNaN != vFloat3;
EXPECT_EQ(Value::Type::BOOL, v.type());
EXPECT_EQ(true, v.getBool());
v = vFloatNaN == vFloat3;
EXPECT_EQ(Value::Type::BOOL, v.type());
EXPECT_EQ(false, v.getBool());
v = vFloatNaN.equal(vFloat3);
EXPECT_EQ(Value::Type::BOOL, v.type());
EXPECT_EQ(false, v.getBool());

v = vFloatNaN.equal(Value(0 / 0.0));
EXPECT_EQ(Value::Type::BOOL, v.type());
EXPECT_EQ(false, v.getBool());
}

{
// -Inf
Value v = vFloatPositiveInfinity.lessThan(vFloatNegativeInfinity);
EXPECT_EQ(Value::Type::BOOL, v.type());
EXPECT_EQ(false, v.getBool());

v = vFloatNegativeInfinity.lessThan(vFloatPositiveInfinity);
EXPECT_EQ(Value::Type::BOOL, v.type());
EXPECT_EQ(true, v.getBool());

v = vFloatNegativeInfinity.lessThan(vInt1);
EXPECT_EQ(Value::Type::BOOL, v.type());
EXPECT_EQ(true, v.getBool());

v = vFloatNegativeInfinity.lessThan(vFloat1);
EXPECT_EQ(Value::Type::BOOL, v.type());
EXPECT_EQ(true, v.getBool());

// +Inf
v = vFloatPositiveInfinity.lessThan(vInt1);
EXPECT_EQ(Value::Type::BOOL, v.type());
EXPECT_EQ(false, v.getBool());

v = vFloatPositiveInfinity.lessThan(vFloat1);
EXPECT_EQ(Value::Type::BOOL, v.type());
EXPECT_EQ(false, v.getBool());

// NaN comparison always false
v = vFloatNegativeInfinity.lessThan(vFloatNaN);
EXPECT_EQ(Value::Type::BOOL, v.type());
EXPECT_EQ(false, v.getBool());

v = vFloatPositiveInfinity.lessThan(vFloatNaN);
EXPECT_EQ(Value::Type::BOOL, v.type());
EXPECT_EQ(false, v.getBool());
}

// int and float
Expand Down
4 changes: 2 additions & 2 deletions src/common/expression/RelationalExpression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,10 @@ const Value& RelationalExpression::eval(ExpressionContext& ctx) {
result_ = lhs.lessThan(rhs) || lhs.equal(rhs);
break;
case Kind::kRelGT:
result_ = !lhs.lessThan(rhs) && !lhs.equal(rhs);
result_ = rhs.lessThan(lhs);
break;
case Kind::kRelGE:
result_ = !lhs.lessThan(rhs) || lhs.equal(rhs);
result_ = rhs.lessThan(lhs) || lhs.equal(rhs);
break;
case Kind::kRelREG: {
if (lhs.isBadNull() || rhs.isBadNull()) {
Expand Down
12 changes: 6 additions & 6 deletions src/common/expression/test/ArithmeticExpressionTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,13 @@ TEST_F(ArithmeticExpressionTest, TestArithmeticExpression) {
TEST_EXPR(11 * 2, 22);
TEST_EXPR(11 * 2.2, 24.2);
TEST_EXPR(100.4 / 4, 25.1);
TEST_EXPR(10.4 % 0, NullType::DIV_BY_ZERO);
TEST_EXPR(10 % 0.0, NullType::DIV_BY_ZERO);
TEST_EXPR(10.4 % 0.0, NullType::DIV_BY_ZERO);
TEST_EXPR(10.4 % 0, Value(std::numeric_limits<double>::quiet_NaN()));
TEST_EXPR(10 % 0.0, Value(std::numeric_limits<double>::quiet_NaN()));
TEST_EXPR(10.4 % 0.0, Value(std::numeric_limits<double>::quiet_NaN()));
TEST_EXPR(10 / 0, NullType::DIV_BY_ZERO);
TEST_EXPR(12 / 0.0, NullType::DIV_BY_ZERO);
TEST_EXPR(187. / 0.0, NullType::DIV_BY_ZERO);
TEST_EXPR(17. / 0, NullType::DIV_BY_ZERO);
TEST_EXPR(12 / 0.0, std::numeric_limits<double>::infinity());
TEST_EXPR(187. / 0.0, std::numeric_limits<double>::infinity());
TEST_EXPR(17. / 0, std::numeric_limits<double>::infinity());
}
{
TEST_EXPR(1 + 2 + 3.2, 6.2);
Expand Down
7 changes: 6 additions & 1 deletion src/common/expression/test/TestBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,12 @@ class ExpressionTest : public ::testing::Test {
Expression *ep = yieldSentence->yield()->yields()->back()->expr();
auto eval = Expression::eval(ep, gExpCtxt);
EXPECT_EQ(eval.type(), expected.type()) << "type check failed: " << ep->toString();
EXPECT_EQ(eval, expected) << "check failed: " << ep->toString();
// NaN is not equals to NaN, check equals should use std::isnan()
if (expected.type() == Value::Type::FLOAT && std::isnan(expected.getFloat())) {
EXPECT_TRUE(std::isnan(eval.getFloat())) << "check failed: " << ep->toString();
} else {
EXPECT_EQ(eval, expected) << "check failed: " << ep->toString();
}
}

void testToString(const std::string &exprSymbol, const char *expected) {
Expand Down
6 changes: 0 additions & 6 deletions src/common/function/FunctionManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -592,16 +592,10 @@ FunctionManager::FunctionManager() {
}
case Value::Type::INT: {
auto val = args[0].get().getInt();
if (val < 0) {
return Value::kNullValue;
}
return std::sqrt(val);
}
case Value::Type::FLOAT: {
auto val = args[0].get().getFloat();
if (val < 0) {
return Value::kNullValue;
}
return std::sqrt(val);
}
default: {
Expand Down
2 changes: 2 additions & 0 deletions src/common/function/test/FunctionManagerTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,8 @@ TEST_F(FunctionManagerTest, functionCall) {
{
TEST_FUNCTION(sqrt, args_["int"], 2.0);
TEST_FUNCTION(sqrt, args_["float"], std::sqrt(1.1));
TEST_FUNCTION(sqrt, {Value(-1)}, std::sqrt(-1));
TEST_FUNCTION(sqrt, {Value(0)}, std::sqrt(0));
}
{
TEST_FUNCTION(cbrt, args_["int"], std::cbrt(4));
Expand Down
Loading