Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

planner: introduce hashEquals interface for expression.Expression #55793

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 31 additions & 5 deletions pkg/expression/column.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ import (

var (
_ base.HashEquals = &Column{}
_ base.HashEquals = &CorrelatedColumn{}
)

// CorrelatedColumn stands for a column in a correlated sub query.
Expand Down Expand Up @@ -246,6 +247,31 @@ func (col *CorrelatedColumn) RemapColumn(m map[int64]*Column) (Expression, error
}, nil
}

// Hash64 implements HashEquals.<0th> interface.
func (col *CorrelatedColumn) Hash64(h base.Hasher) {
// correlatedColumn flag here is used to distinguish correlatedColumn and Column.
h.HashByte(correlatedColumn)
col.Column.Hash64(h)
// since col.Datum is filled in the runtime, we can't use it to calculate hash now, correlatedColumn flag + column is enough.
}

// Equals implements HashEquals.<1st> interface.
func (col *CorrelatedColumn) Equals(other any) bool {
if other == nil {
return false
}
var col2 *CorrelatedColumn
switch x := other.(type) {
case CorrelatedColumn:
col2 = &x
case *CorrelatedColumn:
col2 = x
default:
return false
}
return col.Column.Equals(&col2.Column)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't need to check Data in CorrelatedColumn?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

runtime bound data shouldn‘t be cared in planner phase

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK

}

