From dacc801a767f21dfa94d373117b25cff5bc19144 Mon Sep 17 00:00:00 2001 From: Thibaut Cuvelier Date: Thu, 7 May 2020 03:27:09 +0200 Subject: [PATCH] Make tests pass and simplify a bit. --- src/macros.jl | 39 +++++++++++++-------------------------- test/macros.jl | 19 +++++++------------ 2 files changed, 20 insertions(+), 38 deletions(-) diff --git a/src/macros.jl b/src/macros.jl index 589dfa70329..573a71454f8 100644 --- a/src/macros.jl +++ b/src/macros.jl @@ -163,23 +163,26 @@ end # - code for parsing (which must be called first), if needed; otherwise, :(). # - code for building the required constraints, if needed; otherwise, :(). # - the symbol of the variable that replaces the expression (<: VariableRef). -_rewrite_expr(_error, ::Val{:call}, op::Val, args...) = :(), :(), Expr(:call, op, args...) -_rewrite_expr(_error::Function, head::Val, args...) = :(), :(), Expr(head, args...) +_rewrite_expr(_error::Function, ::Val{:call}, ::Val{OP}, args...) where OP = :(), :(), Expr(:call, OP, args...) +_rewrite_expr(_error::Function, ::Val{HEAD}, args...) where HEAD = :(), :(), Expr(HEAD, args...) +_rewrite_expr(_error::Function, e::Symbol) = :(), :(), e +_rewrite_expr(_error::Function, e::Number) = :(), :(), e _rewrite_expr(_error::Function, e::Expr) = _rewrite_expr(_error, Val(e.head), e.args...) _rewrite_expr(_error::Function, head::Val{:call}, args...) = _rewrite_expr(_error, Val(:call), Val(args[1]), args[2:end]...) + function _rewrite_expr(_error::Function, ::Val{:call}, op::Union{Val{:+}, Val{:-}, Val{:*}, Val{:/}}, args...) parse_code = :() build_code = :() new_args = [] for arg in args - p, b, a = _rewrite_expr(a) + p, b, a = _rewrite_expr(_error, arg) parse_code = :($parse_code; $p) build_code = :($build_code; $b) push!(new_args, a) end - return parse_code, build_code, Expr(:call, op, new_args...) + return parse_code, build_code, Expr(:call, typeof(op).parameters[1], new_args...) end _functionize(v::VariableRef) = convert(AffExpr, v) @@ -187,8 +190,8 @@ _functionize(v::AbstractArray{VariableRef}) = _functionize.(v) _functionize(x) = x _functionize(::MutableArithmetics.Zero) = 0.0 function parse_one_operator_constraint(_error::Function, vectorized::Bool, sense::Val, lhs, rhs) - parse_code_rhs, build_code_rhs, new_rhs = _rewrite_expr(_error, lhs) - parse_code_lhs, build_code_lhs, new_lhs = _rewrite_expr(_error, rhs) + parse_code_rhs, build_code_rhs, new_rhs = _rewrite_expr(_error, rhs) + parse_code_lhs, build_code_lhs, new_lhs = _rewrite_expr(_error, lhs) # Simple comparison - move everything to the LHS. # @@ -196,28 +199,12 @@ function parse_one_operator_constraint(_error::Function, vectorized::Bool, sense # and the `rhs` is a summation with no terms. `_build_call` should be passed a # `GenericAffExpr` or a `GenericQuadExpr`, and not a `VariableRef` as would be the case # without `_functionize`. - - # TODO: bug in MutableArithmetics? The returned code does not find $new_rhs (ERROR: UndefVarError: #642###976 not defined) when only using the first code path. - parse_code, variable = nothing, nothing - if rhs == new_rhs && lhs == new_lhs - if vectorized - func = :($lhs .- $rhs) - else - func = :($lhs - $rhs) - end - variable, parse_code = _MA.rewrite(func) - elseif rhs != new_rhs && lhs == new_lhs - variable, parse_code = _MA.rewrite(lhs) - parse_code = :($parse_code; $variable = _MA.operate!(-, $variable, $new_rhs)) - elseif lhs != new_lhs && rhs == new_rhs - variable, parse_code = _MA.rewrite(rhs) - parse_code = :($parse_code; $variable = _MA.operate!(-, $new_lhs, $variable)) + if vectorized + func = :($new_lhs .- $new_rhs) else - @assert rhs != new_rhs - @assert lhs != new_lhs - variable = gensym() - parse_code = :($parse_code; $variable = _MA.operate!(-, $new_lhs, $new_rhs)) + func = :($new_lhs - $new_rhs) end + variable, parse_code = _MA.rewrite(func) set = sense_to_set(_error, sense) build_code = _build_call(_error, vectorized, :(_functionize($variable)), set) diff --git a/test/macros.jl b/test/macros.jl index bd979476107..84c26255349 100644 --- a/test/macros.jl +++ b/test/macros.jl @@ -192,12 +192,7 @@ function custom_function_test(ModelType::Type{<:JuMP.AbstractModel}) end end -JuMP.expression_to_rewrite(head::Val{:donothing}, var) = true -JuMP.expression_to_rewrite(head::Val{:donothing}, var) = true -function JuMP.rewrite_call_expression(errorf::Function, head::Val{:donothing}, var) - return :(), :(), esc(var) -end - +JuMP._rewrite_expr(_error::Function, ::Val{:call}, ::Val{:donothing}, arg) = :(), :(), arg function build_constraint_test(ModelType::Type{<:JuMP.AbstractModel}) @testset "Extension of @constraint with rewrite_call_expression #2229" begin @testset "Simple: only the function in rhs/lhs" begin @@ -240,8 +235,8 @@ function build_constraint_test(ModelType::Type{<:JuMP.AbstractModel}) cref = @constraint(model, x == 1 + donothing(y)) c = JuMP.constraint_object(cref) - @test JuMP.isequal_canonical(c.func, x - y - 1) - @test c.set == MOI.EqualTo(0.0) + @test JuMP.isequal_canonical(c.func, x - y) + @test c.set == MOI.EqualTo(1.0) # LHS model = ModelType() @@ -250,8 +245,8 @@ function build_constraint_test(ModelType::Type{<:JuMP.AbstractModel}) cref = @constraint(model, donothing(x) - 1 == y) c = JuMP.constraint_object(cref) - @test JuMP.isequal_canonical(c.func, x - y - 1) - @test c.set == MOI.EqualTo(0.0) + @test JuMP.isequal_canonical(c.func, x - y) + @test c.set == MOI.EqualTo(1.0) # Both sides. model = ModelType() @@ -260,8 +255,8 @@ function build_constraint_test(ModelType::Type{<:JuMP.AbstractModel}) cref = @constraint(model, donothing(x) - 0.5 == donothing(y) + 0.5) c = JuMP.constraint_object(cref) - @test JuMP.isequal_canonical(c.func, x - y - 1) - @test c.set == MOI.EqualTo(0.0) + @test JuMP.isequal_canonical(c.func, x - y) + @test c.set == MOI.EqualTo(1.0) end end end