Skip to content

Commit

Permalink
Add onepass algorithm for logsumexp (#97)
Browse files Browse the repository at this point in the history
  • Loading branch information
devmotion committed Sep 23, 2020
1 parent aa87839 commit 59989e1
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 22 deletions.
105 changes: 88 additions & 17 deletions src/basicfuns.jl
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ end
Return `log(exp(x) + exp(y))`, avoiding intermediate overflow/undeflow, and handling non-finite values.
"""
function logaddexp(x::Real, y::Real)
# ensure Δ = 0 if x = y = Inf
# ensure Δ = 0 if x = y = ± Inf
Δ = ifelse(x == y, zero(x - y), abs(x - y))
max(x, y) + log1pexp(-Δ)
end
Expand All @@ -224,28 +224,99 @@ logsubexp(x::Real, y::Real) = max(x, y) + log1mexp(-abs(x - y))
"""
logsumexp(X)
Compute `log(sum(exp, X))`, evaluated avoiding intermediate overflow/undeflow.
Compute `log(sum(exp, X))` in a numerically stable way that avoids intermediate over- and
underflow.
`X` should be an iterator of real numbers. The result is computed using a single pass over
the data.
# References
[Sebastian Nowozin: Streaming Log-sum-exp Computation.](http://www.nowozin.net/sebastian/blog/streaming-log-sum-exp-computation.html)
"""
logsumexp(X) = _logsumexp_onepass(X)

`X` should be an iterator of real numbers.
"""
function logsumexp(X)
logsumexp(X::AbstractArray{<:Real}; dims=:)
Compute `log.(sum(exp.(X); dims=dims))` in a numerically stable way that avoids
intermediate over- and underflow.
The result is computed using a single pass over the data.
# References
[Sebastian Nowozin: Streaming Log-sum-exp Computation.](http://www.nowozin.net/sebastian/blog/streaming-log-sum-exp-computation.html)
"""
logsumexp(X::AbstractArray{<:Real}; dims=:) = _logsumexp(X, dims)

_logsumexp(X::AbstractArray{<:Real}, ::Colon) = _logsumexp_onepass(X)
function _logsumexp(X::AbstractArray{<:Real}, dims)
# Do not use log(zero(eltype(X))) directly to avoid issues with ForwardDiff (#82)
FT = float(eltype(X))
xmax_r = reduce(_logsumexp_onepass_op, X; dims=dims, init=(FT(-Inf), zero(FT)))
return @. first(xmax_r) + log1p(last(xmax_r))
end

function _logsumexp_onepass(X)
# fallback for empty collections
isempty(X) && return log(sum(X))
reduce(logaddexp, X)
return _logsumexp_onepass_result(_logsumexp_onepass_reduce(X, Base.IteratorEltype(X)))
end
function logsumexp(X::AbstractArray{T}; dims=:) where {T<:Real}
# Do not use log(zero(T)) directly to avoid issues with ForwardDiff (#82)
u = reduce(max, X, dims=dims, init=oftype(log(zero(T)), -Inf))
u isa AbstractArray || isfinite(u) || return float(u)
let u=u # avoid https://github.com/JuliaLang/julia/issues/15276
# TODO: remove the branch when JuliaLang/julia#31020 is merged.
if u isa AbstractArray
u .+ log.(sum(exp.(X .- u); dims=dims))
else
u + log(sum(x -> exp(x-u), X))
end
end

# function barrier for reductions with single element and without initial element
_logsumexp_onepass_result(x) = float(x)
_logsumexp_onepass_result((xmax, r)::Tuple) = xmax + log1p(r)

# iterables with known element type
function _logsumexp_onepass_reduce(X, ::Base.HasEltype)
# do not perform type computations if element type is abstract
T = eltype(X)
isconcretetype(T) || return _logsumexp_onepass_reduce(X, Base.EltypeUnknown())

FT = float(T)
return reduce(_logsumexp_onepass_op, X; init=(FT(-Inf), zero(FT)))
end

# iterables without known element type
_logsumexp_onepass_reduce(X, ::Base.EltypeUnknown) = reduce(_logsumexp_onepass_op, X)

## Reductions for one-pass algorithm: avoid expensive multiplications if numbers are reduced

# reduce two numbers
function _logsumexp_onepass_op(x1, x2)
a = x1 == x2 ? zero(x1 - x2) : -abs(x1 - x2)
xmax = x1 > x2 ? oftype(a, x1) : oftype(a, x2)
r = exp(a)
return xmax, r
end

# reduce a number and a partial sum
function _logsumexp_onepass_op(x, (xmax, r)::Tuple)
a = x == xmax ? zero(x - xmax) : -abs(x - xmax)
if x > xmax
_xmax = oftype(a, x)
_r = (r + one(r)) * exp(a)
else
_xmax = oftype(a, xmax)
_r = r + exp(a)
end
return _xmax, _r
end
_logsumexp_onepass_op(xmax_r::Tuple, x) = _logsumexp_onepass_op(x, xmax_r)

# reduce two partial sums
function _logsumexp_onepass_op((xmax1, r1)::Tuple, (xmax2, r2)::Tuple)
a = xmax1 == xmax2 ? zero(xmax1 - xmax2) : -abs(xmax1 - xmax2)
if xmax1 > xmax2
xmax = oftype(a, xmax1)
r = r1 + (r2 + one(r2)) * exp(a)
else
xmax = oftype(a, xmax2)
r = r2 + (r1 + one(r1)) * exp(a)
end
return xmax, r
end

"""
softmax!(r::AbstractArray, x::AbstractArray)
Expand Down
19 changes: 14 additions & 5 deletions test/basicfuns.jl
Original file line number Diff line number Diff line change
Expand Up @@ -96,13 +96,18 @@ end
@test logaddexp(2.0, 3.0) log(exp(2.0) + exp(3.0))
@test logaddexp(10002, 10003) 10000 + logaddexp(2.0, 3.0)

@test logsumexp([1.0, 2.0, 3.0]) 3.40760596444438
@test logsumexp((1.0, 2.0, 3.0)) 3.40760596444438
@test @inferred(logsumexp([1.0])) == 1.0
@test @inferred(logsumexp((x for x in [1.0]))) == 1.0
@test @inferred(logsumexp([1.0, 2.0, 3.0])) 3.40760596444438
@test @inferred(logsumexp((1.0, 2.0, 3.0))) 3.40760596444438
@test logsumexp([1.0, 2.0, 3.0] .+ 1000.) 1003.40760596444438

@test logsumexp([[1.0, 2.0, 3.0] [1.0, 2.0, 3.0] .+ 1000.]; dims=1) [3.40760596444438 1003.40760596444438]
@test logsumexp([[1.0 2.0 3.0]; [1.0 2.0 3.0] .+ 1000.]; dims=2) [3.40760596444438, 1003.40760596444438]
@test logsumexp([[1.0, 2.0, 3.0] [1.0, 2.0, 3.0] .+ 1000.]; dims=[1,2]) [1003.4076059644444]
@test @inferred(logsumexp([[1.0, 2.0, 3.0] [1.0, 2.0, 3.0] .+ 1000.]; dims=1)) [3.40760596444438 1003.40760596444438]
@test @inferred(logsumexp([[1.0 2.0 3.0]; [1.0 2.0 3.0] .+ 1000.]; dims=2)) [3.40760596444438, 1003.40760596444438]
@test @inferred(logsumexp([[1.0, 2.0, 3.0] [1.0, 2.0, 3.0] .+ 1000.]; dims=[1,2])) [1003.4076059644444]

# check underflow
@test logsumexp([1e-20, log(1e-20)]) 2e-20

let cases = [([-Inf, -Inf], -Inf), # correct handling of all -Inf
([-Inf, -Inf32], -Inf), # promotion
Expand Down Expand Up @@ -137,6 +142,10 @@ end
@test isnan(logsumexp([NaN, 9.0]))
@test isnan(logsumexp([NaN, Inf]))
@test isnan(logsumexp([NaN, -Inf]))

# logsumexp with general iterables (issue #63)
xs = range(-500, stop = 10, length = 1000)
@test @inferred(logsumexp(x for x in xs)) == logsumexp(xs)
end

@testset "softmax" begin
Expand Down

0 comments on commit 59989e1

Please sign in to comment.