Skip to content

Commit

Permalink
Add dims argument to LogSumExpAtom (#692)
Browse files Browse the repository at this point in the history
  • Loading branch information
odow committed May 27, 2024
1 parent 2d863c7 commit 459f86b
Show file tree
Hide file tree
Showing 3 changed files with 111 additions and 13 deletions.
32 changes: 24 additions & 8 deletions src/atoms/LogSumExpAtom.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,26 @@
# Use of this source code is governed by a BSD-style license that can be found
# in the LICENSE file or at https://opensource.org/license/bsd-2-clause

"""
LogSumExpAtom(x::AbstractExpr, dims::Union{Colon,Int} = :)
Represents the expression `log.(sum(exp.(x); dims))`.
"""
mutable struct LogSumExpAtom <: AbstractExpr
children::Tuple{AbstractExpr}
size::Tuple{Int,Int}
dims::Union{Colon,Int}

function LogSumExpAtom(x::AbstractExpr)
function LogSumExpAtom(x::AbstractExpr, dims::Union{Colon,Int} = :)
@assert dims == Colon() || 1 <= dims <= 2
if sign(x) == ComplexSign()
error(
"[LogSumExpAtom] the argument should be real but it's instead complex",
)
end
return new((x,), (1, 1))
m = dims == 2 ? size(x, 1) : 1
n = dims == 1 ? size(x, 2) : 1
return new((x,), (m, n), dims)
end
end

Expand All @@ -28,22 +37,29 @@ curvature(::LogSumExpAtom) = ConvexVexity()
function evaluate(x::LogSumExpAtom)
_x = evaluate(x.children[1])
max_x = maximum(_x)
return max_x + log(sum(exp.(_x .- max_x)))
return max_x .+ log.(sum(exp.(_x .- max_x); x.dims))
end

logsumexp(x::AbstractExpr) = LogSumExpAtom(x)
logsumexp(x::AbstractExpr; dims = Colon()) = LogSumExpAtom(x, dims)

function new_conic_form!(context::Context, e::LogSumExpAtom)
# log(sum(exp(x))) <= t <=> sum(exp(x)) <= exp(t) <=> sum(exp(x - t)) <= 1
t = Variable()
z = sum(exp(e.children[1] - t * ones(size(e.children[1]))))
add_constraint!(context, 1 >= z)
x = only(e.children)
t = Variable(size(e))
y = if e.dims == 1 # t is a row-vector
ones(size(x, 1), 1) * t
elseif e.dims == 2 # t is a col-vector
t * ones(1, size(x, 2))
else
t * ones(size(x))
end
add_constraint!(context, 1 >= sum(exp(x - y); dims = e.dims))
return conic_form!(context, t)
end

function logisticloss(e::AbstractExpr)
if length(e) == 1
return logsumexp([e; 0])
end
return sum(logsumexp([e[i]; 0]) for i in 1:length(e))
return sum(logsumexp(hcat(vec(e), zeros(length(e))); dims = 2))
end
42 changes: 41 additions & 1 deletion src/problem_depot/problems/exp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -121,14 +121,54 @@ end
) where {T,test}
y = Variable(5)
p = minimize(logsumexp(y), y >= 1; numeric_type = T)

if test
@test problem_vexity(p) == ConvexVexity()
end
handle_problem!(p)
if test
@test p.optval log(exp(1) * 5) atol = atol rtol = rtol
end
y = Variable(5, 2)
p = minimize(
sum(Convex.logsumexp(y; dims = 1)),
y[:, 1] >= 1,
y[:, 2] >= 2;
numeric_type = T,
)
handle_problem!(p)
if test
@test evaluate(y[:, 1]) ones(5) atol = atol rtol = rtol
@test evaluate(y[:, 2]) 2 * ones(5) atol = atol rtol = rtol
@test (
p.optval,
log(exp(1) * 5) + log(exp(2) * 5);
atol = atol,
rtol = rtol,
)
end
p = minimize(logsumexp(y), y[:, 1] >= 1, y[:, 2] >= 2; numeric_type = T)
handle_problem!(p)
if test
@test evaluate(y[:, 1]) ones(5) atol = atol rtol = rtol
@test evaluate(y[:, 2]) 2 * ones(5) atol = atol rtol = rtol
@test p.optval log(exp(1) * 5 + exp(2) * 5) atol = atol rtol = rtol
end

