Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use numerically stable one-pass algorithm for log-sum-exp #18

Merged
merged 24 commits into from
Aug 21, 2022

Conversation

devmotion
Copy link
Contributor

A while ago I wrote a Julia implementation of a numerically stable one-pass algorithm for log-sum-exp (original PR: JuliaStats/StatsFuns.jl#97, later extracted to https://github.com/JuliaStats/LogExpFunctions.jl/blob/master/src/logsumexp.jl). This one-pass algorithm is faster (at least for the inputs I benchmarked) and also fixes some underflow issues of the standard implementation of log-sum-exp (that are present in Birch as well). Hence I thought it might be useful for Birch as well and switched log_sum_exp and resample_reduce in the standard library to the one-pass algorithm.

I checked the implementation with the following Birch code: https://gist.github.com/devmotion/a6c3561f6c593160744147e7c5165f62

The output indicates that the PR fixes the underflow issue and improves performance:

❯ birch log_sum_exp_underflow -onepass false
f: log_sum_exp
f([1e-20, log(1e-20)]) = 1.00000000000000e-20 (correct: ~1.999999999999999999985e-20)
❯ birch log_sum_exp_underflow -onepass true
f: log_sum_exp_onepass
f([1e-20, log(1e-20)]) = 2.00000000000000e-20 (correct: ~1.999999999999999999985e-20)
❯ birch resample_reduce_underflow -onepass false
f: resample_reduce
f([1e-20, log(1e-20)]) = (1.0, 1.00000000000000e-20) (correct: (_, ~1.999999999999999999985e-20))
❯ birch resample_reduce_underflow -onepass true
f: resample_reduce_onepass
f([1e-20, log(1e-20)]) = (1.0, 2.00000000000000e-20) (correct: (_, ~1.999999999999999999985e-20))
❯ birch log_sum_exp_timings
current: 7.37634658271839e+00 (result), 1.28010000000000e-05 (time)
onepass: 7.37634658271839e+00 (result), 6.99000000000000e-06 (time)
❯ birch log_sum_exp_timings
current: 7.40357961126728e+00 (result), 1.23700000000000e-05 (time)
onepass: 7.40357961126728e+00 (result), 6.95700000000000e-06 (time)
❯ birch resample_reduce_timings                
current: (3.40619836613369e+02, 7.41604189731166e+00) (result), 1.69120000000000e-05 (time)
onepass: (3.40619836613369e+02, 7.41604189731166e+00) (result), 7.15000000000000e-06 (time)
❯ birch resample_reduce_timings
current: (2.37522674833629e+02, 7.43909556079362e+00) (result), 1.29710000000000e-05 (time)
onepass: (2.37522674833630e+02, 7.43909556079362e+00) (result), 8.29000000000000e-06 (time)

I guess it might be good to add the underflow example and possibly some other tests. Where should they be added?

@lawmurray lawmurray changed the base branch from master to numeric August 2, 2022 02:21
@lawmurray
Copy link
Owner

Looks nice @devmotion, thanks for contributing this. I'm rebasing to the numeric branch, which has been long running, as I intend to merge it back to master soon. That will also run the full test suite, specifically so we can confirm that the log-normalizing constants for the various examples are consistent.

For tests, you can add them to libraries/StandardTest/src/basic on master, or tests/Test/src/basic on numeric. They should be consistent with other tests there, as in exit with a nonzero code on failure, in which case they can output some diagnostics, but otherwise stay silent.

Finally, is there a reference to explain the numerical stability aspect of this to a future reader? The blog post linked in the code comments has a nice explanation of the streaming part, but not the numerical stability part, which I think you may have added.

@lawmurray
Copy link
Owner

lawmurray commented Aug 2, 2022

Normalizing constant tests are failing, e.g. test_z_beta in https://app.circleci.com/pipelines/github/lawmurray/Birch/1185/workflows/3df77928-2ab6-4577-837c-2cf02b77767f/jobs/18530

These tests are new on branch numeric, and don't exist on master, so this was not caught until now. They're all those beginning with test_z. These tests use importance sampling to estimate the normalizing constant of a probability distribution, and use resample_reduce() internally, so likely something is at fault there (possibly introduced by the rebase and merge, will need to investigate).

Note that it's the test-opensuse-bench job on CircleCI that runs the full test suite. On other platforms it's just the smoke tests.

@devmotion
Copy link
Contributor Author

Looks nice @devmotion, thanks for contributing this. I'm rebasing to the numeric branch, which has been long running, as I intend to merge it back to master soon.

Thanks!

For tests, you can add them to libraries/StandardTest/src/basic on master, or tests/Test/src/basic on numeric. They should be consistent with other tests there, as in exit with a nonzero code on failure, in which case they can output some diagnostics, but otherwise stay silent.

OK, I'll add tests when I'm back from vacation next week.

Finally, is there a reference to explain the numerical stability aspect of this to a future reader? The blog post linked in the code comments has a nice explanation of the streaming part, but not the numerical stability part, which I think you may have added.

True, I added that in the Julia implementation but it was not part of the blog post. The idea is to only add the + 1 for the maximum element in the end but not during the reduction. I'll add some explanation.

@devmotion
Copy link
Contributor Author

Normalizing constant tests are failing, e.g. test_z_beta

Oh, I'll check what's the problem there and try to fix it. I wonder if it is caused by the NaN values in z resulting from samples outside of the support of the beta distribution. Maybe one has to deal with NaN a bit more carefully in the reduction.

@codecov
Copy link

codecov bot commented Aug 9, 2022

Codecov Report

Merging #18 (e37aec5) into numeric (92269ac) will increase coverage by 0.03%.
The diff coverage is 88.23%.

@@             Coverage Diff             @@
##           numeric      #18      +/-   ##
===========================================
+ Coverage    81.24%   81.27%   +0.03%     
===========================================
  Files          446      447       +1     
  Lines        18272    18354      +82     
===========================================
+ Hits         14845    14918      +73     
- Misses        3427     3436       +9     
Impacted Files Coverage Δ
tests/Test/src/basic/test_basic_log_sum_exp.birch 79.59% <79.59%> (ø)
libraries/Standard/src/primitive/resample.birch 87.14% <100.00%> (+1.89%) ⬆️
...ard/src/distribution/MultinomialDistribution.birch 61.81% <0.00%> (-3.64%) ⬇️
numbirch/numbirch/array/ArrayShape.hpp 98.60% <0.00%> (+0.04%) ⬆️
numbirch/numbirch/array/Array.hpp 94.98% <0.00%> (+0.22%) ⬆️
birch/src/generate/CppGenerator.cpp 94.85% <0.00%> (+0.36%) ⬆️
birch/src/lexer.lpp 87.68% <0.00%> (+0.72%) ⬆️

📣 We’re building smart automated test selection to slash your CI/CD build times. Learn more

@lawmurray
Copy link
Owner

Thanks for the updates on this @devmotion. I made a few minor tweaks for style (spaces around low precedence operators, rewrote a couple of nested ifs to look a little nicer). I'll be merging numeric back to master shortly, but please stay based on numeric for final changes and I can do another merge to master when you're done to pick up the final changes (e.g. some docs and tests mentioned above). Additional tests, by the way, are not essential, as those normalizing constant tests already cover the functions involved here.

@lawmurray
Copy link
Owner

Also, once you merge updates on numeric into the PR, we should see the FreeBSD and Codecov tests passing. I merged your FreeBSD fix onto numeric and not just master as before, and tweaked the Codecov settings to ignore small changes in coverage (I think!).

@devmotion
Copy link
Contributor Author

Sorry for the slow progress here, I was quite busy last week. I added some explanations and tests (revealed an incorrect result of resample_reduce for empty input vectors, it seems).

Indeed, the if statements were a bit ugly and unsatisfying. The main idea was to perform as few checks as possible and e.g. check for inf only when the next element is larger but not if it is smaller than the current maximum. I guess it's not worth it though (I'm not sure if there are any performance gains at all), and I prefer the simpler and more readable version that you changed it to.

