Skip to content

Commit

Permalink
Merge pull request #2735 from jamescowens/implement_fraction_class_im…
Browse files Browse the repository at this point in the history
…provement

util: Enhance Fraction class overflow resistance
  • Loading branch information
jamescowens committed Feb 8, 2024
2 parents 861eaf3 + 4deec19 commit 0ca17e2
Show file tree
Hide file tree
Showing 2 changed files with 125 additions and 6 deletions.
41 changes: 41 additions & 0 deletions src/test/util_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1385,6 +1385,24 @@ BOOST_AUTO_TEST_CASE(util_Fraction_addition_with_internal_simplification)
BOOST_CHECK_EQUAL(sum.IsSimplified(), true);
}

BOOST_AUTO_TEST_CASE(util_Fraction_addition_with_internal_gcd_simplification)
{
Fraction lhs(1, 6);
Fraction rhs(2, 15);

// gcd(6, 15) = 3, so this really is
//
// 1 * (15/3) + 2 * (6/3) 1 * 5 + 2 * 2 3
// ---------------------- = ------------- = --
// 3 * (6/3) * (15/3) 3 * 2 * 5 10

Fraction sum = lhs + rhs;

BOOST_CHECK_EQUAL(sum.GetNumerator(), 3);
BOOST_CHECK_EQUAL(sum.GetDenominator(), 10);
BOOST_CHECK_EQUAL(sum.IsSimplified(), true);
}

BOOST_AUTO_TEST_CASE(util_Fraction_subtraction)
{
Fraction lhs(2, 3);
Expand Down Expand Up @@ -1421,6 +1439,29 @@ BOOST_AUTO_TEST_CASE(util_Fraction_multiplication_with_internal_simplification)
BOOST_CHECK_EQUAL(product.IsSimplified(), true);
}

BOOST_AUTO_TEST_CASE(util_Fraction_multiplication_with_cross_simplification_overflow_resistance)
{

Fraction lhs(std::numeric_limits<int64_t>::max() - 3, std::numeric_limits<int64_t>::max() - 1, false);
Fraction rhs((std::numeric_limits<int64_t>::max() - 1) / (int64_t) 2, (std::numeric_limits<int64_t>::max() - 3) / (int64_t) 2);

Fraction product;

// This should NOT overflow
bool overflow = false;
try {
product = lhs * rhs;
} catch (std::overflow_error& e) {
overflow = true;
}

BOOST_CHECK_EQUAL(overflow, false);

if (!overflow) {
BOOST_CHECK(product == Fraction(1));
}
}

BOOST_AUTO_TEST_CASE(util_Fraction_division_with_internal_simplification)
{
Fraction lhs(-2, 3);
Expand Down
90 changes: 84 additions & 6 deletions src/util.h
Original file line number Diff line number Diff line change
Expand Up @@ -356,10 +356,60 @@ class Fraction {
return Fraction(overflow_add(slhs.GetNumerator(), srhs.GetNumerator()), slhs.GetDenominator(), true);
}

// Otherwise do the full pattern of getting a common denominator and adding, then simplify...
return Fraction(overflow_add(overflow_mult(slhs.GetNumerator(), srhs.GetDenominator()),
overflow_mult(slhs.GetDenominator(), srhs.GetNumerator())),
overflow_mult(slhs.GetDenominator(), srhs.GetDenominator()),
// Now the more complex case. In general, fraction addition follows this pattern:
//
// a c a * (d/g) + c * (b/g)
// - + - , g = gcd(b, d) => --------------------- where {(b/g), (d/g)} will be elements of the counting numbers.
// b d g * (b/g) * (d/g)
//
// (b/g) and (d/g) are divisible with no remainders precisely because of the definition of gcd.
//
// We have already covered the trivial common denominator case above before bothering to compute the gcd of the
// denominator.
int64_t denom_gcd = std::gcd(slhs.GetDenominator(), srhs.GetDenominator());

// We have two special cases. One is where g = b (i.e. d is actually a multiple of b). In this case,
// the expression simplifies to
//
// a * (d/b) + c
// -------------
// d
if (denom_gcd == slhs.GetDenominator()) {
return Fraction(overflow_add(overflow_mult(slhs.GetNumerator(), srhs.GetDenominator() / slhs.GetDenominator()),
srhs.GetNumerator()),
srhs.GetDenominator(),
true);
}

// The other is where g = d (i.e. b is actually a multiple of d). In this case,
// the expression simplifies to
//
// a + c * (b/d)
// -------------
// b
if (denom_gcd == srhs.GetDenominator()) {
return Fraction(overflow_add(overflow_mult(srhs.GetNumerator(), slhs.GetDenominator() / srhs.GetDenominator()),
slhs.GetNumerator()),
slhs.GetDenominator(),
true);
}

// Otherwise do the full pattern of getting a common denominator (pulling out the gcd of the denominators),
// and adding, then simplify...
//
// This approach is more complex than
//
// a * d + c * b
// -------------
// b * d
//
// but has the advantage of being more resistant to overflows, especially when the two denominators are related by a large
// gcd. In particular in Gridcoin's application with Allocations, the largest denominator of the allocations is 10000, so
// every allocation denominator in reduced form must be divisible evenly into 10000. This means the majority of fraction
// additions will be the two simpler cases above.
return Fraction(overflow_add(overflow_mult(slhs.GetNumerator(), srhs.GetDenominator() / denom_gcd),
overflow_mult(srhs.GetNumerator(), slhs.GetDenominator() / denom_gcd)),
overflow_mult(denom_gcd, overflow_mult(slhs.GetDenominator() / denom_gcd, srhs.GetDenominator() / denom_gcd)),
true);
}

Expand All @@ -385,8 +435,36 @@ class Fraction {
Fraction slhs(*this, true);
Fraction srhs(rhs, true);

return Fraction(overflow_mult(slhs.GetNumerator(), srhs.GetNumerator()),
overflow_mult(slhs.GetDenominator(), srhs.GetDenominator()),
// Gcd's can be used in multiplication for better overflow resistance as well.
//
// Consider
// a c
// - * -, where a/b and c/d are already simplified (i.e. gcd(a, b) = gcd(c, d) = 1.
// b d
//
// We can have g = gcd(a, d) and h = gcd(c, b), which is with the numerators reversed, since multiplication is
// commutative. This means we have
//
// (c / h) (a / g)
// ------- * ------- .
// (b / h) (d / g)
//
// If we form Fraction(c, b, true) and Fraction(a, d, true), the simplification will determine and divide the numerator and
// denominator by h and g respectively.
//
// A specific example is instructive.
//
// 1998 1000 999 1000 1000 999 1 1
// ---- * ---- = ---- * ---- = ---- * --- = - * -
// 2000 999 1000 999 1000 999 1 1
//
// This is a formal form of what grade school teachers called factor cancellation. :).

Fraction sxlhs(srhs.GetNumerator(), slhs.GetDenominator(), true);
Fraction sxrhs(slhs.GetNumerator(), srhs.GetDenominator(), true);

return Fraction(overflow_mult(sxlhs.GetNumerator(), sxrhs.GetNumerator()),
overflow_mult(sxlhs.GetDenominator(), sxrhs.GetDenominator()),
true);
}

Expand Down

0 comments on commit 0ca17e2

Please sign in to comment.