Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add dims argument to LogSumExpAtom #692

Merged
merged 5 commits into from
May 27, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 24 additions & 8 deletions src/atoms/LogSumExpAtom.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,26 @@
# Use of this source code is governed by a BSD-style license that can be found
# in the LICENSE file or at https://opensource.org/license/bsd-2-clause

"""
LogSumExpAtom(x::AbstractExpr, dims::Int = 0)
odow marked this conversation as resolved.
Show resolved Hide resolved
odow marked this conversation as resolved.
Show resolved Hide resolved

Represents the expression `log.(sum(exp.(x); dims))`.
"""
mutable struct LogSumExpAtom <: AbstractExpr
children::Tuple{AbstractExpr}
size::Tuple{Int,Int}
dims::Union{Colon,Int}

function LogSumExpAtom(x::AbstractExpr)
function LogSumExpAtom(x::AbstractExpr, dims::Union{Colon,Int} = :)
@assert dims == Colon() || 1 <= dims <= 2
if sign(x) == ComplexSign()
error(
"[LogSumExpAtom] the argument should be real but it's instead complex",
)
end
return new((x,), (1, 1))
m = dims == 2 ? size(x, 1) : 1
n = dims == 1 ? size(x, 2) : 1
return new((x,), (m, n), dims)
end
end

Expand All @@ -28,22 +37,29 @@ curvature(::LogSumExpAtom) = ConvexVexity()
function evaluate(x::LogSumExpAtom)
_x = evaluate(x.children[1])
max_x = maximum(_x)
return max_x + log(sum(exp.(_x .- max_x)))
return max_x .+ log.(sum(exp.(_x .- max_x); x.dims))
odow marked this conversation as resolved.
Show resolved Hide resolved
end

logsumexp(x::AbstractExpr) = LogSumExpAtom(x)
logsumexp(x::AbstractExpr; dims = Colon()) = LogSumExpAtom(x, dims)

function new_conic_form!(context::Context, e::LogSumExpAtom)
# log(sum(exp(x))) <= t <=> sum(exp(x)) <= exp(t) <=> sum(exp(x - t)) <= 1
t = Variable()
z = sum(exp(e.children[1] - t * ones(size(e.children[1]))))
add_constraint!(context, 1 >= z)
x = only(e.children)
t = Variable(size(e))
y = if e.dims == 1 # t is a row-vector
ones(size(x, 1), 1) * t
elseif e.dims == 2 # t is a col-vector
t * ones(1, size(x, 2))
else
t * ones(size(x))
end
add_constraint!(context, 1 >= sum(exp(x - y); dims = e.dims))
return conic_form!(context, t)
end

function logisticloss(e::AbstractExpr)
if length(e) == 1
return logsumexp([e; 0])
end
return sum(logsumexp([e[i]; 0]) for i in 1:length(e))
return sum(logsumexp(hcat(vec(e), zeros(length(e))); dims = 2))
end
43 changes: 42 additions & 1 deletion src/problem_depot/problems/exp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -121,14 +121,55 @@ end
) where {T,test}
y = Variable(5)
p = minimize(logsumexp(y), y >= 1; numeric_type = T)

if test
@test problem_vexity(p) == ConvexVexity()
end
handle_problem!(p)
if test
@test p.optval ≈ log(exp(1) * 5) atol = atol rtol = rtol
end

odow marked this conversation as resolved.
Show resolved Hide resolved
y = Variable(5, 2)
p = minimize(
sum(Convex.logsumexp(y; dims = 1)),
y[:, 1] >= 1,
y[:, 2] >= 2;
numeric_type = T,
)
handle_problem!(p)
if test
@test evaluate(y[:, 1]) ≈ ones(5) atol = atol rtol = rtol
@test evaluate(y[:, 2]) ≈ 2 * ones(5) atol = atol rtol = rtol
@test ≈(
p.optval,
log(exp(1) * 5) + log(exp(2) * 5);
atol = atol,
rtol = rtol,
)
end
p = minimize(logsumexp(y), y[:, 1] >= 1, y[:, 2] >= 2; numeric_type = T)
handle_problem!(p)
if test
@test evaluate(y[:, 1]) ≈ ones(5) atol = atol rtol = rtol
@test evaluate(y[:, 2]) ≈ 2 * ones(5) atol = atol rtol = rtol
@test p.optval ≈ log(exp(1) * 5 + exp(2) * 5) atol = atol rtol = rtol
end

