Skip to content

Commit

Permalink
Add loglogistic, logitexp, log1mlogistic and logit1mexp
Browse files Browse the repository at this point in the history
This takes advantage of `LogExpFunctions`'s accurate
implementations of `log1pexp` and `log1mexp`, combined with negation
in the log-odds domain to provide more accurate and less expensive
implementations of the function compositions.
  • Loading branch information
andrewjradcliffe committed May 3, 2024
1 parent c81ab8f commit 8fa3801
Show file tree
Hide file tree
Showing 3 changed files with 143 additions and 1 deletion.
3 changes: 2 additions & 1 deletion src/LogExpFunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ import LinearAlgebra

export xlogx, xlogy, xlog1py, xexpx, xexpy, logistic, logit, log1psq, log1pexp, log1mexp, log2mexp, logexpm1,
softplus, invsoftplus, log1pmx, logmxp1, logaddexp, logsubexp, logsumexp, logsumexp!, softmax,
softmax!, logcosh, logabssinh, cloglog, cexpexp
softmax!, logcosh, logabssinh, cloglog, cexpexp,
loglogistic, logitexp, log1mlogistic, logit1mexp

# expm1(::Float16) is not defined in older Julia versions,
# hence for better Float16 support we use an internal function instead
Expand Down
69 changes: 69 additions & 0 deletions src/basicfuns.jl
Original file line number Diff line number Diff line change
Expand Up @@ -441,3 +441,72 @@ $(SIGNATURES)
Compute the complementary double exponential, `1 - exp(-exp(x))`.
"""
cexpexp(x) = -_expm1(-exp(x))

#=
this uses the identity:
log(logistic(x)) = -log(1 + exp(-x))
=#
"""
loglogistic(x)
The natural logarithm of the `logistic` function, computed more
carefully and with fewer calls than than the composition
`log(logistic(x))`.
Its inverse is the [`logitexp`](@ref) function.
"""
loglogistic(x::AbstractFloat) = -log1pexp(-x) #
loglogistic(x::T) where {T<:Real} = -log1pexp(-convert(promote_type(Float64, T), x))

#=
this uses the identity:
logit(exp(x)) = log(exp(x) / (1 + exp(x))) = log(exp(x)) - log(1 - exp(x))
=#
"""
logitexp(x)
The logit of the exponential of `x`, computed more carefully and
with fewer function calls than `logit(exp(x))`
Its inverse is the [`loglogistic`](@ref) function.
"""
logitexp(x::Real) = x - log1mexp(x)

#=
this uses the identity:
log(logistic(-x)) = -log(1 + exp(x))
that is, negation in the log-odds domain.
=#

"""
log1mlogistic(x)
The natural logarithm of the 1 minus the inverse logit function,
computed more carefully and with fewer function calls than `log(1 -
logistic(x))`.
Its inverse is the [`logit1mexp`](@ref) function.
"""
log1mlogistic(x::Real) = -log1pexp(x)

#=
this uses the same identity as `logitexp`, followed by negation on the
log-odds scale, i.e. -logit(exp(x)) = log(1 - exp(x)) - log(exp(x))
=#

"""
logit1mexp(x)
The logit of 1 minus the exponential of `x`, computed more carefully
and with fewer function calls than `logit(1 - exp(x))`.
Its inverse is the [`log1mlogistic`](@ref) function.
"""
logit1mexp(x::Real) = log1mexp(x) - x
72 changes: 72 additions & 0 deletions test/basicfuns.jl
Original file line number Diff line number Diff line change
Expand Up @@ -468,3 +468,75 @@ end
@test cexpexp(-Inf) == 0.0
@test cexpexp(0) == (ℯ - 1) /
end

@testset "loglogistic: $T" for T in (Float16, Float32, Float64)
lim1 = T === Float16 ? -14.0 : -50.0
lim2 = T === Float16 ? -10.0 : -37.0
xs = T[Inf, -Inf, 0.0, lim1, lim2]
for x in xs
@test loglogistic(x) == log(logistic(x))
end

ϵ = eps(T)
xs = T[ϵ, 1.0, 18.0, 33.3, 50.0]
for x in xs
lhs = loglogistic(x)
rhs = log(logistic(x))
@test abs(lhs - rhs) < ϵ
end

# misc
@test loglogistic(T(Inf)) == -zero(T)
@test loglogistic(-T(Inf)) == -T(Inf)
@test loglogistic(-T(103.0)) == -T(103.0)
@test abs(loglogistic(T(35.0))) < 3eps(T)
@test abs(loglogistic(T(103.0))) < eps(T)
@test isfinite(loglogistic(-T(745.0)))
@test isfinite(loglogistic(T(50.0)))
@test isfinite(loglogistic(T(745.0)))
end


@testset "logitexp: $T" for T in (Float16, Float32, Float64)
ϵ = eps(T)
xs = T[ϵ, ϵ, 0.2, 0.4, 0.8, 1.0 - ϵ, 1.0 - ϵ]
neg_xs = -xs
for x in xs
@test abs(logitexp(loglogistic(x)) - x) < ϵ
end
for x in neg_xs
@test abs(logitexp(loglogistic(x)) - x) < 2ϵ
end
xs = T[-Inf, 0.0, Inf]
for x in xs
@test logitexp(loglogistic(x)) == x
end
end

@testset "log1mlogistic: $T" for T in (Float16, Float32, Float64)
@test log1mlogistic(T(Inf)) == -T(Inf)
@test log1mlogistic(-T(Inf)) == -zero(T)
@test log1mlogistic(-T(103.0)) < eps(T)
@test abs(log1mlogistic(T(35.0))) == T(35.0)
@test abs(log1mlogistic(T(103.0))) == T(103.0)
@test isfinite(log1mlogistic(-T(745.0)))
@test isfinite(log1mlogistic(T(50.0)))
@test isfinite(log1mlogistic(T(745.0)))
end


@testset "logit1mexp: $T" for T in (Float16, Float32, Float64)
ϵ = eps(T)
xs = T[ϵ, ϵ, 0.2, 0.4, 0.8, 1.0 - ϵ, 1.0 - ϵ]
neg_xs = -xs
for x in xs
@test abs(logit1mexp(log1mlogistic(x)) - x) < 2ϵ
end
for x in neg_xs
@test abs(logit1mexp(log1mlogistic(x)) - x) < ϵ
end
xs = T[-Inf, 0.0, Inf]
for x in xs
@test logit1mexp(log1mlogistic(x)) == x
end
end

0 comments on commit 8fa3801

Please sign in to comment.