Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add rewrite_call_expression through _rewrite_expr. #2241

Closed
wants to merge 14 commits into from
39 changes: 35 additions & 4 deletions src/macros.jl
Original file line number Diff line number Diff line change
Expand Up @@ -159,25 +159,56 @@ function parse_one_operator_constraint(_error::Function, args...)
_unknown_constraint_expr(_error)
end

# When `_rewrite_expr` is called with an expression, it returns three things:
# - code for parsing (which must be called first), if needed; otherwise, :().
# - code for building the required constraints, if needed; otherwise, :().
# - the symbol of the variable that replaces the expression (<: VariableRef).
_rewrite_expr(_error::Function, ::Val{:call}, ::Val{OP}, args...) where OP = :(), :(), Expr(:call, OP, args...)
_rewrite_expr(_error::Function, ::Val{HEAD}, args...) where HEAD = :(), :(), Expr(HEAD, args...)
_rewrite_expr(_error::Function, e::Symbol) = :(), :(), e
_rewrite_expr(_error::Function, e::Number) = :(), :(), e

_rewrite_expr(_error::Function, e::Expr) = _rewrite_expr(_error, Val(e.head), e.args...)
_rewrite_expr(_error::Function, head::Val{:call}, args...) =
_rewrite_expr(_error, Val(:call), Val(args[1]), args[2:end]...)

function _rewrite_expr(_error::Function, ::Val{:call}, op::Union{Val{:+}, Val{:-}, Val{:*}, Val{:/}}, args...)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why have a method with OP above which does not rewrite recursively and this one which rewrite recursively ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The previous one is a fallback for unimplemented operators. I don't know if it's the most sensible default: I try to enable this code as rarely as possible, to avoid breaking things.

parse_code = :()
build_code = :()
new_args = []
for arg in args
p, b, a = _rewrite_expr(_error, arg)
parse_code = :($parse_code; $p)
build_code = :($build_code; $b)
push!(new_args, a)
end
return parse_code, build_code, Expr(:call, typeof(op).parameters[1], new_args...)
end

_functionize(v::VariableRef) = convert(AffExpr, v)
_functionize(v::AbstractArray{VariableRef}) = _functionize.(v)
_functionize(x) = x
_functionize(::MutableArithmetics.Zero) = 0.0
function parse_one_operator_constraint(_error::Function, vectorized::Bool, sense::Val, lhs, rhs)
parse_code_rhs, build_code_rhs, new_rhs = _rewrite_expr(_error, rhs)
parse_code_lhs, build_code_lhs, new_lhs = _rewrite_expr(_error, lhs)

# Simple comparison - move everything to the LHS.
#
# `_functionize` deals with the pathological case where the `lhs` is a `VariableRef`
# and the `rhs` is a summation with no terms. `_build_call` should be passed a
# `GenericAffExpr` or a `GenericQuadExpr`, and not a `VariableRef` as would be the case
# without `_functionize`.
if vectorized
func = :($lhs .- $rhs)
func = :($new_lhs .- $new_rhs)
else
func = :($lhs - $rhs)
func = :($new_lhs - $new_rhs)
end
set = sense_to_set(_error, sense)
variable, parse_code = _MA.rewrite(func)
return parse_code, _build_call(_error, vectorized, :(_functionize($variable)), set)

set = sense_to_set(_error, sense)
build_code = _build_call(_error, vectorized, :(_functionize($variable)), set)
return :($parse_code_lhs; $parse_code_rhs; $parse_code), :($build_code_lhs; $build_code_rhs; $build_code)
end

function parse_constraint_expr(_error::Function, expr::Expr)
Expand Down
102 changes: 102 additions & 0 deletions test/macros.jl
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,107 @@ function custom_function_test(ModelType::Type{<:JuMP.AbstractModel})
end
end

