diff --git a/executor/join_test.go b/executor/join_test.go index e79758db5ef96..8b4c37e654f6b 100644 --- a/executor/join_test.go +++ b/executor/join_test.go @@ -876,11 +876,11 @@ func (s *testSuite) TestMergejoinOrder(c *C) { tk.MustExec("insert into t2 select a*100, b*100 from t1;") tk.MustQuery("explain select /*+ TIDB_SMJ(t2) */ * from t1 left outer join t2 on t1.a=t2.a and t1.a!=3 order by t1.a;").Check(testkit.Rows( - "MergeJoin_15 12500.00 root left outer join, left key:test.t1.a, right key:test.t2.a, left cond:[ne(test.t1.a, 3)]", + "MergeJoin_15 10000.00 root left outer join, left key:test.t1.a, right key:test.t2.a, left cond:[ne(test.t1.a, 3)]", "├─TableReader_11 10000.00 root data:TableScan_10", "│ └─TableScan_10 10000.00 cop table:t1, range:[-inf,+inf], keep order:true, stats:pseudo", - "└─TableReader_13 10000.00 root data:TableScan_12", - " └─TableScan_12 10000.00 cop table:t2, range:[-inf,+inf], keep order:true, stats:pseudo", + "└─TableReader_13 6666.67 root data:TableScan_12", + " └─TableScan_12 6666.67 cop table:t2, range:[-inf,3), (3,+inf], keep order:true, stats:pseudo", )) tk.MustExec("set @@tidb_max_chunk_size=1") diff --git a/expression/constant_propagation.go b/expression/constant_propagation.go index 2919c5dc359c8..87b217100627a 100644 --- a/expression/constant_propagation.go +++ b/expression/constant_propagation.go @@ -50,13 +50,115 @@ func (m *multiEqualSet) findRoot(a int) int { return m.parent[a] } -type propagateConstantSolver struct { - colMapper map[string]int // colMapper maps column to its index - unionSet *multiEqualSet // unionSet stores the relations like col_i = col_j - eqList []*Constant // if eqList[i] != nil, it means col_i = eqList[i] - columns []*Column // columns stores all columns appearing in the conditions +type basePropConstSolver struct { + colMapper map[int64]int // colMapper maps column to its index + eqList []*Constant // if eqList[i] != nil, it means col_i = eqList[i] + unionSet *multiEqualSet // unionSet stores the relations like col_i = col_j + columns []*Column // columns stores all columns appearing in the conditions + ctx sessionctx.Context +} + +func (s *basePropConstSolver) getColID(col *Column) int { + return s.colMapper[col.UniqueID] +} + +func (s *basePropConstSolver) insertCol(col *Column) { + _, ok := s.colMapper[col.UniqueID] + if !ok { + s.colMapper[col.UniqueID] = len(s.colMapper) + s.columns = append(s.columns, col) + } +} + +// tryToUpdateEQList tries to update the eqList. When the eqList has store this column with a different constant, like +// a = 1 and a = 2, we set the second return value to false. +func (s *basePropConstSolver) tryToUpdateEQList(col *Column, con *Constant) (bool, bool) { + if con.Value.IsNull() { + return false, true + } + id := s.getColID(col) + oldCon := s.eqList[id] + if oldCon != nil { + return false, !oldCon.Equal(s.ctx, con) + } + s.eqList[id] = con + return true, false +} + +// validEqualCond checks if the cond is an expression like [column eq constant]. +func validEqualCond(cond Expression) (*Column, *Constant) { + if eq, ok := cond.(*ScalarFunction); ok { + if eq.FuncName.L != ast.EQ { + return nil, nil + } + if col, colOk := eq.GetArgs()[0].(*Column); colOk { + if con, conOk := eq.GetArgs()[1].(*Constant); conOk { + return col, con + } + } + if col, colOk := eq.GetArgs()[1].(*Column); colOk { + if con, conOk := eq.GetArgs()[0].(*Constant); conOk { + return col, con + } + } + } + return nil, nil +} + +// tryToReplaceCond aims to replace all occurrences of column 'src' and try to replace it with 'tgt' in 'cond' +// It returns +// bool: if a replacement happened +// bool: if 'cond' contains non-deterministic expression +// Expression: the replaced expression, or original 'cond' if the replacement didn't happen +// +// For example: +// for 'a, b, a < 3', it returns 'true, false, b < 3' +// for 'a, b, sin(a) + cos(a) = 5', it returns 'true, false, returns sin(b) + cos(b) = 5' +// for 'a, b, cast(a) < rand()', it returns 'false, true, cast(a) < rand()' +func tryToReplaceCond(ctx sessionctx.Context, src *Column, tgt *Column, cond Expression) (bool, bool, Expression) { + sf, ok := cond.(*ScalarFunction) + if !ok { + return false, false, cond + } + replaced := false + var args []Expression + if _, ok := unFoldableFunctions[sf.FuncName.L]; ok { + return false, true, cond + } + if _, ok := inequalFunctions[sf.FuncName.L]; ok { + return false, true, cond + } + for idx, expr := range sf.GetArgs() { + if src.Equal(nil, expr) { + replaced = true + if args == nil { + args = make([]Expression, len(sf.GetArgs())) + copy(args, sf.GetArgs()) + } + args[idx] = tgt + } else { + subReplaced, isNonDeterminisitic, subExpr := tryToReplaceCond(ctx, src, tgt, expr) + if isNonDeterminisitic { + return false, true, cond + } else if subReplaced { + replaced = true + if args == nil { + args = make([]Expression, len(sf.GetArgs())) + copy(args, sf.GetArgs()) + } + args[idx] = subExpr + } + } + } + if replaced { + return true, false, NewFunctionInternal(ctx, sf.FuncName.L, sf.GetType(), args...) + } + return false, false, cond +} + +type propConstSolver struct { + basePropConstSolver conditions []Expression - ctx sessionctx.Context } // propagateConstantEQ propagates expressions like 'column = constant' by substituting the constant for column, the @@ -64,7 +166,7 @@ type propagateConstantSolver struct { // a = d & b * 2 = c & c = d + 2 & b = 1 & a = 4, we pick eq cond b = 1 and a = 4 // d = 4 & 2 = c & c = d + 2 & b = 1 & a = 4, we propagate b = 1 and a = 4 and pick eq cond c = 2 and d = 4 // d = 4 & 2 = c & false & b = 1 & a = 4, we propagate c = 2 and d = 4, and do constant folding: c = d + 2 will be folded as false. -func (s *propagateConstantSolver) propagateConstantEQ() { +func (s *propConstSolver) propagateConstantEQ() { s.eqList = make([]*Constant, len(s.columns)) visited := make([]bool, len(s.conditions)) for i := 0; i < MaxPropagateColsCnt; i++ { @@ -104,7 +206,7 @@ func (s *propagateConstantSolver) propagateConstantEQ() { // TODO: remove redundancies later // // We maintain a unionSet representing the equivalent for every two columns. -func (s *propagateConstantSolver) propagateColumnEQ() { +func (s *propConstSolver) propagateColumnEQ() { visited := make([]bool, len(s.conditions)) s.unionSet = &multiEqualSet{} s.unionSet.init(len(s.columns)) @@ -135,11 +237,11 @@ func (s *propagateConstantSolver) propagateColumnEQ() { continue } cond := s.conditions[k] - replaced, _, newExpr := s.tryToReplaceCond(coli, colj, cond) + replaced, _, newExpr := tryToReplaceCond(s.ctx, coli, colj, cond) if replaced { s.conditions = append(s.conditions, newExpr) } - replaced, _, newExpr = s.tryToReplaceCond(colj, coli, cond) + replaced, _, newExpr = tryToReplaceCond(s.ctx, colj, coli, cond) if replaced { s.conditions = append(s.conditions, newExpr) } @@ -148,78 +250,7 @@ func (s *propagateConstantSolver) propagateColumnEQ() { } } -// validEqualCond checks if the cond is an expression like [column eq constant]. -func (s *propagateConstantSolver) validEqualCond(cond Expression) (*Column, *Constant) { - if eq, ok := cond.(*ScalarFunction); ok { - if eq.FuncName.L != ast.EQ { - return nil, nil - } - if col, colOk := eq.GetArgs()[0].(*Column); colOk { - if con, conOk := eq.GetArgs()[1].(*Constant); conOk { - return col, con - } - } - if col, colOk := eq.GetArgs()[1].(*Column); colOk { - if con, conOk := eq.GetArgs()[0].(*Constant); conOk { - return col, con - } - } - } - return nil, nil -} - -// tryToReplaceCond aims to replace all occurrences of column 'src' and try to replace it with 'tgt' in 'cond' -// It returns -// bool: if a replacement happened -// bool: if 'cond' contains non-deterministic expression -// Expression: the replaced expression, or original 'cond' if the replacement didn't happen -// -// For example: -// for 'a, b, a < 3', it returns 'true, false, b < 3' -// for 'a, b, sin(a) + cos(a) = 5', it returns 'true, false, returns sin(b) + cos(b) = 5' -// for 'a, b, cast(a) < rand()', it returns 'false, true, cast(a) < rand()' -func (s *propagateConstantSolver) tryToReplaceCond(src *Column, tgt *Column, cond Expression) (bool, bool, Expression) { - sf, ok := cond.(*ScalarFunction) - if !ok { - return false, false, cond - } - replaced := false - var args []Expression - if _, ok := unFoldableFunctions[sf.FuncName.L]; ok { - return false, true, cond - } - if _, ok := inequalFunctions[sf.FuncName.L]; ok { - return false, true, cond - } - for idx, expr := range sf.GetArgs() { - if src.Equal(nil, expr) { - replaced = true - if args == nil { - args = make([]Expression, len(sf.GetArgs())) - copy(args, sf.GetArgs()) - } - args[idx] = tgt - } else { - subReplaced, isNonDeterminisitic, subExpr := s.tryToReplaceCond(src, tgt, expr) - if isNonDeterminisitic { - return false, true, cond - } else if subReplaced { - replaced = true - if args == nil { - args = make([]Expression, len(sf.GetArgs())) - copy(args, sf.GetArgs()) - } - args[idx] = subExpr - } - } - } - if replaced { - return true, false, NewFunctionInternal(s.ctx, sf.FuncName.L, sf.GetType(), args...) - } - return false, false, cond -} - -func (s *propagateConstantSolver) setConds2ConstFalse() { +func (s *propConstSolver) setConds2ConstFalse() { s.conditions = []Expression{&Constant{ Value: types.NewDatum(false), RetType: types.NewFieldType(mysql.TypeTiny), @@ -227,19 +258,22 @@ func (s *propagateConstantSolver) setConds2ConstFalse() { } // pickNewEQConds tries to pick new equal conds and puts them to retMapper. -func (s *propagateConstantSolver) pickNewEQConds(visited []bool) (retMapper map[int]*Constant) { +func (s *propConstSolver) pickNewEQConds(visited []bool) (retMapper map[int]*Constant) { retMapper = make(map[int]*Constant) for i, cond := range s.conditions { if visited[i] { continue } - col, con := s.validEqualCond(cond) + col, con := validEqualCond(cond) // Then we check if this CNF item is a false constant. If so, we will set the whole condition to false. var ok bool if col == nil { if con, ok = cond.(*Constant); ok { value, err := EvalBool(s.ctx, []Expression{con}, chunk.Row{}) - terror.Log(errors.Trace(err)) + if err != nil { + terror.Log(errors.Trace(err)) + return nil + } if !value { s.setConds2ConstFalse() return nil @@ -260,22 +294,7 @@ func (s *propagateConstantSolver) pickNewEQConds(visited []bool) (retMapper map[ return } -// tryToUpdateEQList tries to update the eqList. When the eqList has store this column with a different constant, like -// a = 1 and a = 2, we set the second return value to false. -func (s *propagateConstantSolver) tryToUpdateEQList(col *Column, con *Constant) (bool, bool) { - if con.Value.IsNull() { - return false, true - } - id := s.getColID(col) - oldCon := s.eqList[id] - if oldCon != nil { - return false, !oldCon.Equal(s.ctx, con) - } - s.eqList[id] = con - return true, false -} - -func (s *propagateConstantSolver) solve(conditions []Expression) []Expression { +func (s *propConstSolver) solve(conditions []Expression) []Expression { cols := make([]*Column, 0, len(conditions)) for _, cond := range conditions { s.conditions = append(s.conditions, SplitCNFItems(cond)...) @@ -290,37 +309,269 @@ func (s *propagateConstantSolver) solve(conditions []Expression) []Expression { } s.propagateConstantEQ() s.propagateColumnEQ() - for i, cond := range s.conditions { - if dnf, ok := cond.(*ScalarFunction); ok && dnf.FuncName.L == ast.LogicOr { - dnfItems := SplitDNFItems(cond) - for j, item := range dnfItems { - dnfItems[j] = ComposeCNFCondition(s.ctx, PropagateConstant(s.ctx, []Expression{item})...) + s.conditions = propagateConstantDNF(s.ctx, s.conditions) + return s.conditions +} + +// PropagateConstant propagate constant values of deterministic predicates in a condition. +func PropagateConstant(ctx sessionctx.Context, conditions []Expression) []Expression { + solver := &propConstSolver{} + solver.colMapper = make(map[int64]int) + solver.ctx = ctx + return solver.solve(conditions) +} + +type propOuterJoinConstSolver struct { + basePropConstSolver + joinConds []Expression + filterConds []Expression + outerSchema *Schema + innerSchema *Schema +} + +func (s *propOuterJoinConstSolver) setConds2ConstFalse(filterConds bool) { + s.joinConds = []Expression{&Constant{ + Value: types.NewDatum(false), + RetType: types.NewFieldType(mysql.TypeTiny), + }} + if filterConds { + s.filterConds = []Expression{&Constant{ + Value: types.NewDatum(false), + RetType: types.NewFieldType(mysql.TypeTiny), + }} + } +} + +// pickEQCondsOnOuterCol picks constant equal expression from specified conditions. +func (s *propOuterJoinConstSolver) pickEQCondsOnOuterCol(retMapper map[int]*Constant, visited []bool, filterConds bool) map[int]*Constant { + var conds []Expression + var condsOffset int + if filterConds { + conds = s.filterConds + } else { + conds = s.joinConds + condsOffset = len(s.filterConds) + } + for i, cond := range conds { + if visited[i+condsOffset] { + continue + } + col, con := validEqualCond(cond) + // Then we check if this CNF item is a false constant. If so, we will set the whole condition to false. + var ok bool + if col == nil { + if con, ok = cond.(*Constant); ok { + value, err := EvalBool(s.ctx, []Expression{con}, chunk.Row{}) + if err != nil { + terror.Log(errors.Trace(err)) + return nil + } + if !value { + s.setConds2ConstFalse(filterConds) + return nil + } } - s.conditions[i] = ComposeDNFCondition(s.ctx, dnfItems...) + continue + } + // Only extract `outerCol = const` expressions. + if !s.outerSchema.Contains(col) { + continue + } + visited[i+condsOffset] = true + updated, foreverFalse := s.tryToUpdateEQList(col, con) + if foreverFalse { + s.setConds2ConstFalse(filterConds) + return nil + } + if updated { + retMapper[s.getColID(col)] = con } } - return s.conditions + return retMapper } -func (s *propagateConstantSolver) getColID(col *Column) int { - code := col.HashCode(nil) - return s.colMapper[string(code)] +// pickNewEQConds picks constant equal expressions from join and filter conditions. +func (s *propOuterJoinConstSolver) pickNewEQConds(visited []bool) map[int]*Constant { + retMapper := make(map[int]*Constant) + retMapper = s.pickEQCondsOnOuterCol(retMapper, visited, true) + if retMapper == nil { + // Filter is constant false or error occured, enforce early termination. + return nil + } + retMapper = s.pickEQCondsOnOuterCol(retMapper, visited, false) + return retMapper } -func (s *propagateConstantSolver) insertCol(col *Column) { - code := col.HashCode(nil) - _, ok := s.colMapper[string(code)] - if !ok { - s.colMapper[string(code)] = len(s.colMapper) - s.columns = append(s.columns, col) +// propagateConstantEQ propagates expressions like `outerCol = const` by substituting `outerCol` in *JOIN* condition +// with `const`, the procedure repeats multiple times. +func (s *propOuterJoinConstSolver) propagateConstantEQ() { + s.eqList = make([]*Constant, len(s.columns)) + lenFilters := len(s.filterConds) + visited := make([]bool, lenFilters+len(s.joinConds)) + for i := 0; i < MaxPropagateColsCnt; i++ { + mapper := s.pickNewEQConds(visited) + if len(mapper) == 0 { + return + } + cols := make([]*Column, 0, len(mapper)) + cons := make([]Expression, 0, len(mapper)) + for id, con := range mapper { + cols = append(cols, s.columns[id]) + cons = append(cons, con) + } + for i, cond := range s.joinConds { + if !visited[i+lenFilters] { + s.joinConds[i] = ColumnSubstitute(cond, NewSchema(cols...), cons) + } + } } } -// PropagateConstant propagate constant values of deterministic predicates in a condition. -func PropagateConstant(ctx sessionctx.Context, conditions []Expression) []Expression { - solver := &propagateConstantSolver{ - colMapper: make(map[string]int), - ctx: ctx, +func (s *propOuterJoinConstSolver) colsFromOuterAndInner(col1, col2 *Column) (*Column, *Column) { + if s.outerSchema.Contains(col1) && s.innerSchema.Contains(col2) { + return col1, col2 } - return solver.solve(conditions) + if s.outerSchema.Contains(col2) && s.innerSchema.Contains(col1) { + return col2, col1 + } + return nil, nil +} + +// validColEqualCond checks if expression is column equal condition that we can use for constant +// propagation over outer join. We only use expression like `outerCol = innerCol`, for expressions like +// `outerCol1 = outerCol2` or `innerCol1 = innerCol2`, they do not help deriving new inner table conditions +// which can be pushed down to children plan nodes, so we do not pick them. +func (s *propOuterJoinConstSolver) validColEqualCond(cond Expression) (*Column, *Column) { + if fun, ok := cond.(*ScalarFunction); ok && fun.FuncName.L == ast.EQ { + lCol, lOk := fun.GetArgs()[0].(*Column) + rCol, rOk := fun.GetArgs()[1].(*Column) + if lOk && rOk { + return s.colsFromOuterAndInner(lCol, rCol) + } + } + return nil, nil + +} + +// deriveConds given `outerCol = innerCol`, derive new expression for specified conditions. +func (s *propOuterJoinConstSolver) deriveConds(outerCol, innerCol *Column, schema *Schema, fCondsOffset int, visited []bool, filterConds bool) []bool { + var offset, condsLen int + var conds []Expression + if filterConds { + conds = s.filterConds + offset = fCondsOffset + condsLen = len(s.filterConds) + } else { + conds = s.joinConds + condsLen = fCondsOffset + } + for k := 0; k < condsLen; k++ { + if visited[k+offset] { + // condition has been used to retrieve equality relation or contains column beyond children schema. + continue + } + cond := conds[k] + if !ExprFromSchema(cond, schema) { + visited[k+offset] = true + continue + } + replaced, _, newExpr := tryToReplaceCond(s.ctx, outerCol, innerCol, cond) + if replaced { + s.joinConds = append(s.joinConds, newExpr) + } + } + return visited +} + +// propagateColumnEQ propagates expressions like 'outerCol = innerCol' by adding extra filters +// 'expression(..., innerCol, ...)' derived from 'expression(..., outerCol, ...)' as long as +// 'expression(..., outerCol, ...)' does not reference columns outside children schemas of join node. +// Derived new expressions must be appended into join condition, not filter condition. +func (s *propOuterJoinConstSolver) propagateColumnEQ() { + visited := make([]bool, len(s.joinConds)+len(s.filterConds)) + s.unionSet = &multiEqualSet{} + s.unionSet.init(len(s.columns)) + var outerCol, innerCol *Column + // Only consider column equal condition in joinConds. + // If we have column equal in filter condition, the outer join should have been simplified already. + for i := range s.joinConds { + outerCol, innerCol = s.validColEqualCond(s.joinConds[i]) + if outerCol != nil { + outerID := s.getColID(outerCol) + innerID := s.getColID(innerCol) + s.unionSet.addRelation(outerID, innerID) + visited[i] = true + } + } + lenJoinConds := len(s.joinConds) + mergedSchema := MergeSchema(s.outerSchema, s.innerSchema) + for i, coli := range s.columns { + for j := i + 1; j < len(s.columns); j++ { + // unionSet doesn't have iterate(), we use a two layer loop to iterate col_i = col_j relation. + if s.unionSet.findRoot(i) != s.unionSet.findRoot(j) { + continue + } + colj := s.columns[j] + outerCol, innerCol = s.colsFromOuterAndInner(coli, colj) + if outerCol == nil { + continue + } + visited = s.deriveConds(outerCol, innerCol, mergedSchema, lenJoinConds, visited, false) + visited = s.deriveConds(outerCol, innerCol, mergedSchema, lenJoinConds, visited, true) + } + } +} + +func (s *propOuterJoinConstSolver) solve(joinConds, filterConds []Expression) ([]Expression, []Expression) { + cols := make([]*Column, 0, len(joinConds)+len(filterConds)) + for _, cond := range joinConds { + s.joinConds = append(s.joinConds, SplitCNFItems(cond)...) + cols = append(cols, ExtractColumns(cond)...) + } + for _, cond := range filterConds { + s.filterConds = append(s.filterConds, SplitCNFItems(cond)...) + cols = append(cols, ExtractColumns(cond)...) + } + for _, col := range cols { + s.insertCol(col) + } + if len(s.columns) > MaxPropagateColsCnt { + log.Warnf("[const_propagation_over_outerjoin] Too many columns: column count is %d, max count is %d.", len(s.columns), MaxPropagateColsCnt) + return joinConds, filterConds + } + s.propagateConstantEQ() + s.propagateColumnEQ() + s.joinConds = propagateConstantDNF(s.ctx, s.joinConds) + s.filterConds = propagateConstantDNF(s.ctx, s.filterConds) + return s.joinConds, s.filterConds +} + +// propagateConstantDNF find DNF item from CNF, and propagate constant inside DNF. +func propagateConstantDNF(ctx sessionctx.Context, conds []Expression) []Expression { + for i, cond := range conds { + if dnf, ok := cond.(*ScalarFunction); ok && dnf.FuncName.L == ast.LogicOr { + dnfItems := SplitDNFItems(cond) + for j, item := range dnfItems { + dnfItems[j] = ComposeCNFCondition(ctx, PropagateConstant(ctx, []Expression{item})...) + } + conds[i] = ComposeDNFCondition(ctx, dnfItems...) + } + } + return conds +} + +// PropConstOverOuterJoin propagate constant equal and column equal conditions over outer join. +// First step is to extract `outerCol = const` from join conditions and filter conditions, +// and substitute `outerCol` in join conditions with `const`; +// Second step is to extract `outerCol = innerCol` from join conditions, and derive new join +// conditions based on this column equal condition and `outerCol` related +// expressions in join conditions and filter conditions; +func PropConstOverOuterJoin(ctx sessionctx.Context, joinConds, filterConds []Expression, outerSchema, innerSchema *Schema) ([]Expression, []Expression) { + solver := &propOuterJoinConstSolver{ + outerSchema: outerSchema, + innerSchema: innerSchema, + } + solver.colMapper = make(map[int64]int) + solver.ctx = ctx + return solver.solve(joinConds, filterConds) } diff --git a/expression/constant_propagation_test.go b/expression/constant_propagation_test.go new file mode 100644 index 0000000000000..a9d2b49ff9e19 --- /dev/null +++ b/expression/constant_propagation_test.go @@ -0,0 +1,258 @@ +// Copyright 2018 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package expression_test + +import ( + "fmt" + + . "github.com/pingcap/check" + "github.com/pingcap/tidb/domain" + "github.com/pingcap/tidb/kv" + "github.com/pingcap/tidb/sessionctx" + "github.com/pingcap/tidb/util/mock" + "github.com/pingcap/tidb/util/testkit" + "github.com/pingcap/tidb/util/testleak" +) + +var _ = Suite(&testSuite{}) + +type testSuite struct { + store kv.Storage + dom *domain.Domain + ctx sessionctx.Context +} + +func (s *testSuite) cleanEnv(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + r := tk.MustQuery("show tables") + for _, tb := range r.Rows() { + tableName := tb[0] + tk.MustExec(fmt.Sprintf("drop table %v", tableName)) + } +} + +func (s *testSuite) SetUpSuite(c *C) { + var err error + testleak.BeforeTest() + s.store, s.dom, err = newStoreWithBootstrap() + c.Assert(err, IsNil) + s.ctx = mock.NewContext() +} + +func (s *testSuite) TearDownSuite(c *C) { + s.dom.Close() + s.store.Close() + testleak.AfterTest(c)() +} + +func (s *testSuite) TestOuterJoinPropConst(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t1, t2;") + tk.MustExec("create table t1(id bigint primary key, a int, b int);") + tk.MustExec("create table t2(id bigint primary key, a int, b int);") + + // Positive tests. + tk.MustQuery("explain select * from t1 left join t2 on t1.a > t2.a and t1.a = 1;").Check(testkit.Rows( + "HashLeftJoin_6 33233333.33 root left outer join, inner:TableReader_11, left cond:[eq(test.t1.a, 1)]", + "├─TableReader_8 10000.00 root data:TableScan_7", + "│ └─TableScan_7 10000.00 cop table:t1, range:[-inf,+inf], keep order:false, stats:pseudo", + "└─TableReader_11 3323.33 root data:Selection_10", + " └─Selection_10 3323.33 cop gt(1, test.t2.a)", + " └─TableScan_9 10000.00 cop table:t2, range:[-inf,+inf], keep order:false, stats:pseudo", + )) + tk.MustQuery("explain select * from t1 left join t2 on t1.a > t2.a where t1.a = 1;").Check(testkit.Rows( + "HashLeftJoin_7 33233.33 root left outer join, inner:TableReader_13", + "├─TableReader_10 10.00 root data:Selection_9", + "│ └─Selection_9 10.00 cop eq(test.t1.a, 1)", + "│ └─TableScan_8 10000.00 cop table:t1, range:[-inf,+inf], keep order:false, stats:pseudo", + "└─TableReader_13 3323.33 root data:Selection_12", + " └─Selection_12 3323.33 cop gt(1, test.t2.a)", + " └─TableScan_11 10000.00 cop table:t2, range:[-inf,+inf], keep order:false, stats:pseudo", + )) + tk.MustQuery("explain select * from t1 left join t2 on t1.a = t2.a and t1.a > 1;").Check(testkit.Rows( + "HashLeftJoin_6 10000.00 root left outer join, inner:TableReader_11, equal:[eq(test.t1.a, test.t2.a)], left cond:[gt(test.t1.a, 1)]", + "├─TableReader_8 10000.00 root data:TableScan_7", + "│ └─TableScan_7 10000.00 cop table:t1, range:[-inf,+inf], keep order:false, stats:pseudo", + "└─TableReader_11 3333.33 root data:Selection_10", + " └─Selection_10 3333.33 cop gt(test.t2.a, 1)", + " └─TableScan_9 10000.00 cop table:t2, range:[-inf,+inf], keep order:false, stats:pseudo", + )) + tk.MustQuery("explain select * from t1 left join t2 on t1.a = t2.a where t1.a > 1;").Check(testkit.Rows( + "HashLeftJoin_7 4166.67 root left outer join, inner:TableReader_13, equal:[eq(test.t1.a, test.t2.a)]", + "├─TableReader_10 3333.33 root data:Selection_9", + "│ └─Selection_9 3333.33 cop gt(test.t1.a, 1)", + "│ └─TableScan_8 10000.00 cop table:t1, range:[-inf,+inf], keep order:false, stats:pseudo", + "└─TableReader_13 3333.33 root data:Selection_12", + " └─Selection_12 3333.33 cop gt(test.t2.a, 1)", + " └─TableScan_11 10000.00 cop table:t2, range:[-inf,+inf], keep order:false, stats:pseudo", + )) + tk.MustQuery("explain select * from t1 right join t2 on t1.a > t2.a where t2.a = 1;").Check(testkit.Rows( + "HashRightJoin_7 33333.33 root right outer join, inner:TableReader_10", + "├─TableReader_10 3333.33 root data:Selection_9", + "│ └─Selection_9 3333.33 cop gt(test.t1.a, 1)", + "│ └─TableScan_8 10000.00 cop table:t1, range:[-inf,+inf], keep order:false, stats:pseudo", + "└─TableReader_13 10.00 root data:Selection_12", + " └─Selection_12 10.00 cop eq(test.t2.a, 1)", + " └─TableScan_11 10000.00 cop table:t2, range:[-inf,+inf], keep order:false, stats:pseudo", + )) + tk.MustQuery("explain select * from t1 right join t2 on t1.a = t2.a where t2.a > 1;").Check(testkit.Rows( + "HashRightJoin_7 4166.67 root right outer join, inner:TableReader_10, equal:[eq(test.t1.a, test.t2.a)]", + "├─TableReader_10 3333.33 root data:Selection_9", + "│ └─Selection_9 3333.33 cop gt(test.t1.a, 1)", + "│ └─TableScan_8 10000.00 cop table:t1, range:[-inf,+inf], keep order:false, stats:pseudo", + "└─TableReader_13 3333.33 root data:Selection_12", + " └─Selection_12 3333.33 cop gt(test.t2.a, 1)", + " └─TableScan_11 10000.00 cop table:t2, range:[-inf,+inf], keep order:false, stats:pseudo", + )) + tk.MustQuery("explain select * from t1 right join t2 on t1.a = t2.a and t2.a > 1;").Check(testkit.Rows( + "HashRightJoin_6 10000.00 root right outer join, inner:TableReader_9, equal:[eq(test.t1.a, test.t2.a)], right cond:gt(test.t2.a, 1)", + "├─TableReader_9 3333.33 root data:Selection_8", + "│ └─Selection_8 3333.33 cop gt(test.t1.a, 1)", + "│ └─TableScan_7 10000.00 cop table:t1, range:[-inf,+inf], keep order:false, stats:pseudo", + "└─TableReader_11 10000.00 root data:TableScan_10", + " └─TableScan_10 10000.00 cop table:t2, range:[-inf,+inf], keep order:false, stats:pseudo", + )) + tk.MustQuery("explain select * from t1 right join t2 on t1.a > t2.a and t2.a = 1;").Check(testkit.Rows( + "HashRightJoin_6 33333333.33 root right outer join, inner:TableReader_9, right cond:eq(test.t2.a, 1)", + "├─TableReader_9 3333.33 root data:Selection_8", + "│ └─Selection_8 3333.33 cop gt(test.t1.a, 1)", + "│ └─TableScan_7 10000.00 cop table:t1, range:[-inf,+inf], keep order:false, stats:pseudo", + "└─TableReader_11 10000.00 root data:TableScan_10", + " └─TableScan_10 10000.00 cop table:t2, range:[-inf,+inf], keep order:false, stats:pseudo", + )) + // Negative tests. + tk.MustQuery("explain select * from t1 left join t2 on t1.a = t2.a and t2.a > 1;").Check(testkit.Rows( + "HashLeftJoin_6 10000.00 root left outer join, inner:TableReader_11, equal:[eq(test.t1.a, test.t2.a)]", + "├─TableReader_8 10000.00 root data:TableScan_7", + "│ └─TableScan_7 10000.00 cop table:t1, range:[-inf,+inf], keep order:false, stats:pseudo", + "└─TableReader_11 3333.33 root data:Selection_10", + " └─Selection_10 3333.33 cop gt(test.t2.a, 1)", + " └─TableScan_9 10000.00 cop table:t2, range:[-inf,+inf], keep order:false, stats:pseudo", + )) + tk.MustQuery("explain select * from t1 left join t2 on t1.a > t2.a and t2.a = 1;").Check(testkit.Rows( + "HashLeftJoin_6 100000.00 root left outer join, inner:TableReader_11, other cond:gt(test.t1.a, test.t2.a)", + "├─TableReader_8 10000.00 root data:TableScan_7", + "│ └─TableScan_7 10000.00 cop table:t1, range:[-inf,+inf], keep order:false, stats:pseudo", + "└─TableReader_11 10.00 root data:Selection_10", + " └─Selection_10 10.00 cop eq(test.t2.a, 1)", + " └─TableScan_9 10000.00 cop table:t2, range:[-inf,+inf], keep order:false, stats:pseudo", + )) + tk.MustQuery("explain select * from t1 right join t2 on t1.a > t2.a and t1.a = 1;").Check(testkit.Rows( + "HashRightJoin_6 100000.00 root right outer join, inner:TableReader_9, other cond:gt(test.t1.a, test.t2.a)", + "├─TableReader_9 10.00 root data:Selection_8", + "│ └─Selection_8 10.00 cop eq(test.t1.a, 1)", + "│ └─TableScan_7 10000.00 cop table:t1, range:[-inf,+inf], keep order:false, stats:pseudo", + "└─TableReader_11 10000.00 root data:TableScan_10", + " └─TableScan_10 10000.00 cop table:t2, range:[-inf,+inf], keep order:false, stats:pseudo", + )) + tk.MustQuery("explain select * from t1 right join t2 on t1.a = t2.a and t1.a > 1;").Check(testkit.Rows( + "HashRightJoin_6 10000.00 root right outer join, inner:TableReader_9, equal:[eq(test.t1.a, test.t2.a)]", + "├─TableReader_9 3333.33 root data:Selection_8", + "│ └─Selection_8 3333.33 cop gt(test.t1.a, 1)", + "│ └─TableScan_7 10000.00 cop table:t1, range:[-inf,+inf], keep order:false, stats:pseudo", + "└─TableReader_11 10000.00 root data:TableScan_10", + " └─TableScan_10 10000.00 cop table:t2, range:[-inf,+inf], keep order:false, stats:pseudo", + )) + tk.MustQuery("explain select * from t1 left join t2 on t1.a = t1.b and t1.a > 1;").Check(testkit.Rows( + "HashLeftJoin_6 100000000.00 root left outer join, inner:TableReader_10, left cond:[eq(test.t1.a, test.t1.b) gt(test.t1.a, 1)]", + "├─TableReader_8 10000.00 root data:TableScan_7", + "│ └─TableScan_7 10000.00 cop table:t1, range:[-inf,+inf], keep order:false, stats:pseudo", + "└─TableReader_10 10000.00 root data:TableScan_9", + " └─TableScan_9 10000.00 cop table:t2, range:[-inf,+inf], keep order:false, stats:pseudo", + )) + tk.MustQuery("explain select * from t1 left join t2 on t2.a = t2.b and t2.a > 1;").Check(testkit.Rows( + "HashLeftJoin_6 26666666.67 root left outer join, inner:TableReader_11", + "├─TableReader_8 10000.00 root data:TableScan_7", + "│ └─TableScan_7 10000.00 cop table:t1, range:[-inf,+inf], keep order:false, stats:pseudo", + "└─TableReader_11 2666.67 root data:Selection_10", + " └─Selection_10 2666.67 cop eq(test.t2.a, test.t2.b), gt(test.t2.a, 1)", + " └─TableScan_9 10000.00 cop table:t2, range:[-inf,+inf], keep order:false, stats:pseudo", + )) + // Constant equal condition merge in outer join. + tk.MustQuery("explain select * from t1 left join t2 on true where t1.a = 1 and false;").Check(testkit.Rows( + "TableDual_8 0.00 root rows:0", + )) + tk.MustQuery("explain select * from t1 left join t2 on true where t1.a = 1 and null;").Check(testkit.Rows( + "TableDual_8 0.00 root rows:0", + )) + tk.MustQuery("explain select * from t1 left join t2 on true where t1.a = null;").Check(testkit.Rows( + "TableDual_8 0.00 root rows:0", + )) + tk.MustQuery("explain select * from t1 left join t2 on true where t1.a = 1 and t1.a = 2;").Check(testkit.Rows( + "TableDual_8 0.00 root rows:0", + )) + tk.MustQuery("explain select * from t1 left join t2 on true where t1.a = 1 and t1.a = 1;").Check(testkit.Rows( + "HashLeftJoin_7 80000.00 root left outer join, inner:TableReader_12", + "├─TableReader_10 10.00 root data:Selection_9", + "│ └─Selection_9 10.00 cop eq(test.t1.a, 1), eq(test.t1.a, 1)", + "│ └─TableScan_8 10000.00 cop table:t1, range:[-inf,+inf], keep order:false, stats:pseudo", + "└─TableReader_12 10000.00 root data:TableScan_11", + " └─TableScan_11 10000.00 cop table:t2, range:[-inf,+inf], keep order:false, stats:pseudo", + )) + tk.MustQuery("explain select * from t1 left join t2 on false;").Check(testkit.Rows( + "HashLeftJoin_6 80000000.00 root left outer join, inner:TableDual_9", + "├─TableReader_8 10000.00 root data:TableScan_7", + "│ └─TableScan_7 10000.00 cop table:t1, range:[-inf,+inf], keep order:false, stats:pseudo", + "└─TableDual_9 8000.00 root rows:0", + )) + tk.MustQuery("explain select * from t1 left join t2 on t1.a =1 and t1.a = 2;").Check(testkit.Rows( + "HashLeftJoin_6 80000000.00 root left outer join, inner:TableDual_9", + "├─TableReader_8 10000.00 root data:TableScan_7", + "│ └─TableScan_7 10000.00 cop table:t1, range:[-inf,+inf], keep order:false, stats:pseudo", + "└─TableDual_9 8000.00 root rows:0", + )) + tk.MustQuery("explain select * from t1 left join t2 on t1.a =1 where t1.a = 2;").Check(testkit.Rows( + "HashLeftJoin_7 80000.00 root left outer join, inner:TableDual_11", + "├─TableReader_10 10.00 root data:Selection_9", + "│ └─Selection_9 10.00 cop eq(test.t1.a, 2)", + "│ └─TableScan_8 10000.00 cop table:t1, range:[-inf,+inf], keep order:false, stats:pseudo", + "└─TableDual_11 8000.00 root rows:0", + )) + tk.MustQuery("explain select * from t1 left join t2 on t2.a = 1 and t2.a = 2;").Check(testkit.Rows( + "HashLeftJoin_6 0.00 root left outer join, inner:TableReader_11", + "├─TableReader_8 10000.00 root data:TableScan_7", + "│ └─TableScan_7 10000.00 cop table:t1, range:[-inf,+inf], keep order:false, stats:pseudo", + "└─TableReader_11 0.00 root data:Selection_10", + " └─Selection_10 0.00 cop eq(test.t2.a, 1), eq(test.t2.a, 2)", + " └─TableScan_9 10000.00 cop table:t2, range:[-inf,+inf], keep order:false, stats:pseudo", + )) + // Constant propagation for DNF in outer join. + tk.MustQuery("explain select * from t1 left join t2 on t1.a = 1 or (t1.a = 2 and t1.a = 3);").Check(testkit.Rows( + "HashLeftJoin_6 100000000.00 root left outer join, inner:TableReader_10, left cond:[or(eq(test.t1.a, 1), 0)]", + "├─TableReader_8 10000.00 root data:TableScan_7", + "│ └─TableScan_7 10000.00 cop table:t1, range:[-inf,+inf], keep order:false, stats:pseudo", + "└─TableReader_10 10000.00 root data:TableScan_9", + " └─TableScan_9 10000.00 cop table:t2, range:[-inf,+inf], keep order:false, stats:pseudo", + )) + tk.MustQuery("explain select * from t1 left join t2 on true where t1.a = 1 or (t1.a = 2 and t1.a = 3);").Check(testkit.Rows( + "HashLeftJoin_7 80000.00 root left outer join, inner:TableReader_12", + "├─TableReader_10 10.00 root data:Selection_9", + "│ └─Selection_9 10.00 cop or(eq(test.t1.a, 1), 0)", + "│ └─TableScan_8 10000.00 cop table:t1, range:[-inf,+inf], keep order:false, stats:pseudo", + "└─TableReader_12 10000.00 root data:TableScan_11", + " └─TableScan_11 10000.00 cop table:t2, range:[-inf,+inf], keep order:false, stats:pseudo", + )) + // Constant propagation over left outer semi join, filter with aux column should be be derived. + tk.MustQuery("explain select * from t1 where t1.b > 1 or t1.b in (select b from t2);").Check(testkit.Rows( + "Projection_7 8000.00 root test.t1.id, test.t1.a, test.t1.b", + "└─Selection_8 8000.00 root or(gt(test.t1.b, 1), 5_aux_0)", + " └─HashLeftJoin_9 10000.00 root left outer semi join, inner:TableReader_13, equal:[eq(test.t1.b, test.t2.b)]", + " ├─TableReader_11 10000.00 root data:TableScan_10", + " │ └─TableScan_10 10000.00 cop table:t1, range:[-inf,+inf], keep order:false, stats:pseudo", + " └─TableReader_13 10000.00 root data:TableScan_12", + " └─TableScan_12 10000.00 cop table:t2, range:[-inf,+inf], keep order:false, stats:pseudo", + )) +} diff --git a/planner/core/logical_plan_test.go b/planner/core/logical_plan_test.go index 2b4a4c9687273..1d268cb28181a 100644 --- a/planner/core/logical_plan_test.go +++ b/planner/core/logical_plan_test.go @@ -335,7 +335,7 @@ func (s *testPlanSuite) TestPredicatePushDown(c *C) { }, { sql: "select * from t ta left outer join t tb on ta.d = tb.d and ta.a > 1 where ta.d = 0", - best: "Join{DataScan(ta)->DataScan(tb)}(ta.d,tb.d)->Projection", + best: "Join{DataScan(ta)->DataScan(tb)}->Projection", }, { sql: "select * from t ta left outer join t tb on ta.d = tb.d and ta.a > 1 where tb.d = 0", diff --git a/planner/core/rule_predicate_push_down.go b/planner/core/rule_predicate_push_down.go index a22bc51417837..ffcb7467920ca 100644 --- a/planner/core/rule_predicate_push_down.go +++ b/planner/core/rule_predicate_push_down.go @@ -109,6 +109,11 @@ func (p *LogicalJoin) PredicatePushDown(predicates []expression.Expression) (ret var leftPushCond, rightPushCond, otherCond, leftCond, rightCond []expression.Expression switch p.JoinType { case LeftOuterJoin, LeftOuterSemiJoin, AntiLeftOuterSemiJoin: + predicates = p.outerJoinPropConst(predicates) + dual := conds2TableDual(p, predicates) + if dual != nil { + return ret, dual + } // Handle where conditions predicates = expression.ExtractFiltersFromDNFs(p.ctx, predicates) // Only derive left where condition, because right where condition cannot be pushed down @@ -121,6 +126,11 @@ func (p *LogicalJoin) PredicatePushDown(predicates []expression.Expression) (ret ret = append(expression.ScalarFuncs2Exprs(equalCond), otherCond...) ret = append(ret, rightPushCond...) case RightOuterJoin: + predicates = p.outerJoinPropConst(predicates) + dual := conds2TableDual(p, predicates) + if dual != nil { + return ret, dual + } // Handle where conditions predicates = expression.ExtractFiltersFromDNFs(p.ctx, predicates) // Only derive right where condition, because left where condition cannot be pushed down @@ -413,3 +423,27 @@ func conds2TableDual(p LogicalPlan, conds []expression.Expression) LogicalPlan { } return nil } + +// outerJoinPropConst propagates constant equal and column equal conditions over outer join. +func (p *LogicalJoin) outerJoinPropConst(predicates []expression.Expression) []expression.Expression { + outerTable := p.children[0] + innerTable := p.children[1] + if p.JoinType == RightOuterJoin { + innerTable, outerTable = outerTable, innerTable + } + lenJoinConds := len(p.EqualConditions) + len(p.LeftConditions) + len(p.RightConditions) + len(p.OtherConditions) + joinConds := make([]expression.Expression, 0, lenJoinConds) + for _, equalCond := range p.EqualConditions { + joinConds = append(joinConds, equalCond) + } + joinConds = append(joinConds, p.LeftConditions...) + joinConds = append(joinConds, p.RightConditions...) + joinConds = append(joinConds, p.OtherConditions...) + p.EqualConditions = nil + p.LeftConditions = nil + p.RightConditions = nil + p.OtherConditions = nil + joinConds, predicates = expression.PropConstOverOuterJoin(p.ctx, joinConds, predicates, outerTable.Schema(), innerTable.Schema()) + p.attachOnConds(joinConds) + return predicates +}