Skip to content

Commit

Permalink
simulate_multinomial
Browse files Browse the repository at this point in the history
  • Loading branch information
devmotion committed Aug 14, 2022
1 parent 3763cd0 commit 35f8f3a
Showing 1 changed file with 26 additions and 25 deletions.
51 changes: 26 additions & 25 deletions tests/Test/src/basic/test_basic_logsumexp.birch
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,15 @@ program test_logsumexp() {

// Compare with common two-pass algorithm
let y <- log_sum_exp_twopass(w);
let w2 <- transform(w, \(x:Real) -> { return 2*w; });
let w2 <- transform(w, \(x:Real) -> { return 2*x; });
let ess <- y*y/log_sum_exp_twopass(w2);
if !check_ess_logsumexp(w, ess, y) {
exit(1);
}

// Check overflow
let v <- transform(w, \(x:Real) -> { return x + 1000.0; });
let v2 <- transform(v, \(x:Real) -> { return 2*w;} );
let v2 <- transform(v, \(x:Real) -> { return 2*x; } );
y <- y + 1000.0;
ess <- y*y/log_sum_exp_twopass(v2);
if !check_ess_logsumexp(w, ess, y) {
Expand All @@ -42,24 +42,24 @@ program test_logsumexp() {
}

// Special cases involving -inf, inf, and nan.
let cases <- [
([-inf, -inf], nan, -inf),
([-inf, nan], nan, -inf),
([nan, -inf], nan, -inf),
([-inf, 42.0], 1.0, 42.0),
([nan, 42.0], 1.0, 42.0),
([42.0, -inf], 1.0, 42.0),
([42.0, nan], 1.0, 42.0),
([-inf, inf], 1.0, inf),
([nan, inf], 1.0, inf),
([42.0, inf], 1.0, inf),
([inf, -inf], 1.0, inf),
([inf, nan], 1.0, inf),
([inf, 42.0], 1.0, inf),
([inf, inf], 1.0, inf),
];
let cases <- [ [-inf, -inf, nan, -inf],
[-inf, nan, nan, -inf],
[nan, -inf, nan, -inf],
[-inf, 42.0, 1.0, 42.0],
[nan, 42.0, 1.0, 42.0],
[42.0, -inf, 1.0, 42.0],
[42.0, nan, 1.0, 42.0],
[-inf, inf, 1.0, inf],
[nan, inf, 1.0, inf],
[42.0, inf, 1.0, inf],
[inf, -inf, 1.0, inf],
[inf, nan, 1.0, inf],
[inf, 42.0, 1.0, inf],
[inf, inf, 1.0, inf] ];
for n in 1..length(cases) {
(w, y, ess) <- cases[n];
w <- cases[1..2,n];
ess <- cases[3,n];
y <- cases[4,n];
if !check_ess_logsumexp(w, ess, y) {
exit(1);
}
Expand All @@ -77,10 +77,10 @@ program test_logsumexp() {
* This implementation uses the common two-pass algorithm
* that avoids overflow.
*/
function log_sum_exp_twopass(w:Real[]) -> Real {
function log_sum_exp_twopass(w:Real[_]) -> Real {
if length(w) > 0 {
let mx <- max(w);
let r <- transform_reduce(x, 0.0, \(x:Real, y:Real) -> { return x + y; },
let r <- transform_reduce(w, 0.0, \(x:Real, y:Real) -> { return x + y; },
\(x:Real) -> { return nan_exp(x - mx); });
return mx + log(r);
} else {
Expand Down Expand Up @@ -116,13 +116,14 @@ function check_ess_logsumexp(w:Real[_], ess_expected:Real, y_expected:Real, relt
let result <- true;

let y <- log_sum_exp(w);
if !approx_equal(y, y_expected) {
stderr.print("log_sum_exp(" + w + ") = " + y_expected " ≉ " + y + "(reltol = " + reltol + ")\n");
if !approx_equal(y, y_expected, reltol) {
stderr.print("log_sum_exp(" + w + ") = " + y + " ≉ " + y_expected + "(reltol = " + reltol + ")\n");
result <- false;
}

let (ess, y) <- resample_reduce(w);
if !approx_equal(ess, ess_expected) || !approx_equal(y, y_expected) {
ess:Real;
(ess, y) <- resample_reduce(w);
if !approx_equal(ess, ess_expected, reltol) || !approx_equal(y, y_expected, reltol) {
stderr.print("resample_reduce(" + w + ") = (" + ess + ", " + y + ") ≉ (" + ess_expected + ", " + y_expected + ") (reltol = " + reltol + ")\n");
result <- false;
}
Expand Down

0 comments on commit 35f8f3a

Please sign in to comment.