x = Variable(2, 3)
v = Convex.logsumexp(x; dims = 1)
p = minimize(sum(v), x >= [1 2 3; 4 5 6]; numeric_type = T)
handle_problem!(p)
if test
@test evaluate(x) ≈ [1 2 3; 4 5 6] atol = atol rtol = rtol
@test ≈(
evaluate(v),
log.(sum(exp, evaluate(x); dims = 1));
atol = atol,
rtol = rtol,
)
@test vexity(v) == Convex.ConvexVexity()
end
return
end

@add_problem exp function exp_logistic_loss_atom(
Expand Down
50 changes: 46 additions & 4 deletions test/test_atoms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1115,14 +1115,13 @@ function test_LogSumExpAtom()
return logisticloss(Variable())
end
target = """
variables: x1, x1_, t, z1, z2, t_, z1_, z2_
variables: x1, x1_, t, t_, z1, z1_, z2, z2_
minobjective: 1.0 * t + 1.0 * t_
[1.0 + -1.0*z1 + -1.0*z2, 1.0 + -1.0*z1_ + -1.0*z2_] in Nonnegatives(2)
[1.0 * x1 + -1.0 * t, 1.0, 1.0 * z1] in ExponentialCone()
[-1.0 * t, 1.0, 1.0 * z2] in ExponentialCone()
[1.0 * x1_ + -1.0 * t_, 1.0, 1.0 * z1_] in ExponentialCone()
[-1.0 * t, 1.0, 1.0 * z2] in ExponentialCone()
[-1.0 * t_, 1.0, 1.0 * z2_] in ExponentialCone()
[1.0 + -1.0*z1 + -1.0*z2] in Nonnegatives(1)
[1.0 + -1.0*z1_ + -1.0*z2_] in Nonnegatives(1)
"""
_test_atom(target) do context
return logisticloss(Variable(2))
Expand All @@ -1137,6 +1136,49 @@ function test_LogSumExpAtom()
atom = logsumexp(x)
x.value = [1.0 1_000.0]
@test evaluate(atom) ≈ 1_000.0
x = Variable(2, 3)
x.value = [1 2 3; 4 5 6]
@test evaluate(logsumexp(x; dims = :)) ≈ 6.456193316018123
@test ≈(
evaluate(logsumexp(x; dims = 1)),
[4.04859 5.04859 6.04859],
atol = 1e-5,
)
@test ≈(
evaluate(logsumexp(x; dims = 2)),
[3.40760596444438, 6.407605964444381],
atol = 1e-5,
)
target = """
variables: x11, x12, x21, x22, t1, t2, y11, y12, y21, y22
minobjective: [1.0 * t1, 1.0 * t2]
[-1.0 + x11, -2 + x12, -3 + x21, -4 + x22] in Nonnegatives(4)
[1.0 + -1.0 * y11 + -1.0 * y21, 1.0 + -1.0 * y12 + -1.0 * y22] in Nonnegatives(2)
[1.0 * x11 + -1.0 * t1, 1.0, y11] in ExponentialCone()
[1.0 * x12 + -1.0 * t2, 1.0, y12] in ExponentialCone()
[1.0 * x21 + -1.0 * t1, 1.0, y21] in ExponentialCone()
[1.0 * x22 + -1.0 * t2, 1.0, y22] in ExponentialCone()
"""
_test_atom(target) do context
x = Variable(2, 2)
add_constraint!(context, x >= [1 3; 2 4])
return logsumexp(x; dims = 2)
end
target = """
variables: x11, x12, x21, x22, t1, t2, y11, y12, y21, y22
minobjective: [1.0 * t1, 1.0 * t2]
[-1.0 + x11, -2 + x12, -3 + x21, -4 + x22] in Nonnegatives(4)
[1.0 + -1.0 * y11 + -1.0 * y12, 1.0 + -1.0 * y21 + -1.0 * y22] in Nonnegatives(2)
[1.0 * x11 + -1.0 * t1, 1.0, y11] in ExponentialCone()
[1.0 * x12 + -1.0 * t1, 1.0, y12] in ExponentialCone()
[1.0 * x21 + -1.0 * t2, 1.0, y21] in ExponentialCone()
[1.0 * x22 + -1.0 * t2, 1.0, y22] in ExponentialCone()
"""
_test_atom(target) do context
x = Variable(2, 2)
add_constraint!(context, x >= [1 3; 2 4])
return logsumexp(x; dims = 1)
end
return
end

Expand Down
Loading