Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add VcatAtom #607

Merged
merged 4 commits into from
Apr 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 11 additions & 48 deletions src/atoms/affine/HcatAtom.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,41 +29,15 @@ monotonicity(x::HcatAtom) = ntuple(_ -> Nondecreasing(), length(x.children))

curvature(::HcatAtom) = ConstVexity()

evaluate(x::HcatAtom) = hcat(map(evaluate, x.children)...)
evaluate(x::HcatAtom) = reduce(hcat, collect(map(evaluate, x.children)))

function new_conic_form!(context::Context{T}, x::HcatAtom) where {T}
objectives = map(c -> conic_form!(context, c), AbstractTrees.children(x))
# Suppose the child objectives for two children e1 (2 x 1) and e2 (2 x 2)
# look something like
# e1: x => 1 2 3
# 4 5 6
# y => 2 4
# 7 8
# e2: x => 1 1 1
# 2 2 2
# 3 3 3
# 4 4 4
# The objective of [e1 e2] will look like
# x => 1 2 3
# 4 5 6
# 1 1 1
# 2 2 2
# 3 3 3
# 4 4 4
# y => 2 4
# 7 8
# 0 0
# 0 0
# 0 0
# 0 0
# builds the objective by aggregating a list of coefficients for each
# variable from each child objective, and then vertically concatenating them
return operate(vcat, T, sign(x), objectives...)
args = map(c -> conic_form!(context, c), AbstractTrees.children(x))
# MOI represents matrices by concatenating their columns, so even though
# this is an HcatAtom, we built the conic form by vcat'ing the arguments.
return operate(vcat, T, sign(x), args...)
end
# TODO: fix piracy!

# * `Value` is not owned by Convex.jl
# * splatting creates zero-argument functions, which again are not owned by Convex.jl
Base.hcat(args::AbstractExpr...) = HcatAtom(args...)

function Base.hcat(args::Union{AbstractExpr,Value}...)
Expand All @@ -73,26 +47,15 @@ function Base.hcat(args::Union{AbstractExpr,Value}...)
return HcatAtom(args...)
end

# TODO: implement vertical concatenation in a more efficient way
Base.vcat(args::AbstractExpr...) = transpose(HcatAtom(map(transpose, args)...))

function Base.vcat(args::Union{AbstractExpr,Value}...)
if all(Base.Fix2(isa, Value), args)
return Base.cat(args..., dims = Val(1))
end
return transpose(HcatAtom(map(transpose, args)...))
end

function Base.hvcat(
rows::Tuple{Vararg{Int}},
args::Union{AbstractExpr,Value}...,
)
nbr = length(rows)
rs = Vector{Any}(undef, nbr)
a = 1
for i in 1:nbr
rs[i] = HcatAtom(args[a:a-1+rows[i]]...)
a += rows[i]
output_rows = Vector{HcatAtom}(undef, length(rows))
offset = 0
for (i, n) in enumerate(rows)
output_rows[i] = HcatAtom(args[offset.+(1:n)]...)
offset += n
end
return vcat(rs...)
return vcat(output_rows...)
end
61 changes: 61 additions & 0 deletions src/atoms/affine/VcatAtom.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# Copyright (c) 2014: Madeleine Udell and contributors
#
# Use of this source code is governed by a BSD-style license that can be found
# in the LICENSE file or at https://opensource.org/license/bsd-2-clause

mutable struct VcatAtom <: AbstractExpr
children::Tuple
size::Tuple{Int,Int}

function VcatAtom(args...)
args = convert.(AbstractExpr, args)
num_rows, num_cols = 0, args[1].size[2]
for arg in args
if arg.size[2] != num_cols
msg = "[VcatAtom] cannot stack expressions of incompatible size. Got $(arg.size[2]) expected $num_cols."
throw(DimensionMismatch(msg))
end
num_rows += arg.size[1]
end
return new(args, (num_rows, num_cols))
end
end

head(io::IO, ::VcatAtom) = print(io, "vcat")

Base.sign(x::VcatAtom) = sum(map(sign, x.children))

monotonicity(x::VcatAtom) = ntuple(_ -> Nondecreasing(), length(x.children))

curvature(::VcatAtom) = ConstVexity()

evaluate(x::VcatAtom) = reduce(vcat, collect(map(evaluate, x.children)))

function new_conic_form!(context::Context{T}, x::VcatAtom) where {T}
# Converting a VcatAtom to conic form is non-trivial. Consider two matrices:
# x = [1 3; 2 4]
# y = [5 7; 6 8]
# with VcatAtom(x, y). The desired outcome is [1, 2, 5, 6, 3, 4, 7, 8].
# If we naively convert the children to conic form and then vcat, we will
# get:
# vcat([1, 2, 3, 4], [5, 6, 7, 8]) = [1, 2, 3, 4, 5, 6, 7, 8]
# which is not what we are after. We need to first transpose each child to
# get:
# x^T, y^T = [1 2; 3 4], [5 6; 7 8])
# then hcat them to get:
# hcat(x^T, y^T) = [1 2 5 6; 3 4 7 8]
# then transpose this to get:
# hcat(x^T, y^T)^T = [1 3; 2 4; 5 7; 6 8]
# so our final conic form produces the desired
# [1, 2, 5, 6, 3, 4, 7, 8]
blegat marked this conversation as resolved.
Show resolved Hide resolved
return conic_form!(context, transpose(reduce(hcat, transpose.(x.children))))
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This part is not nice, but I don't know an alternative.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One option is to manually build the operator we want to do:

