From 35f8f3ab063a4afb3d2e03cb631e3bb78903b791 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Mon, 15 Aug 2022 00:52:21 +0200 Subject: [PATCH] simulate_multinomial --- .../Test/src/basic/test_basic_logsumexp.birch | 51 ++++++++++--------- 1 file changed, 26 insertions(+), 25 deletions(-) diff --git a/tests/Test/src/basic/test_basic_logsumexp.birch b/tests/Test/src/basic/test_basic_logsumexp.birch index cfe8316be..113eb192d 100644 --- a/tests/Test/src/basic/test_basic_logsumexp.birch +++ b/tests/Test/src/basic/test_basic_logsumexp.birch @@ -12,7 +12,7 @@ program test_logsumexp() { // Compare with common two-pass algorithm let y <- log_sum_exp_twopass(w); - let w2 <- transform(w, \(x:Real) -> { return 2*w; }); + let w2 <- transform(w, \(x:Real) -> { return 2*x; }); let ess <- y*y/log_sum_exp_twopass(w2); if !check_ess_logsumexp(w, ess, y) { exit(1); @@ -20,7 +20,7 @@ program test_logsumexp() { // Check overflow let v <- transform(w, \(x:Real) -> { return x + 1000.0; }); - let v2 <- transform(v, \(x:Real) -> { return 2*w;} ); + let v2 <- transform(v, \(x:Real) -> { return 2*x; } ); y <- y + 1000.0; ess <- y*y/log_sum_exp_twopass(v2); if !check_ess_logsumexp(w, ess, y) { @@ -42,24 +42,24 @@ program test_logsumexp() { } // Special cases involving -inf, inf, and nan. - let cases <- [ - ([-inf, -inf], nan, -inf), - ([-inf, nan], nan, -inf), - ([nan, -inf], nan, -inf), - ([-inf, 42.0], 1.0, 42.0), - ([nan, 42.0], 1.0, 42.0), - ([42.0, -inf], 1.0, 42.0), - ([42.0, nan], 1.0, 42.0), - ([-inf, inf], 1.0, inf), - ([nan, inf], 1.0, inf), - ([42.0, inf], 1.0, inf), - ([inf, -inf], 1.0, inf), - ([inf, nan], 1.0, inf), - ([inf, 42.0], 1.0, inf), - ([inf, inf], 1.0, inf), - ]; + let cases <- [ [-inf, -inf, nan, -inf], + [-inf, nan, nan, -inf], + [nan, -inf, nan, -inf], + [-inf, 42.0, 1.0, 42.0], + [nan, 42.0, 1.0, 42.0], + [42.0, -inf, 1.0, 42.0], + [42.0, nan, 1.0, 42.0], + [-inf, inf, 1.0, inf], + [nan, inf, 1.0, inf], + [42.0, inf, 1.0, inf], + [inf, -inf, 1.0, inf], + [inf, nan, 1.0, inf], + [inf, 42.0, 1.0, inf], + [inf, inf, 1.0, inf] ]; for n in 1..length(cases) { - (w, y, ess) <- cases[n]; + w <- cases[1..2,n]; + ess <- cases[3,n]; + y <- cases[4,n]; if !check_ess_logsumexp(w, ess, y) { exit(1); } @@ -77,10 +77,10 @@ program test_logsumexp() { * This implementation uses the common two-pass algorithm * that avoids overflow. */ -function log_sum_exp_twopass(w:Real[]) -> Real { +function log_sum_exp_twopass(w:Real[_]) -> Real { if length(w) > 0 { let mx <- max(w); - let r <- transform_reduce(x, 0.0, \(x:Real, y:Real) -> { return x + y; }, + let r <- transform_reduce(w, 0.0, \(x:Real, y:Real) -> { return x + y; }, \(x:Real) -> { return nan_exp(x - mx); }); return mx + log(r); } else { @@ -116,13 +116,14 @@ function check_ess_logsumexp(w:Real[_], ess_expected:Real, y_expected:Real, relt let result <- true; let y <- log_sum_exp(w); - if !approx_equal(y, y_expected) { - stderr.print("log_sum_exp(" + w + ") = " + y_expected " ≉ " + y + "(reltol = " + reltol + ")\n"); + if !approx_equal(y, y_expected, reltol) { + stderr.print("log_sum_exp(" + w + ") = " + y + " ≉ " + y_expected + "(reltol = " + reltol + ")\n"); result <- false; } - let (ess, y) <- resample_reduce(w); - if !approx_equal(ess, ess_expected) || !approx_equal(y, y_expected) { + ess:Real; + (ess, y) <- resample_reduce(w); + if !approx_equal(ess, ess_expected, reltol) || !approx_equal(y, y_expected, reltol) { stderr.print("resample_reduce(" + w + ") = (" + ess + ", " + y + ") ≉ (" + ess_expected + ", " + y_expected + ") (reltol = " + reltol + ")\n"); result <- false; }