diff --git a/plan/expression_rewriter.go b/plan/expression_rewriter.go index 2426da6162aee..485db23d82c6d 100644 --- a/plan/expression_rewriter.go +++ b/plan/expression_rewriter.go @@ -151,15 +151,11 @@ func popRowArg(ctx sessionctx.Context, e expression.Expression) (ret expression. } // 1. If op are EQ or NE or NullEQ, constructBinaryOpFunctions converts (a0,a1,a2) op (b0,b1,b2) to (a0 op b0) and (a1 op b1) and (a2 op b2) -// 2. If op are LE or GE, constructBinaryOpFunctions converts (a0,a1,a2) op (b0,b1,b2) to -// `IF( (a0 op b0) EQ 0, 0, -// IF ( (a1 op b1) EQ 0, 0, a2 op b2))` -// 3. If op are LT or GT, constructBinaryOpFunctions converts (a0,a1,a2) op (b0,b1,b2) to +// 2. Else constructBinaryOpFunctions converts (a0,a1,a2) op (b0,b1,b2) to // `IF( a0 NE b0, a0 op b0, -// IF( a1 NE b1, -// a1 op b1, -// a2 op b2) -// )` +// IF ( isNull(a0 NE b0), Null, +// IF ( a1 NE b1, a1 op b1, +// IF ( isNull(a1 NE b1), Null, a2 op b2))))` func (er *expressionRewriter) constructBinaryOpFunction(l expression.Expression, r expression.Expression, op string) (expression.Expression, error) { lLen, rLen := getRowLen(l), getRowLen(r) if lLen == 1 && rLen == 1 { @@ -180,15 +176,10 @@ func (er *expressionRewriter) constructBinaryOpFunction(l expression.Expression, return expression.ComposeCNFCondition(er.ctx, funcs...), nil default: larg0, rarg0 := getRowArg(l, 0), getRowArg(r, 0) - var expr1, expr2, expr3 expression.Expression - if op == ast.LE || op == ast.GE { - expr1 = expression.NewFunctionInternal(er.ctx, op, types.NewFieldType(mysql.TypeTiny), larg0, rarg0) - expr1 = expression.NewFunctionInternal(er.ctx, ast.EQ, types.NewFieldType(mysql.TypeTiny), expr1, expression.Zero) - expr2 = expression.Zero - } else if op == ast.LT || op == ast.GT { - expr1 = expression.NewFunctionInternal(er.ctx, ast.NE, types.NewFieldType(mysql.TypeTiny), larg0, rarg0) - expr2 = expression.NewFunctionInternal(er.ctx, op, types.NewFieldType(mysql.TypeTiny), larg0, rarg0) - } + var expr1, expr2, expr3, expr4, expr5 expression.Expression + expr1 = expression.NewFunctionInternal(er.ctx, ast.NE, types.NewFieldType(mysql.TypeTiny), larg0, rarg0) + expr2 = expression.NewFunctionInternal(er.ctx, op, types.NewFieldType(mysql.TypeTiny), larg0, rarg0) + expr3 = expression.NewFunctionInternal(er.ctx, ast.IsNull, types.NewFieldType(mysql.TypeTiny), expr1) var err error l, err = popRowArg(er.ctx, l) if err != nil { @@ -198,11 +189,15 @@ func (er *expressionRewriter) constructBinaryOpFunction(l expression.Expression, if err != nil { return nil, errors.Trace(err) } - expr3, err = er.constructBinaryOpFunction(l, r, op) + expr4, err = er.constructBinaryOpFunction(l, r, op) + if err != nil { + return nil, errors.Trace(err) + } + expr5, err = expression.NewFunction(er.ctx, ast.If, types.NewFieldType(mysql.TypeTiny), expr3, expression.Null, expr4) if err != nil { return nil, errors.Trace(err) } - return expression.NewFunction(er.ctx, ast.If, types.NewFieldType(mysql.TypeTiny), expr1, expr2, expr3) + return expression.NewFunction(er.ctx, ast.If, types.NewFieldType(mysql.TypeTiny), expr1, expr2, expr5) } }