return reshape(P * vec(x), size(x)...)

Copy link
Collaborator

@ericphanson ericphanson Apr 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually, I think that permutedims_matrix function already does what we need:

julia> x = [1 3; 2 4]
2×2 Matrix{Int64}:
 1  3
 2  4

julia> y = [5 7; 6 8]
2×2 Matrix{Int64}:
 5  7
 6  8

julia> M = permutedims_matrix((size(x, 1),size(x,2),size(y,1)), (1,3,2))
8×8 SparseMatrixCSC{Bool, Int64} with 8 stored entries:
 1              
   1            
         1      
           1    
     1          
       1        
             1  
               1

julia> reshape(M * vcat(vec(x),vec(y)), size(x,1) + size(y,1), size(x,2))
4×2 Matrix{Int64}:
 1  3
 2  4
 5  7
 6  8

Here I'm using the fact that permutedims_matrix is actually the matrix implementation of X -> vec(permutedims(reshape(X, dims), p)) where here dims = (size(x, 1),size(x,2),size(y,1) which corresponds to concatenating x and y in a new 3rd dimension, then (1,3,2) does the transposing business to swap the last 2 dimensions. We end up with a vector of course, but I reshape it to the intended output dimensions to show we got it right.

To actually operate on the vectorized level in Convex IIIC we'd need to do something like z = operate(vcat, x, y), then generate M and apply it on the vectorized level with operate(*, M, z), I think. We don't need to bother with the final reshaping since the dimensions are stored on the AbstractExpr level not the vectorized level.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've kept as-is for now. It wasn't obvious how to generalize this to the n-ary case, and what we currently have works.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can try this refactoring in a separate PR

end

Base.vcat(args::AbstractExpr...) = VcatAtom(args...)

function Base.vcat(args::Union{AbstractExpr,Value}...)
if all(Base.Fix2(isa, Value), args)
return Base.cat(args..., dims = Val(1))

Check warning on line 58 in src/atoms/affine/VcatAtom.jl

View check run for this annotation

Codecov / codecov/patch

src/atoms/affine/VcatAtom.jl#L58

Added line #L58 was not covered by tests
end
return VcatAtom(args...)
end
78 changes: 62 additions & 16 deletions test/test_atoms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -379,22 +379,14 @@ function test_HcatAtom()
x = Variable()
return hcat(x, x)
end
_test_atom(target) do context
x = Variable()
return vcat(x, x)
end
target = """
variables: x1, x2
minobjective: [1.0 * x1, 1.0 * x2, 2.0]
"""
_test_atom(target) do context
x = Variable(2)
x = Variable(1, 2)
y = constant(2)
return vcat(x, y)
end
_test_atom(target) do context
x = Variable(2)
return vcat(x, 2)
return hcat(x, y)
end
_test_atom(target) do context
x = Variable(1, 2)
Expand All @@ -406,12 +398,6 @@ function test_HcatAtom()
),
hcat(Variable(2), constant(2)),
)
@test_throws(
DimensionMismatch(
"[HcatAtom] cannot stack expressions of incompatible size. Got 2 expected 1.",
),
vcat(Variable(2, 1), Variable(1, 2)),
)
return
end

Expand Down Expand Up @@ -731,6 +717,66 @@ function test_SumAtom()
return
end

### affine/VcatAtom

function test_VcatAtom()
target = """
variables: x
minobjective: [1.0 * x, 1.0 * x]
"""
_test_atom(target) do context
x = Variable()
return vcat(x, x)
end
target = """
variables: x1, x2
minobjective: [1.0 * x1, 1.0 * x2, 2.0]
"""
_test_atom(target) do context
x = Variable(2)
y = constant(2)
return vcat(x, y)
end
_test_atom(target) do context
x = Variable(2)
return vcat(x, 2)
end
target = """
variables: x1, x2
minobjective: [1.0 * x1, 2.0, 1.0 * x2, 3.0]
"""
_test_atom(target) do context
x = Variable(1, 2)
y = constant([2 3])
return vcat(x, y)
end
target = """
variables: x1, x2, x3
minobjective: [2.0, 1.0 * x1, 2.0, 3.0, 1.0 * x2, 3.0, 4.0, 1.0 * x3, 4.0]
"""
_test_atom(target) do context
x = Variable(1, 3)
y = constant([2 3 4])
return vcat(y, x, y)
end
target = """
variables: x1, x2, x3, x4
minobjective: [x1, x2, 2.0, x3, x4, 3.0]
"""
_test_atom(target) do context
x = Variable(2, 2)
y = constant([2 3])
return vcat(x, y)
end
@test_throws(
DimensionMismatch(
"[VcatAtom] cannot stack expressions of incompatible size. Got 2 expected 1.",
),
vcat(Variable(2, 1), Variable(1, 2)),
)
return
end

### exp_+_sdp_cone/LogDetAtom

function test_LogDetAtom()
Expand Down
Loading