x = Variable(2, 3)
v = Convex.logsumexp(x; dims = 1)
p = minimize(sum(v), x >= [1 2 3; 4 5 6]; numeric_type = T)
handle_problem!(p)
if test
@test evaluate(x) [1 2 3; 4 5 6] atol = atol rtol = rtol
@test (
evaluate(v),
log.(sum(exp, evaluate(x); dims = 1));
atol = atol,
rtol = rtol,
)
@test vexity(v) == Convex.ConvexVexity()
end
return
end

@add_problem exp function exp_logistic_loss_atom(
Expand Down
50 changes: 46 additions & 4 deletions test/test_atoms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1115,14 +1115,13 @@ function test_LogSumExpAtom()
return logisticloss(Variable())
end
target = """
variables: x1, x1_, t, z1, z2, t_, z1_, z2_
variables: x1, x1_, t, t_, z1, z1_, z2, z2_
minobjective: 1.0 * t + 1.0 * t_
[1.0 + -1.0*z1 + -1.0*z2, 1.0 + -1.0*z1_ + -1.0*z2_] in Nonnegatives(2)
[1.0 * x1 + -1.0 * t, 1.0, 1.0 * z1] in ExponentialCone()
[-1.0 * t, 1.0, 1.0 * z2] in ExponentialCone()
[1.0 * x1_ + -1.0 * t_, 1.0, 1.0 * z1_] in ExponentialCone()
[-1.0 * t, 1.0, 1.0 * z2] in ExponentialCone()
[-1.0 * t_, 1.0, 1.0 * z2_] in ExponentialCone()
[1.0 + -1.0*z1 + -1.0*z2] in Nonnegatives(1)
[1.0 + -1.0*z1_ + -1.0*z2_] in Nonnegatives(1)
"""
_test_atom(target) do context
return logisticloss(Variable(2))
Expand All @@ -1137,6 +1136,49 @@ function test_LogSumExpAtom()
atom = logsumexp(x)
x.value = [1.0 1_000.0]
@test evaluate(atom) 1_000.0
x = Variable(2, 3)
x.value = [1 2 3; 4 5 6]
@test evaluate(logsumexp(x; dims = :)) 6.456193316018123
@test (
evaluate(logsumexp(x; dims = 1)),
[4.04859 5.04859 6.04859],
atol = 1e-5,
)
@test (
evaluate(logsumexp(x; dims = 2)),
[3.40760596444438, 6.407605964444381],
atol = 1e-5,
)
target = """
variables: x11, x12, x21, x22, t1, t2, y11, y12, y21, y22
minobjective: [1.0 * t1, 1.0 * t2]
[-1.0 + x11, -2 + x12, -3 + x21, -4 + x22] in Nonnegatives(4)
[1.0 + -1.0 * y11 + -1.0 * y21, 1.0 + -1.0 * y12 + -1.0 * y22] in Nonnegatives(2)
[1.0 * x11 + -1.0 * t1, 1.0, y11] in ExponentialCone()
[1.0 * x12 + -1.0 * t2, 1.0, y12] in ExponentialCone()
[1.0 * x21 + -1.0 * t1, 1.0, y21] in ExponentialCone()
[1.0 * x22 + -1.0 * t2, 1.0, y22] in ExponentialCone()
"""
_test_atom(target) do context
x = Variable(2, 2)
add_constraint!(context, x >= [1 3; 2 4])
return logsumexp(x; dims = 2)
end
target = """
variables: x11, x12, x21, x22, t1, t2, y11, y12, y21, y22
minobjective: [1.0 * t1, 1.0 * t2]
[-1.0 + x11, -2 + x12, -3 + x21, -4 + x22] in Nonnegatives(4)
[1.0 + -1.0 * y11 + -1.0 * y12, 1.0 + -1.0 * y21 + -1.0 * y22] in Nonnegatives(2)
[1.0 * x11 + -1.0 * t1, 1.0, y11] in ExponentialCone()
[1.0 * x12 + -1.0 * t1, 1.0, y12] in ExponentialCone()
[1.0 * x21 + -1.0 * t2, 1.0, y21] in ExponentialCone()
[1.0 * x22 + -1.0 * t2, 1.0, y22] in ExponentialCone()
"""
_test_atom(target) do context
x = Variable(2, 2)
add_constraint!(context, x >= [1 3; 2 4])
return logsumexp(x; dims = 1)
end
return
end

Expand Down

0 comments on commit 459f86b

Please sign in to comment.