// Column represents a column.
type Column struct {
RetType *types.FieldType `plan-cache-clone:"shallow"`
Expand Down Expand Up @@ -458,11 +484,11 @@ func (col *Column) Hash64(h base.Hasher) {
h.HashInt64(col.ID)
h.HashInt64(col.UniqueID)
h.HashInt(col.Index)
if col.VirtualExpr != nil {
if col.VirtualExpr == nil {
h.HashByte(base.NilFlag)
} else {
h.HashByte(base.NotNilFlag)
//col.VirtualExpr.Hash64(h)
col.VirtualExpr.Hash64(h)
}
h.HashString(col.OrigName)
h.HashBool(col.IsHidden)
Expand All @@ -488,12 +514,12 @@ func (col *Column) Equals(other any) bool {
}
// when step into here, we could ensure that col1.RetType and col2.RetType are same type.
// and we should ensure col1.RetType and col2.RetType is not nil ourselves.
ftEqual := col.RetType == nil && col2.RetType == nil || col.RetType != nil && col2.RetType != nil && col.RetType.Equal(col2.RetType)
return ftEqual &&
ok := col.RetType == nil && col2.RetType == nil || col.RetType != nil && col2.RetType != nil && col.RetType.Equal(col2.RetType)
ok = ok && (col.VirtualExpr == nil && col2.VirtualExpr == nil || col.VirtualExpr != nil && col2.VirtualExpr != nil && col.VirtualExpr.Equals(col2.VirtualExpr))
return ok &&
col.ID == col2.ID &&
col.UniqueID == col2.UniqueID &&
col.Index == col2.Index &&
//col.VirtualExpr.Equals(col2.VirtualExpr) &&
col.OrigName == col2.OrigName &&
col.IsHidden == col2.IsHidden &&
col.IsPrefix == col2.IsPrefix &&
Expand Down
29 changes: 27 additions & 2 deletions pkg/expression/column_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -414,8 +414,7 @@ func TestColumnHashEquals(t *testing.T) {
require.NotEqual(t, hasher1.Sum64(), hasher2.Sum64())
require.False(t, col1.Equals(col2))

// diff VirtualExpr
// TODO: add HashEquals for VirtualExpr
// diff VirtualExpr see TestColumnHashEuqals4VirtualExpr

// diff OrigName
col2.Index = col1.Index
Expand Down Expand Up @@ -467,3 +466,29 @@ func TestColumnHashEquals(t *testing.T) {
require.NotEqual(t, hasher1.Sum64(), hasher2.Sum64())
require.False(t, col1.Equals(col2))
}

func TestColumnHashEuqals4VirtualExpr(t *testing.T) {
col1 := &Column{UniqueID: 1, VirtualExpr: NewZero()}
col2 := &Column{UniqueID: 1, VirtualExpr: nil}
hasher1 := base.NewHashEqualer()
hasher2 := base.NewHashEqualer()
col1.Hash64(hasher1)
col2.Hash64(hasher2)
require.NotEqual(t, hasher1.Sum64(), hasher2.Sum64())
require.False(t, col1.Equals(col2))

col2.VirtualExpr = NewZero()
hasher2.Reset()
col2.Hash64(hasher2)
require.Equal(t, hasher1.Sum64(), hasher2.Sum64())
require.True(t, col1.Equals(col2))

col1.VirtualExpr = nil
col2.VirtualExpr = nil
hasher1.Reset()
hasher2.Reset()
col1.Hash64(hasher1)
col2.Hash64(hasher2)
require.Equal(t, hasher1.Sum64(), hasher2.Sum64())
require.True(t, col1.Equals(col2))
}
47 changes: 47 additions & 0 deletions pkg/expression/constant.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (

perrors "github.com/pingcap/errors"
"github.com/pingcap/tidb/pkg/parser/mysql"
"github.com/pingcap/tidb/pkg/planner/cascades/base"
"github.com/pingcap/tidb/pkg/types"
"github.com/pingcap/tidb/pkg/util/chunk"
"github.com/pingcap/tidb/pkg/util/codec"
Expand All @@ -29,6 +30,8 @@ import (
"go.uber.org/zap"
)

var _ base.HashEquals = &Constant{}

// NewOne stands for a number 1.
func NewOne() *Constant {
retT := types.NewFieldType(mysql.TypeTiny)
Expand Down Expand Up @@ -502,6 +505,50 @@ func (c *Constant) CanonicalHashCode() []byte {
return c.getHashCode(true)
}

// Hash64 implements HashEquals.<0th> interface.
func (c *Constant) Hash64(h base.Hasher) {
if c.RetType == nil {
h.HashByte(base.NilFlag)
} else {
h.HashByte(base.NotNilFlag)
c.RetType.Hash64(h)
}
c.collationInfo.Hash64(h)
if c.DeferredExpr != nil {
c.DeferredExpr.Hash64(h)
return
}
if c.ParamMarker != nil {
h.HashByte(parameterFlag)
h.HashInt64(int64(c.ParamMarker.order))
return
}
intest.Assert(c.DeferredExpr == nil && c.ParamMarker == nil)
h.HashByte(constantFlag)
c.Value.Hash64(h)
}

// Equals implements HashEquals.<1st> interface.
func (c *Constant) Equals(other any) bool {
if other == nil {
return false
}
var c2 *Constant
switch x := other.(type) {
case *Constant:
c2 = x
case Constant:
c2 = &x
default:
return false
}
ok := c.RetType == nil && c2.RetType == nil || c.RetType != nil && c2.RetType != nil && c.RetType.Equals(c2.RetType)
ok = ok && c.collationInfo.Equals(c2.collationInfo)
ok = ok && (c.DeferredExpr == nil && c2.DeferredExpr == nil || c.DeferredExpr != nil && c2.DeferredExpr != nil && c.DeferredExpr.Equals(c2.DeferredExpr))
ok = ok && (c.ParamMarker == nil && c2.ParamMarker == nil || c.ParamMarker != nil && c2.ParamMarker != nil && c.ParamMarker.order == c2.ParamMarker.order)
return ok && c.Value.Equals(c2.Value)
}

func (c *Constant) getHashCode(canonical bool) []byte {
if len(c.hashcode) > 0 {
return c.hashcode
Expand Down
28 changes: 28 additions & 0 deletions pkg/expression/constant_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
exprctx "github.com/pingcap/tidb/pkg/expression/context"
"github.com/pingcap/tidb/pkg/parser/ast"
"github.com/pingcap/tidb/pkg/parser/mysql"
"github.com/pingcap/tidb/pkg/planner/cascades/base"
"github.com/pingcap/tidb/pkg/types"
"github.com/pingcap/tidb/pkg/util/chunk"
"github.com/pingcap/tidb/pkg/util/mock"
Expand Down Expand Up @@ -545,3 +546,30 @@ func TestSpecificConstant(t *testing.T) {
require.Equal(t, null.RetType.GetFlen(), 1)
require.Equal(t, null.RetType.GetDecimal(), 0)
}

func TestConstantHashEquals(t *testing.T) {
// Test for Hash64 interface
cst1 := &Constant{Value: types.NewIntDatum(2333), RetType: newIntFieldType()}
cst2 := &Constant{Value: types.NewIntDatum(2333), RetType: newIntFieldType()}
hasher1 := base.NewHashEqualer()
hasher2 := base.NewHashEqualer()
cst1.Hash64(hasher1)
cst2.Hash64(hasher2)
require.Equal(t, hasher1.Sum64(), hasher2.Sum64())
require.True(t, cst1.Equals(cst2))

// test cst2 datum changes.
cst2.Value = types.NewIntDatum(2334)
hasher2.Reset()
cst2.Hash64(hasher2)
require.NotEqual(t, hasher1.Sum64(), hasher2.Sum64())
require.False(t, cst1.Equals(cst2))

// test cst2 type changes.
cst2.Value = types.NewIntDatum(2333)
cst2.RetType = newStringFieldType()
hasher2.Reset()
cst2.Hash64(hasher2)
require.NotEqual(t, hasher1.Sum64(), hasher2.Sum64())
require.False(t, cst1.Equals(cst2))
}
3 changes: 3 additions & 0 deletions pkg/expression/expression.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"github.com/pingcap/tidb/pkg/parser/mysql"
"github.com/pingcap/tidb/pkg/parser/opcode"
"github.com/pingcap/tidb/pkg/parser/terror"
"github.com/pingcap/tidb/pkg/planner/cascades/base"
"github.com/pingcap/tidb/pkg/types"
"github.com/pingcap/tidb/pkg/util/chunk"
"github.com/pingcap/tidb/pkg/util/generatedexpr"
Expand All @@ -41,6 +42,7 @@ const (
scalarFunctionFlag byte = 3
parameterFlag byte = 4
ScalarSubQFlag byte = 5
correlatedColumn byte = 6
)

// EvalSimpleAst evaluates a simple ast expression directly.
Expand Down Expand Up @@ -169,6 +171,7 @@ const (
type Expression interface {
VecExpr
CollationInfo
base.HashEquals

Traverse(TraverseAction) Expression

Expand Down
48 changes: 48 additions & 0 deletions pkg/expression/scalar_function.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"github.com/pingcap/tidb/pkg/parser/model"
"github.com/pingcap/tidb/pkg/parser/mysql"
"github.com/pingcap/tidb/pkg/parser/terror"
"github.com/pingcap/tidb/pkg/planner/cascades/base"
"github.com/pingcap/tidb/pkg/sessionctx/variable"
"github.com/pingcap/tidb/pkg/types"
"github.com/pingcap/tidb/pkg/util/chunk"
Expand All @@ -35,6 +36,8 @@ import (
"github.com/pingcap/tidb/pkg/util/intest"
)

var _ base.HashEquals = &ScalarFunction{}

// ScalarFunction is the function that returns a value.
type ScalarFunction struct {
FuncName model.CIStr
Expand Down Expand Up @@ -673,6 +676,51 @@ func simpleCanonicalizedHashCode(sf *ScalarFunction) {
}
}

// Hash64 implements HashEquals.<0th> interface.
func (sf *ScalarFunction) Hash64(h base.Hasher) {
h.HashByte(scalarFunctionFlag)
h.HashString(sf.FuncName.L)
if sf.RetType == nil {
h.HashByte(base.NilFlag)
} else {
h.HashByte(base.NotNilFlag)
sf.RetType.Hash64(h)
}
// hash the arg length to avoid hash collision.
h.HashInt(len(sf.GetArgs()))
for _, arg := range sf.GetArgs() {
arg.Hash64(h)
}
}

// Equals implements HashEquals.<1th> interface.
func (sf *ScalarFunction) Equals(other any) bool {
if other == nil {
return false
}
var sf2 *ScalarFunction
switch x := other.(type) {
case *ScalarFunction:
sf2 = x
case ScalarFunction:
sf2 = &x
default:
return false
}
ok := sf.FuncName.L == sf2.FuncName.L
ok = ok && (sf.RetType == nil && sf2.RetType == nil || sf.RetType != nil && sf2.RetType != nil && sf.RetType.Equals(sf2.RetType))
if len(sf.GetArgs()) != len(sf2.GetArgs()) {
return false
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so if ok is false, we can return first.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make sense

for i, arg := range sf.GetArgs() {
ok = ok && arg.Equals(sf2.GetArgs()[i])
if !ok {
return false
}
}
return ok
}

// ReHashCode is used after we change the argument in place.
func ReHashCode(sf *ScalarFunction) {
sf.hashcode = sf.hashcode[:0]
Expand Down
38 changes: 38 additions & 0 deletions pkg/expression/scalar_function_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (

"github.com/pingcap/tidb/pkg/parser/ast"
"github.com/pingcap/tidb/pkg/parser/mysql"
"github.com/pingcap/tidb/pkg/planner/cascades/base"
"github.com/pingcap/tidb/pkg/types"
"github.com/pingcap/tidb/pkg/util/chunk"
"github.com/pingcap/tidb/pkg/util/mock"
Expand Down Expand Up @@ -147,3 +148,40 @@ func TestScalarFuncs2Exprs(t *testing.T) {
require.True(t, exprs[i].Equal(ctx, funcs[i]))
}
}

func TestScalarFunctionHash64Equals(t *testing.T) {
a := &Column{
UniqueID: 1,
RetType: types.NewFieldType(mysql.TypeDouble),
}
sf0, _ := newFunctionWithMockCtx(ast.LT, a, NewZero()).(*ScalarFunction)
sf1, _ := newFunctionWithMockCtx(ast.LT, a, NewZero()).(*ScalarFunction)
hasher1 := base.NewHashEqualer()
hasher2 := base.NewHashEqualer()
sf0.Hash64(hasher1)
sf1.Hash64(hasher2)
require.Equal(t, hasher1.Sum64(), hasher2.Sum64())
require.True(t, sf0.Equals(sf1))

// change the func name
sf2, _ := newFunctionWithMockCtx(ast.GT, a, NewZero()).(*ScalarFunction)
hasher2.Reset()
sf2.Hash64(hasher2)
require.NotEqual(t, hasher1.Sum64(), hasher2.Sum64())
require.False(t, sf0.Equals(sf2))

// change the args
sf3, _ := newFunctionWithMockCtx(ast.LT, a, NewOne()).(*ScalarFunction)
hasher2.Reset()
sf3.Hash64(hasher2)
require.NotEqual(t, hasher1.Sum64(), hasher2.Sum64())
require.False(t, sf0.Equals(sf3))

// change the ret type
sf4, _ := newFunctionWithMockCtx(ast.LT, a, NewZero()).(*ScalarFunction)
sf4.RetType = types.NewFieldType(mysql.TypeLong)
hasher2.Reset()
sf4.Hash64(hasher2)
require.NotEqual(t, hasher1.Sum64(), hasher2.Sum64())
require.False(t, sf0.Equals(sf4))
}
3 changes: 3 additions & 0 deletions pkg/expression/util_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"github.com/pingcap/tidb/pkg/parser/ast"
"github.com/pingcap/tidb/pkg/parser/model"
"github.com/pingcap/tidb/pkg/parser/mysql"
"github.com/pingcap/tidb/pkg/planner/cascades/base"
"github.com/pingcap/tidb/pkg/sessionctx/stmtctx"
"github.com/pingcap/tidb/pkg/types"
"github.com/pingcap/tidb/pkg/util/chunk"
Expand Down Expand Up @@ -661,3 +662,5 @@ func (m *MockExpr) MemoryUsage() (sum int64) {
func (m *MockExpr) Traverse(action TraverseAction) Expression {
return action.Transform(m)
}
func (m *MockExpr) Hash64(_ base.Hasher) {}
func (m *MockExpr) Equals(_ any) bool { return false }
1 change: 1 addition & 0 deletions pkg/planner/core/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ go_library(
"//pkg/parser/terror",
"//pkg/parser/types",
"//pkg/planner/cardinality",
"//pkg/planner/cascades/base",
"//pkg/planner/context",
"//pkg/planner/core/base",
"//pkg/planner/core/constraint",
Expand Down
Loading