Copy link
Owner

@lawmurray lawmurray left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @devmotion. Looks solid, only issue is program name, see comments. Otherwise, I left a few comments mainly style related, but keen to keep a consistent style throughout the code.

* Test log-sum-exp implementations in `log_sum_exp` and
* `resample_reduce`.
*/
program test_logsumexp() {
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be test_basic_logsumexp to be picked up and run by the grep in smoke.sh and test.sh. Or I'd even suggest test_basic_log_sum_exp here to be consistent with the name of the function.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar throughout, logsumexp -> log_sum_exp.

Copy link
Contributor Author

@devmotion devmotion Aug 16, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I had noticed this problem as well and thought I had pushed a commit that fixes the name. It seems I did not 😄

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, actually I think I renamed it to test_basic_logsumexp in bc6d13a, so maybe you were looking at an older commit? In any case, I'll rename logsumexp to log_sum_exp.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I renamed all occurrences of logsumexp (also in the filename) to log_sum_exp.

tests/Test/src/basic/test_basic_logsumexp.birch Outdated Show resolved Hide resolved
}

// Special cases involving -inf, inf, and nan.
let cases <- [
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! Thanks for including all these edge cases.

tests/Test/src/basic/test_basic_logsumexp.birch Outdated Show resolved Hide resolved
libraries/Standard/src/primitive/resample.birch Outdated Show resolved Hide resolved
libraries/Standard/src/primitive/resample.birch Outdated Show resolved Hide resolved
tests/Test/src/basic/test_basic_logsumexp.birch Outdated Show resolved Hide resolved
libraries/Standard/src/primitive/resample.birch Outdated Show resolved Hide resolved
libraries/Standard/src/primitive/resample.birch Outdated Show resolved Hide resolved
libraries/Standard/src/primitive/resample.birch Outdated Show resolved Hide resolved
@lawmurray
Copy link
Owner

Indeed, the if statements were a bit ugly and unsatisfying. The main idea was to perform as few checks as possible and e.g. check for inf only when the next element is larger but not if it is smaller than the current maximum. I guess it's not worth it though (I'm not sure if there are any performance gains at all), and I prefer the simpler and more readable version that you changed it to.

No worries. Just favoring the neater code, and can expect the compiler to optimize it out here.

@lawmurray lawmurray merged commit e6b5c8a into lawmurray:numeric Aug 21, 2022
@lawmurray
Copy link
Owner

Thanks for all the revisions @devmotion. Looks good. I've merged back to numeric and then master now. Thanks again for all the recent contributions!

@devmotion devmotion deleted the log_sum_exp branch August 21, 2022 05:37
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants