diff --git a/pkg/expression/column.go b/pkg/expression/column.go index 68f3eb900d416..b9d883bc61388 100644 --- a/pkg/expression/column.go +++ b/pkg/expression/column.go @@ -36,6 +36,7 @@ import ( var ( _ base.HashEquals = &Column{} + _ base.HashEquals = &CorrelatedColumn{} ) // CorrelatedColumn stands for a column in a correlated sub query. @@ -246,6 +247,31 @@ func (col *CorrelatedColumn) RemapColumn(m map[int64]*Column) (Expression, error }, nil } +// Hash64 implements HashEquals.<0th> interface. +func (col *CorrelatedColumn) Hash64(h base.Hasher) { + // correlatedColumn flag here is used to distinguish correlatedColumn and Column. + h.HashByte(correlatedColumn) + col.Column.Hash64(h) + // since col.Datum is filled in the runtime, we can't use it to calculate hash now, correlatedColumn flag + column is enough. +} + +// Equals implements HashEquals.<1st> interface. +func (col *CorrelatedColumn) Equals(other any) bool { + if other == nil { + return false + } + var col2 *CorrelatedColumn + switch x := other.(type) { + case CorrelatedColumn: + col2 = &x + case *CorrelatedColumn: + col2 = x + default: + return false + } + return col.Column.Equals(&col2.Column) +} + // Column represents a column. type Column struct { RetType *types.FieldType `plan-cache-clone:"shallow"` @@ -458,11 +484,11 @@ func (col *Column) Hash64(h base.Hasher) { h.HashInt64(col.ID) h.HashInt64(col.UniqueID) h.HashInt(col.Index) - if col.VirtualExpr != nil { + if col.VirtualExpr == nil { h.HashByte(base.NilFlag) } else { h.HashByte(base.NotNilFlag) - //col.VirtualExpr.Hash64(h) + col.VirtualExpr.Hash64(h) } h.HashString(col.OrigName) h.HashBool(col.IsHidden) @@ -488,12 +514,12 @@ func (col *Column) Equals(other any) bool { } // when step into here, we could ensure that col1.RetType and col2.RetType are same type. // and we should ensure col1.RetType and col2.RetType is not nil ourselves. - ftEqual := col.RetType == nil && col2.RetType == nil || col.RetType != nil && col2.RetType != nil && col.RetType.Equal(col2.RetType) - return ftEqual && + ok := col.RetType == nil && col2.RetType == nil || col.RetType != nil && col2.RetType != nil && col.RetType.Equal(col2.RetType) + ok = ok && (col.VirtualExpr == nil && col2.VirtualExpr == nil || col.VirtualExpr != nil && col2.VirtualExpr != nil && col.VirtualExpr.Equals(col2.VirtualExpr)) + return ok && col.ID == col2.ID && col.UniqueID == col2.UniqueID && col.Index == col2.Index && - //col.VirtualExpr.Equals(col2.VirtualExpr) && col.OrigName == col2.OrigName && col.IsHidden == col2.IsHidden && col.IsPrefix == col2.IsPrefix && diff --git a/pkg/expression/column_test.go b/pkg/expression/column_test.go index 78dcd8c061443..153cd6f929b43 100644 --- a/pkg/expression/column_test.go +++ b/pkg/expression/column_test.go @@ -415,8 +415,7 @@ func TestColumnHashEquals(t *testing.T) { require.NotEqual(t, hasher1.Sum64(), hasher2.Sum64()) require.False(t, col1.Equals(col2)) - // diff VirtualExpr - // TODO: add HashEquals for VirtualExpr + // diff VirtualExpr see TestColumnHashEuqals4VirtualExpr // diff OrigName col2.Index = col1.Index @@ -468,3 +467,29 @@ func TestColumnHashEquals(t *testing.T) { require.NotEqual(t, hasher1.Sum64(), hasher2.Sum64()) require.False(t, col1.Equals(col2)) } + +func TestColumnHashEuqals4VirtualExpr(t *testing.T) { + col1 := &Column{UniqueID: 1, VirtualExpr: NewZero()} + col2 := &Column{UniqueID: 1, VirtualExpr: nil} + hasher1 := base.NewHashEqualer() + hasher2 := base.NewHashEqualer() + col1.Hash64(hasher1) + col2.Hash64(hasher2) + require.NotEqual(t, hasher1.Sum64(), hasher2.Sum64()) + require.False(t, col1.Equals(col2)) + + col2.VirtualExpr = NewZero() + hasher2.Reset() + col2.Hash64(hasher2) + require.Equal(t, hasher1.Sum64(), hasher2.Sum64()) + require.True(t, col1.Equals(col2)) + + col1.VirtualExpr = nil + col2.VirtualExpr = nil + hasher1.Reset() + hasher2.Reset() + col1.Hash64(hasher1) + col2.Hash64(hasher2) + require.Equal(t, hasher1.Sum64(), hasher2.Sum64()) + require.True(t, col1.Equals(col2)) +} diff --git a/pkg/expression/constant.go b/pkg/expression/constant.go index 210b55186eff8..7b2599247f9c2 100644 --- a/pkg/expression/constant.go +++ b/pkg/expression/constant.go @@ -20,6 +20,7 @@ import ( perrors "github.com/pingcap/errors" "github.com/pingcap/tidb/pkg/parser/mysql" + "github.com/pingcap/tidb/pkg/planner/cascades/base" "github.com/pingcap/tidb/pkg/types" "github.com/pingcap/tidb/pkg/util/chunk" "github.com/pingcap/tidb/pkg/util/codec" @@ -29,6 +30,8 @@ import ( "go.uber.org/zap" ) +var _ base.HashEquals = &Constant{} + // NewOne stands for a number 1. func NewOne() *Constant { retT := types.NewFieldType(mysql.TypeTiny) @@ -502,6 +505,50 @@ func (c *Constant) CanonicalHashCode() []byte { return c.getHashCode(true) } +// Hash64 implements HashEquals.<0th> interface. +func (c *Constant) Hash64(h base.Hasher) { + if c.RetType == nil { + h.HashByte(base.NilFlag) + } else { + h.HashByte(base.NotNilFlag) + c.RetType.Hash64(h) + } + c.collationInfo.Hash64(h) + if c.DeferredExpr != nil { + c.DeferredExpr.Hash64(h) + return + } + if c.ParamMarker != nil { + h.HashByte(parameterFlag) + h.HashInt64(int64(c.ParamMarker.order)) + return + } + intest.Assert(c.DeferredExpr == nil && c.ParamMarker == nil) + h.HashByte(constantFlag) + c.Value.Hash64(h) +} + +// Equals implements HashEquals.<1st> interface. +func (c *Constant) Equals(other any) bool { + if other == nil { + return false + } + var c2 *Constant + switch x := other.(type) { + case *Constant: + c2 = x + case Constant: + c2 = &x + default: + return false + } + ok := c.RetType == nil && c2.RetType == nil || c.RetType != nil && c2.RetType != nil && c.RetType.Equals(c2.RetType) + ok = ok && c.collationInfo.Equals(c2.collationInfo) + ok = ok && (c.DeferredExpr == nil && c2.DeferredExpr == nil || c.DeferredExpr != nil && c2.DeferredExpr != nil && c.DeferredExpr.Equals(c2.DeferredExpr)) + ok = ok && (c.ParamMarker == nil && c2.ParamMarker == nil || c.ParamMarker != nil && c2.ParamMarker != nil && c.ParamMarker.order == c2.ParamMarker.order) + return ok && c.Value.Equals(c2.Value) +} + func (c *Constant) getHashCode(canonical bool) []byte { if len(c.hashcode) > 0 { return c.hashcode diff --git a/pkg/expression/constant_test.go b/pkg/expression/constant_test.go index 8be33601bae5d..b338d6238c1a2 100644 --- a/pkg/expression/constant_test.go +++ b/pkg/expression/constant_test.go @@ -25,6 +25,7 @@ import ( exprctx "github.com/pingcap/tidb/pkg/expression/context" "github.com/pingcap/tidb/pkg/parser/ast" "github.com/pingcap/tidb/pkg/parser/mysql" + "github.com/pingcap/tidb/pkg/planner/cascades/base" "github.com/pingcap/tidb/pkg/types" "github.com/pingcap/tidb/pkg/util/chunk" "github.com/pingcap/tidb/pkg/util/mock" @@ -545,3 +546,30 @@ func TestSpecificConstant(t *testing.T) { require.Equal(t, null.RetType.GetFlen(), 1) require.Equal(t, null.RetType.GetDecimal(), 0) } + +func TestConstantHashEquals(t *testing.T) { + // Test for Hash64 interface + cst1 := &Constant{Value: types.NewIntDatum(2333), RetType: newIntFieldType()} + cst2 := &Constant{Value: types.NewIntDatum(2333), RetType: newIntFieldType()} + hasher1 := base.NewHashEqualer() + hasher2 := base.NewHashEqualer() + cst1.Hash64(hasher1) + cst2.Hash64(hasher2) + require.Equal(t, hasher1.Sum64(), hasher2.Sum64()) + require.True(t, cst1.Equals(cst2)) + + // test cst2 datum changes. + cst2.Value = types.NewIntDatum(2334) + hasher2.Reset() + cst2.Hash64(hasher2) + require.NotEqual(t, hasher1.Sum64(), hasher2.Sum64()) + require.False(t, cst1.Equals(cst2)) + + // test cst2 type changes. + cst2.Value = types.NewIntDatum(2333) + cst2.RetType = newStringFieldType() + hasher2.Reset() + cst2.Hash64(hasher2) + require.NotEqual(t, hasher1.Sum64(), hasher2.Sum64()) + require.False(t, cst1.Equals(cst2)) +} diff --git a/pkg/expression/expression.go b/pkg/expression/expression.go index 07f83c1f06609..799dd58da850c 100644 --- a/pkg/expression/expression.go +++ b/pkg/expression/expression.go @@ -27,6 +27,7 @@ import ( "github.com/pingcap/tidb/pkg/parser/mysql" "github.com/pingcap/tidb/pkg/parser/opcode" "github.com/pingcap/tidb/pkg/parser/terror" + "github.com/pingcap/tidb/pkg/planner/cascades/base" "github.com/pingcap/tidb/pkg/types" "github.com/pingcap/tidb/pkg/util/chunk" "github.com/pingcap/tidb/pkg/util/generatedexpr" @@ -42,6 +43,7 @@ const ( scalarFunctionFlag byte = 3 parameterFlag byte = 4 ScalarSubQFlag byte = 5 + correlatedColumn byte = 6 ) // EvalSimpleAst evaluates a simple ast expression directly. @@ -170,6 +172,7 @@ const ( type Expression interface { VecExpr CollationInfo + base.HashEquals Traverse(TraverseAction) Expression diff --git a/pkg/expression/scalar_function.go b/pkg/expression/scalar_function.go index 3ec1c6cb7f8cd..f36581eabbcd1 100644 --- a/pkg/expression/scalar_function.go +++ b/pkg/expression/scalar_function.go @@ -26,6 +26,7 @@ import ( "github.com/pingcap/tidb/pkg/parser/model" "github.com/pingcap/tidb/pkg/parser/mysql" "github.com/pingcap/tidb/pkg/parser/terror" + "github.com/pingcap/tidb/pkg/planner/cascades/base" "github.com/pingcap/tidb/pkg/sessionctx/variable" "github.com/pingcap/tidb/pkg/types" "github.com/pingcap/tidb/pkg/util/chunk" @@ -35,6 +36,8 @@ import ( "github.com/pingcap/tidb/pkg/util/intest" ) +var _ base.HashEquals = &ScalarFunction{} + // ScalarFunction is the function that returns a value. type ScalarFunction struct { FuncName model.CIStr @@ -673,6 +676,51 @@ func simpleCanonicalizedHashCode(sf *ScalarFunction) { } } +// Hash64 implements HashEquals.<0th> interface. +func (sf *ScalarFunction) Hash64(h base.Hasher) { + h.HashByte(scalarFunctionFlag) + h.HashString(sf.FuncName.L) + if sf.RetType == nil { + h.HashByte(base.NilFlag) + } else { + h.HashByte(base.NotNilFlag) + sf.RetType.Hash64(h) + } + // hash the arg length to avoid hash collision. + h.HashInt(len(sf.GetArgs())) + for _, arg := range sf.GetArgs() { + arg.Hash64(h) + } +} + +// Equals implements HashEquals.<1th> interface. +func (sf *ScalarFunction) Equals(other any) bool { + if other == nil { + return false + } + var sf2 *ScalarFunction + switch x := other.(type) { + case *ScalarFunction: + sf2 = x + case ScalarFunction: + sf2 = &x + default: + return false + } + ok := sf.FuncName.L == sf2.FuncName.L + ok = ok && (sf.RetType == nil && sf2.RetType == nil || sf.RetType != nil && sf2.RetType != nil && sf.RetType.Equals(sf2.RetType)) + if len(sf.GetArgs()) != len(sf2.GetArgs()) { + return false + } + for i, arg := range sf.GetArgs() { + ok = ok && arg.Equals(sf2.GetArgs()[i]) + if !ok { + return false + } + } + return ok +} + // ReHashCode is used after we change the argument in place. func ReHashCode(sf *ScalarFunction) { sf.hashcode = sf.hashcode[:0] diff --git a/pkg/expression/scalar_function_test.go b/pkg/expression/scalar_function_test.go index 810adbb3b569a..eb1920eacb6a2 100644 --- a/pkg/expression/scalar_function_test.go +++ b/pkg/expression/scalar_function_test.go @@ -19,6 +19,7 @@ import ( "github.com/pingcap/tidb/pkg/parser/ast" "github.com/pingcap/tidb/pkg/parser/mysql" + "github.com/pingcap/tidb/pkg/planner/cascades/base" "github.com/pingcap/tidb/pkg/types" "github.com/pingcap/tidb/pkg/util/chunk" "github.com/pingcap/tidb/pkg/util/mock" @@ -147,3 +148,40 @@ func TestScalarFuncs2Exprs(t *testing.T) { require.True(t, exprs[i].Equal(ctx, funcs[i])) } } + +func TestScalarFunctionHash64Equals(t *testing.T) { + a := &Column{ + UniqueID: 1, + RetType: types.NewFieldType(mysql.TypeDouble), + } + sf0, _ := newFunctionWithMockCtx(ast.LT, a, NewZero()).(*ScalarFunction) + sf1, _ := newFunctionWithMockCtx(ast.LT, a, NewZero()).(*ScalarFunction) + hasher1 := base.NewHashEqualer() + hasher2 := base.NewHashEqualer() + sf0.Hash64(hasher1) + sf1.Hash64(hasher2) + require.Equal(t, hasher1.Sum64(), hasher2.Sum64()) + require.True(t, sf0.Equals(sf1)) + + // change the func name + sf2, _ := newFunctionWithMockCtx(ast.GT, a, NewZero()).(*ScalarFunction) + hasher2.Reset() + sf2.Hash64(hasher2) + require.NotEqual(t, hasher1.Sum64(), hasher2.Sum64()) + require.False(t, sf0.Equals(sf2)) + + // change the args + sf3, _ := newFunctionWithMockCtx(ast.LT, a, NewOne()).(*ScalarFunction) + hasher2.Reset() + sf3.Hash64(hasher2) + require.NotEqual(t, hasher1.Sum64(), hasher2.Sum64()) + require.False(t, sf0.Equals(sf3)) + + // change the ret type + sf4, _ := newFunctionWithMockCtx(ast.LT, a, NewZero()).(*ScalarFunction) + sf4.RetType = types.NewFieldType(mysql.TypeLong) + hasher2.Reset() + sf4.Hash64(hasher2) + require.NotEqual(t, hasher1.Sum64(), hasher2.Sum64()) + require.False(t, sf0.Equals(sf4)) +} diff --git a/pkg/expression/util_test.go b/pkg/expression/util_test.go index f236e943a9d0b..1329149b60b3a 100644 --- a/pkg/expression/util_test.go +++ b/pkg/expression/util_test.go @@ -22,6 +22,7 @@ import ( "github.com/pingcap/tidb/pkg/parser/ast" "github.com/pingcap/tidb/pkg/parser/model" "github.com/pingcap/tidb/pkg/parser/mysql" + "github.com/pingcap/tidb/pkg/planner/cascades/base" "github.com/pingcap/tidb/pkg/sessionctx/stmtctx" "github.com/pingcap/tidb/pkg/types" "github.com/pingcap/tidb/pkg/util/chunk" @@ -661,3 +662,5 @@ func (m *MockExpr) MemoryUsage() (sum int64) { func (m *MockExpr) Traverse(action TraverseAction) Expression { return action.Transform(m) } +func (m *MockExpr) Hash64(_ base.Hasher) {} +func (m *MockExpr) Equals(_ any) bool { return false } diff --git a/pkg/planner/core/BUILD.bazel b/pkg/planner/core/BUILD.bazel index 6a7fa154f4ae5..b511fdfbde465 100644 --- a/pkg/planner/core/BUILD.bazel +++ b/pkg/planner/core/BUILD.bazel @@ -126,6 +126,7 @@ go_library( "//pkg/parser/terror", "//pkg/parser/types", "//pkg/planner/cardinality", + "//pkg/planner/cascades/base", "//pkg/planner/context", "//pkg/planner/core/base", "//pkg/planner/core/constraint", diff --git a/pkg/planner/core/scalar_subq_expression.go b/pkg/planner/core/scalar_subq_expression.go index 17dd07fdc1f1e..2a7f2bf16ecee 100644 --- a/pkg/planner/core/scalar_subq_expression.go +++ b/pkg/planner/core/scalar_subq_expression.go @@ -22,6 +22,7 @@ import ( "github.com/pingcap/errors" "github.com/pingcap/tidb/pkg/expression" "github.com/pingcap/tidb/pkg/infoschema" + base2 "github.com/pingcap/tidb/pkg/planner/cascades/base" "github.com/pingcap/tidb/pkg/planner/core/base" "github.com/pingcap/tidb/pkg/planner/core/operator/baseimpl" "github.com/pingcap/tidb/pkg/types" @@ -224,6 +225,29 @@ func (s *ScalarSubQueryExpr) ExplainNormalizedInfo() string { return s.String() } +// Hash64 implements the HashEquals.<0th> interface. +func (s *ScalarSubQueryExpr) Hash64(h base2.Hasher) { + h.HashByte(expression.ScalarSubQFlag) + h.HashInt64(s.scalarSubqueryColID) +} + +// Equals implements the HashEquals.<1st> interface. +func (s *ScalarSubQueryExpr) Equals(other any) bool { + if other == nil { + return false + } + var s2 *ScalarSubQueryExpr + switch x := other.(type) { + case *ScalarSubQueryExpr: + s2 = x + case ScalarSubQueryExpr: + s2 = &x + default: + return false + } + return s.scalarSubqueryColID == s2.scalarSubqueryColID +} + // HashCode implements the Expression interface. func (s *ScalarSubQueryExpr) HashCode() []byte { if len(s.hashcode) != 0 {