Skip to content

Commit

Permalink
expression: MySQL compatible current_user function (#7801)
Browse files Browse the repository at this point in the history
  • Loading branch information
morgo authored and lysu committed Oct 16, 2018
1 parent 1730798 commit 19e4e2f
Show file tree
Hide file tree
Showing 7 changed files with 42 additions and 24 deletions.
4 changes: 1 addition & 3 deletions expression/builtin_info.go
Original file line number Diff line number Diff line change
Expand Up @@ -148,14 +148,12 @@ func (b *builtinCurrentUserSig) Clone() builtinFunc {

// evalString evals a builtinCurrentUserSig.
// See https://dev.mysql.com/doc/refman/5.7/en/information-functions.html#function_current-user
// TODO: The value of CURRENT_USER() can differ from the value of USER(). We will finish this after we support grant tables.
func (b *builtinCurrentUserSig) evalString(row chunk.Row) (string, bool, error) {
data := b.ctx.GetSessionVars()
if data == nil || data.User == nil {
return "", true, errors.Errorf("Missing session variable when eval builtin")
}

return data.User.String(), false, nil
return data.User.AuthIdentityString(), false, nil
}

type userFunctionClass struct {
Expand Down
2 changes: 1 addition & 1 deletion expression/builtin_info_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ func (s *testEvaluatorSuite) TestCurrentUser(c *C) {
defer testleak.AfterTest(c)()
ctx := mock.NewContext()
sessionVars := ctx.GetSessionVars()
sessionVars.User = &auth.UserIdentity{Username: "root", Hostname: "localhost"}
sessionVars.User = &auth.UserIdentity{Username: "root", Hostname: "localhost", AuthUsername: "root", AuthHostname: "localhost"}

fc := funcs[ast.CurrentUser]
f, err := fc.getFunction(ctx, nil)
Expand Down
6 changes: 3 additions & 3 deletions expression/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2387,13 +2387,13 @@ func (s *testIntegrationSuite) TestInfoBuiltin(c *C) {
// for current_user
sessionVars := tk.Se.GetSessionVars()
originUser := sessionVars.User
sessionVars.User = &auth.UserIdentity{Username: "root", Hostname: "localhost"}
sessionVars.User = &auth.UserIdentity{Username: "root", Hostname: "localhost", AuthUsername: "root", AuthHostname: "127.0.%%"}
result = tk.MustQuery("select current_user()")
result.Check(testkit.Rows("root@localhost"))
result.Check(testkit.Rows("root@127.0.%%"))
sessionVars.User = originUser

// for user
sessionVars.User = &auth.UserIdentity{Username: "root", Hostname: "localhost"}
sessionVars.User = &auth.UserIdentity{Username: "root", Hostname: "localhost", AuthUsername: "root", AuthHostname: "127.0.%%"}
result = tk.MustQuery("select user()")
result.Check(testkit.Rows("root@localhost"))
sessionVars.User = originUser
Expand Down
2 changes: 1 addition & 1 deletion privilege/privilege.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ type Manager interface {
// this means any privilege would be OK.
RequestVerification(db, table, column string, priv mysql.PrivilegeType) bool
// ConnectionVerification verifies user privilege for connection.
ConnectionVerification(user, host string, auth, salt []byte) bool
ConnectionVerification(user, host string, auth, salt []byte) (string, string, bool)

// DBIsVisible returns true is the database is visible to current user.
DBIsVisible(db string) bool
Expand Down
25 changes: 16 additions & 9 deletions privilege/privileges/privileges.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,50 +70,57 @@ func (p *UserPrivileges) RequestVerification(db, table, column string, priv mysq
}

// ConnectionVerification implements the Manager interface.
func (p *UserPrivileges) ConnectionVerification(user, host string, authentication, salt []byte) bool {
func (p *UserPrivileges) ConnectionVerification(user, host string, authentication, salt []byte) (u string, h string, success bool) {

if SkipWithGrant {
p.user = user
p.host = host
return true
success = true
return
}

mysqlPriv := p.Handle.Get()
record := mysqlPriv.connectionVerification(user, host)
if record == nil {
log.Errorf("Get user privilege record fail: user %v, host %v", user, host)
return false
return
}

u = record.User
h = record.Host

pwd := record.Password
if len(pwd) != 0 && len(pwd) != mysql.PWDHashLen+1 {
log.Errorf("User [%s] password from SystemDB not like a sha1sum", user)
return false
return
}

// empty password
if len(pwd) == 0 && len(authentication) == 0 {
p.user = user
p.host = host
return true
success = true
return
}

if len(pwd) == 0 || len(authentication) == 0 {
return false
return
}

hpwd, err := auth.DecodePassword(pwd)
if err != nil {
log.Errorf("Decode password string error %v", err)
return false
return
}

if !auth.CheckScrambledPassword(salt, hpwd, authentication) {
return false
return
}

p.user = user
p.host = host
return true
success = true
return
}

// DBIsVisible implements the Manager interface.
Expand Down
13 changes: 9 additions & 4 deletions session/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -1002,17 +1002,22 @@ func (s *session) Auth(user *auth.UserIdentity, authentication []byte, salt []by
pm := privilege.GetPrivilegeManager(s)

// Check IP.
if pm.ConnectionVerification(user.Username, user.Hostname, authentication, salt) {
var success bool
user.AuthUsername, user.AuthHostname, success = pm.ConnectionVerification(user.Username, user.Hostname, authentication, salt)
if success {
s.sessionVars.User = user
return true
}

// Check Hostname.
for _, addr := range getHostByIP(user.Hostname) {
if pm.ConnectionVerification(user.Username, addr, authentication, salt) {
u, h, success := pm.ConnectionVerification(user.Username, addr, authentication, salt)
if success {
s.sessionVars.User = &auth.UserIdentity{
Username: user.Username,
Hostname: addr,
Username: user.Username,
Hostname: addr,
AuthUsername: u,
AuthHostname: h,
}
return true
}
Expand Down
14 changes: 11 additions & 3 deletions util/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,11 @@ import (

// UserIdentity represents username and hostname.
type UserIdentity struct {
Username string
Hostname string
CurrentUser bool
Username string
Hostname string
CurrentUser bool
AuthUsername string // Username matched in privileges system
AuthHostname string // Match in privs system (i.e. could be a wildcard)
}

// String converts UserIdentity to the format user@host.
Expand All @@ -36,6 +38,12 @@ func (user *UserIdentity) String() string {
return fmt.Sprintf("%s@%s", user.Username, user.Hostname)
}

// AuthIdentityString returns matched identity in user@host format
func (user *UserIdentity) AuthIdentityString() string {
// TODO: Escape username and hostname.
return fmt.Sprintf("%s@%s", user.AuthUsername, user.AuthHostname)
}

// CheckScrambledPassword check scrambled password received from client.
// The new authentication is performed in following manner:
// SERVER: public_seed=create_random_string()
Expand Down

0 comments on commit 19e4e2f

Please sign in to comment.