JuMP._rewrite_expr(_error::Function, ::Val{:call}, ::Val{:donothing}, arg) = :(), :(), arg
JuMP._rewrite_expr(_error::Function, ::Val{:call}, ::Val{:&}, arg1, arg2) = :(), :(), arg1
JuMP._rewrite_expr(_error::Function, ::Val{:call}, ::Val{:|}, arg1, arg2) = :(), :(), arg2
function build_constraint_test(ModelType::Type{<:JuMP.AbstractModel})
@testset "Extension of @constraint with rewrite_call_expression #2229" begin
@testset "Simple: only the function in rhs/lhs" begin
# RHS
model = ModelType()
@variable(model, x)
@variable(model, y)
cref = @constraint(model, x == donothing(y))

c = JuMP.constraint_object(cref)
@test JuMP.isequal_canonical(c.func, x - y)
@test c.set == MOI.EqualTo(0.0)

# LHS
model = ModelType()
@variable(model, x)
@variable(model, y)
cref = @constraint(model, donothing(x) == y)

c = JuMP.constraint_object(cref)
@test JuMP.isequal_canonical(c.func, x - y)
@test c.set == MOI.EqualTo(0.0)

# Both sides.
model = ModelType()
@variable(model, x)
@variable(model, y)
cref = @constraint(model, donothing(x) == donothing(y))

c = JuMP.constraint_object(cref)
@test JuMP.isequal_canonical(c.func, x - y)
@test c.set == MOI.EqualTo(0.0)
end

@testset "Complex: rewrite within the rhs/lhs expressions" begin
# RHS
model = ModelType()
@variable(model, x)
@variable(model, y)
cref = @constraint(model, x == 1 + donothing(y))

c = JuMP.constraint_object(cref)
@test JuMP.isequal_canonical(c.func, x - y)
@test c.set == MOI.EqualTo(1.0)

# LHS
model = ModelType()
@variable(model, x)
@variable(model, y)
cref = @constraint(model, donothing(x) - 1 == y)

c = JuMP.constraint_object(cref)
@test JuMP.isequal_canonical(c.func, x - y)
@test c.set == MOI.EqualTo(1.0)

# Both sides.
model = ModelType()
@variable(model, x)
@variable(model, y)
cref = @constraint(model, donothing(x) - 0.5 == donothing(y) + 0.5)

c = JuMP.constraint_object(cref)
@test JuMP.isequal_canonical(c.func, x - y)
@test c.set == MOI.EqualTo(1.0)
end

@testset "Binary logical operators" begin
# Simple case.
model = ModelType()
@variable(model, x)
@variable(model, y)
cref = @constraint(model, x & y == x | y)
# \ x / \ y /

c = JuMP.constraint_object(cref)
@test JuMP.isequal_canonical(c.func, x - y)
@test c.set == MOI.EqualTo(0.0)

# Complex case. Not for custom models because of _functionize does
# not return a GenericAffExpr for MyVariableRef arguments.
if ModelType == JuMP.Model
model = ModelType()
@variable(model, w)
@variable(model, x)
@variable(model, y)
@variable(model, z)
cref = @constraint(model, (w & x & y) | z == true)
# \___ w ___/ /
# \____ z ____/

c = JuMP.constraint_object(cref)
@test JuMP.isequal_canonical(c.func, JuMP._functionize(z))
@test c.set == MOI.EqualTo(1.0)
end
end
end
end

function macros_test(ModelType::Type{<:JuMP.AbstractModel}, VariableRefType::Type{<:JuMP.AbstractVariableRef})
@testset "build_constraint on variable" begin
m = ModelType()
Expand Down Expand Up @@ -375,6 +476,7 @@ function macros_test(ModelType::Type{<:JuMP.AbstractModel}, VariableRefType::Typ
build_constraint_keyword_test(ModelType)
custom_expression_test(ModelType)
custom_function_test(ModelType)
build_constraint_test(ModelType)
end

@testset "Macros for JuMP.Model" begin
Expand Down