Skip to content

Commit

Permalink
Make tests pass and simplify a bit.
Browse files Browse the repository at this point in the history
  • Loading branch information
dourouc05 committed Jun 20, 2020
1 parent 8e8063d commit dacc801
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 38 deletions.
39 changes: 13 additions & 26 deletions src/macros.jl
Original file line number Diff line number Diff line change
Expand Up @@ -163,61 +163,48 @@ end
# - code for parsing (which must be called first), if needed; otherwise, :().
# - code for building the required constraints, if needed; otherwise, :().
# - the symbol of the variable that replaces the expression (<: VariableRef).
_rewrite_expr(_error, ::Val{:call}, op::Val, args...) = :(), :(), Expr(:call, op, args...)
_rewrite_expr(_error::Function, head::Val, args...) = :(), :(), Expr(head, args...)
_rewrite_expr(_error::Function, ::Val{:call}, ::Val{OP}, args...) where OP = :(), :(), Expr(:call, OP, args...)
_rewrite_expr(_error::Function, ::Val{HEAD}, args...) where HEAD = :(), :(), Expr(HEAD, args...)
_rewrite_expr(_error::Function, e::Symbol) = :(), :(), e
_rewrite_expr(_error::Function, e::Number) = :(), :(), e

_rewrite_expr(_error::Function, e::Expr) = _rewrite_expr(_error, Val(e.head), e.args...)
_rewrite_expr(_error::Function, head::Val{:call}, args...) =
_rewrite_expr(_error, Val(:call), Val(args[1]), args[2:end]...)

function _rewrite_expr(_error::Function, ::Val{:call}, op::Union{Val{:+}, Val{:-}, Val{:*}, Val{:/}}, args...)
parse_code = :()
build_code = :()
new_args = []
for arg in args
p, b, a = _rewrite_expr(a)
p, b, a = _rewrite_expr(_error, arg)
parse_code = :($parse_code; $p)
build_code = :($build_code; $b)
push!(new_args, a)
end
return parse_code, build_code, Expr(:call, op, new_args...)
return parse_code, build_code, Expr(:call, typeof(op).parameters[1], new_args...)
end

_functionize(v::VariableRef) = convert(AffExpr, v)
_functionize(v::AbstractArray{VariableRef}) = _functionize.(v)
_functionize(x) = x
_functionize(::MutableArithmetics.Zero) = 0.0
function parse_one_operator_constraint(_error::Function, vectorized::Bool, sense::Val, lhs, rhs)
parse_code_rhs, build_code_rhs, new_rhs = _rewrite_expr(_error, lhs)
parse_code_lhs, build_code_lhs, new_lhs = _rewrite_expr(_error, rhs)
parse_code_rhs, build_code_rhs, new_rhs = _rewrite_expr(_error, rhs)
parse_code_lhs, build_code_lhs, new_lhs = _rewrite_expr(_error, lhs)

# Simple comparison - move everything to the LHS.
#
# `_functionize` deals with the pathological case where the `lhs` is a `VariableRef`
# and the `rhs` is a summation with no terms. `_build_call` should be passed a
# `GenericAffExpr` or a `GenericQuadExpr`, and not a `VariableRef` as would be the case
# without `_functionize`.

# TODO: bug in MutableArithmetics? The returned code does not find $new_rhs (ERROR: UndefVarError: #642###976 not defined) when only using the first code path.
parse_code, variable = nothing, nothing
if rhs == new_rhs && lhs == new_lhs
if vectorized
func = :($lhs .- $rhs)
else
func = :($lhs - $rhs)
end
variable, parse_code = _MA.rewrite(func)
elseif rhs != new_rhs && lhs == new_lhs
variable, parse_code = _MA.rewrite(lhs)
parse_code = :($parse_code; $variable = _MA.operate!(-, $variable, $new_rhs))
elseif lhs != new_lhs && rhs == new_rhs
variable, parse_code = _MA.rewrite(rhs)
parse_code = :($parse_code; $variable = _MA.operate!(-, $new_lhs, $variable))
if vectorized
func = :($new_lhs .- $new_rhs)
else
@assert rhs != new_rhs
@assert lhs != new_lhs
variable = gensym()
parse_code = :($parse_code; $variable = _MA.operate!(-, $new_lhs, $new_rhs))
func = :($new_lhs - $new_rhs)
end
variable, parse_code = _MA.rewrite(func)

set = sense_to_set(_error, sense)
build_code = _build_call(_error, vectorized, :(_functionize($variable)), set)
Expand Down
19 changes: 7 additions & 12 deletions test/macros.jl
Original file line number Diff line number Diff line change
Expand Up @@ -192,12 +192,7 @@ function custom_function_test(ModelType::Type{<:JuMP.AbstractModel})
end
end

JuMP.expression_to_rewrite(head::Val{:donothing}, var) = true
JuMP.expression_to_rewrite(head::Val{:donothing}, var) = true
function JuMP.rewrite_call_expression(errorf::Function, head::Val{:donothing}, var)
return :(), :(), esc(var)
end

JuMP._rewrite_expr(_error::Function, ::Val{:call}, ::Val{:donothing}, arg) = :(), :(), arg
function build_constraint_test(ModelType::Type{<:JuMP.AbstractModel})
@testset "Extension of @constraint with rewrite_call_expression #2229" begin
@testset "Simple: only the function in rhs/lhs" begin
Expand Down Expand Up @@ -240,8 +235,8 @@ function build_constraint_test(ModelType::Type{<:JuMP.AbstractModel})
cref = @constraint(model, x == 1 + donothing(y))

c = JuMP.constraint_object(cref)
@test JuMP.isequal_canonical(c.func, x - y - 1)
@test c.set == MOI.EqualTo(0.0)
@test JuMP.isequal_canonical(c.func, x - y)
@test c.set == MOI.EqualTo(1.0)

# LHS
model = ModelType()
Expand All @@ -250,8 +245,8 @@ function build_constraint_test(ModelType::Type{<:JuMP.AbstractModel})
cref = @constraint(model, donothing(x) - 1 == y)

c = JuMP.constraint_object(cref)
@test JuMP.isequal_canonical(c.func, x - y - 1)
@test c.set == MOI.EqualTo(0.0)
@test JuMP.isequal_canonical(c.func, x - y)
@test c.set == MOI.EqualTo(1.0)

# Both sides.
model = ModelType()
Expand All @@ -260,8 +255,8 @@ function build_constraint_test(ModelType::Type{<:JuMP.AbstractModel})
cref = @constraint(model, donothing(x) - 0.5 == donothing(y) + 0.5)

c = JuMP.constraint_object(cref)
@test JuMP.isequal_canonical(c.func, x - y - 1)
@test c.set == MOI.EqualTo(0.0)
@test JuMP.isequal_canonical(c.func, x - y)
@test c.set == MOI.EqualTo(1.0)
end
end
end
Expand Down

0 comments on commit dacc801

Please sign in to comment.