Skip to content

Commit

Permalink
Write logsumexp in a more generic form (#45)
Browse files Browse the repository at this point in the history
This uses Julia's reduction functionality to improve the `logsumexp`
functionality. It no longer relies on 1-based indexing, is fast
on CuArrays, and there is now a generic fallback for iterable collections.

I've also renamed the 2-arg `logsumexp` to `logaddexp`, since Julia
convention is to use different functions for reductions and
reducers (e.g. `max` vs `maximum`).
  • Loading branch information
simonbyrne authored and andreasnoack committed Jun 26, 2018
1 parent e66dd97 commit ae40bc1
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 20 deletions.
3 changes: 2 additions & 1 deletion src/StatsFuns.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ export
invsoftplus, # alias of logexpm1
log1pmx, # log(1 + x) - x
logmxp1, # log(x) - x + 1
logsumexp, # log(exp(x) + exp(y)) or log(sum(exp(x)))
logaddexp, # log(exp(x) + exp(y))
logsumexp, # log(sum(exp(x)))
softmax, # exp(x_i) / sum(exp(x)), for i
softmax!, # inplace softmax

Expand Down
40 changes: 24 additions & 16 deletions src/basicfuns.jl
Original file line number Diff line number Diff line change
Expand Up @@ -164,34 +164,42 @@ end


"""
logsumexp(x::Real, y::Real)
logaddexp(x::Real, y::Real)
Return `log(exp(x) + exp(y))`, avoiding intermediate overflow/undeflow.
Return `log(exp(x) + exp(y))`, avoiding intermediate overflow/undeflow, and handling non-finite values.
"""
function logsumexp(x::T, y::T) where T<:Real
x == y && abs(x) == Inf && return x
function logaddexp(x::T, y::T) where T<:Real
# x or y is NaN => NaN
# x or y is +Inf => +Inf
# x or y is -Inf => other value
isfinite(x) && isfinite(y) || return max(x,y)
x > y ? x + log1p(exp(y - x)) : y + log1p(exp(x - y))
end
logaddexp(x::Real, y::Real) = logaddexp(promote(x, y)...)

logsumexp(x::Real, y::Real) = logsumexp(promote(x, y)...)
Base.@deprecate logsumexp(x::Real, y::Real) logaddexp(x,y)

"""
logsumexp(x::AbstractArray{T}) where T<:Real
logsumexp(X)
Return `log(sum(exp, x))`, evaluated avoiding intermediate overflow/undeflow.
Compute `log(sum(exp, X))`, evaluated avoiding intermediate overflow/undeflow.
`X` should be an iterator of real numbers.
"""
function logsumexp(x::AbstractArray{T}) where T<:Real
S = typeof(exp(zero(T))) # because of 0.4.0
isempty(x) && return -S(Inf)
u = maximum(x)
abs(u) == Inf && return any(isnan, x) ? S(NaN) : u
s = zero(S)
for i = 1:length(x)
@inbounds s += exp(x[i] - u)
function logsumexp(X)
isempty(X) && return log(sum(X))
reduce(logaddexp, X)
end
function logsumexp(X::AbstractArray{T}) where {T<:Real}
isempty(X) && return log(zero(T))
u = maximum(X)
isfinite(u) || return float(u)
let u=u # avoid https://github.com/JuliaLang/julia/issues/15276
u + log(sum(x -> exp(x-u), X))
end
log(s) + u
end


"""
softmax!(r::AbstractArray, x::AbstractArray)
Expand Down
7 changes: 4 additions & 3 deletions test/basicfuns.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,11 @@ end
end

@testset "logsumexp" begin
@test logsumexp(2.0, 3.0) log(exp(2.0) + exp(3.0))
@test logsumexp(10002, 10003) 10000 + logsumexp(2.0, 3.0)
@test logaddexp(2.0, 3.0) log(exp(2.0) + exp(3.0))
@test logaddexp(10002, 10003) 10000 + logaddexp(2.0, 3.0)

@test logsumexp([1.0, 2.0, 3.0]) 3.40760596444438
@test logsumexp((1.0, 2.0, 3.0)) 3.40760596444438
@test logsumexp([1.0, 2.0, 3.0] .+ 1000.) 1003.40760596444438

let cases = [([-Inf, -Inf], -Inf), # correct handling of all -Inf
Expand All @@ -76,8 +77,8 @@ end
([NaN, -Inf], NaN), # NaN propagation
([0, 0], log(2.0))] # non-float arguments
for (arguments, result) in cases
@test logaddexp(arguments...) result
@test logsumexp(arguments) result
@test logsumexp(arguments...) result
end
end
end
Expand Down

0 comments on commit ae40bc1

Please sign in to comment.