diff --git a/expression/builtin_info.go b/expression/builtin_info.go index b498e96f289e1..4bb90b8e9901b 100644 --- a/expression/builtin_info.go +++ b/expression/builtin_info.go @@ -148,14 +148,12 @@ func (b *builtinCurrentUserSig) Clone() builtinFunc { // evalString evals a builtinCurrentUserSig. // See https://dev.mysql.com/doc/refman/5.7/en/information-functions.html#function_current-user -// TODO: The value of CURRENT_USER() can differ from the value of USER(). We will finish this after we support grant tables. func (b *builtinCurrentUserSig) evalString(row chunk.Row) (string, bool, error) { data := b.ctx.GetSessionVars() if data == nil || data.User == nil { return "", true, errors.Errorf("Missing session variable when eval builtin") } - - return data.User.String(), false, nil + return data.User.AuthIdentityString(), false, nil } type userFunctionClass struct { diff --git a/expression/builtin_info_test.go b/expression/builtin_info_test.go index c71ce15c88bd5..c4bae7f784ca0 100644 --- a/expression/builtin_info_test.go +++ b/expression/builtin_info_test.go @@ -84,7 +84,7 @@ func (s *testEvaluatorSuite) TestCurrentUser(c *C) { defer testleak.AfterTest(c)() ctx := mock.NewContext() sessionVars := ctx.GetSessionVars() - sessionVars.User = &auth.UserIdentity{Username: "root", Hostname: "localhost"} + sessionVars.User = &auth.UserIdentity{Username: "root", Hostname: "localhost", AuthUsername: "root", AuthHostname: "localhost"} fc := funcs[ast.CurrentUser] f, err := fc.getFunction(ctx, nil) diff --git a/expression/integration_test.go b/expression/integration_test.go index d2e0dc125767c..20e1a6cec3fcc 100644 --- a/expression/integration_test.go +++ b/expression/integration_test.go @@ -2387,13 +2387,13 @@ func (s *testIntegrationSuite) TestInfoBuiltin(c *C) { // for current_user sessionVars := tk.Se.GetSessionVars() originUser := sessionVars.User - sessionVars.User = &auth.UserIdentity{Username: "root", Hostname: "localhost"} + sessionVars.User = &auth.UserIdentity{Username: "root", Hostname: "localhost", AuthUsername: "root", AuthHostname: "127.0.%%"} result = tk.MustQuery("select current_user()") - result.Check(testkit.Rows("root@localhost")) + result.Check(testkit.Rows("root@127.0.%%")) sessionVars.User = originUser // for user - sessionVars.User = &auth.UserIdentity{Username: "root", Hostname: "localhost"} + sessionVars.User = &auth.UserIdentity{Username: "root", Hostname: "localhost", AuthUsername: "root", AuthHostname: "127.0.%%"} result = tk.MustQuery("select user()") result.Check(testkit.Rows("root@localhost")) sessionVars.User = originUser diff --git a/privilege/privilege.go b/privilege/privilege.go index 34c210086c671..e14b95914af84 100644 --- a/privilege/privilege.go +++ b/privilege/privilege.go @@ -38,7 +38,7 @@ type Manager interface { // this means any privilege would be OK. RequestVerification(db, table, column string, priv mysql.PrivilegeType) bool // ConnectionVerification verifies user privilege for connection. - ConnectionVerification(user, host string, auth, salt []byte) bool + ConnectionVerification(user, host string, auth, salt []byte) (string, string, bool) // DBIsVisible returns true is the database is visible to current user. DBIsVisible(db string) bool diff --git a/privilege/privileges/privileges.go b/privilege/privileges/privileges.go index 2230a44d7be52..6105875910077 100644 --- a/privilege/privileges/privileges.go +++ b/privilege/privileges/privileges.go @@ -70,50 +70,57 @@ func (p *UserPrivileges) RequestVerification(db, table, column string, priv mysq } // ConnectionVerification implements the Manager interface. -func (p *UserPrivileges) ConnectionVerification(user, host string, authentication, salt []byte) bool { +func (p *UserPrivileges) ConnectionVerification(user, host string, authentication, salt []byte) (u string, h string, success bool) { + if SkipWithGrant { p.user = user p.host = host - return true + success = true + return } mysqlPriv := p.Handle.Get() record := mysqlPriv.connectionVerification(user, host) if record == nil { log.Errorf("Get user privilege record fail: user %v, host %v", user, host) - return false + return } + u = record.User + h = record.Host + pwd := record.Password if len(pwd) != 0 && len(pwd) != mysql.PWDHashLen+1 { log.Errorf("User [%s] password from SystemDB not like a sha1sum", user) - return false + return } // empty password if len(pwd) == 0 && len(authentication) == 0 { p.user = user p.host = host - return true + success = true + return } if len(pwd) == 0 || len(authentication) == 0 { - return false + return } hpwd, err := auth.DecodePassword(pwd) if err != nil { log.Errorf("Decode password string error %v", err) - return false + return } if !auth.CheckScrambledPassword(salt, hpwd, authentication) { - return false + return } p.user = user p.host = host - return true + success = true + return } // DBIsVisible implements the Manager interface. diff --git a/session/session.go b/session/session.go index 716eb8eda60b4..f71ef274f6855 100644 --- a/session/session.go +++ b/session/session.go @@ -1002,17 +1002,22 @@ func (s *session) Auth(user *auth.UserIdentity, authentication []byte, salt []by pm := privilege.GetPrivilegeManager(s) // Check IP. - if pm.ConnectionVerification(user.Username, user.Hostname, authentication, salt) { + var success bool + user.AuthUsername, user.AuthHostname, success = pm.ConnectionVerification(user.Username, user.Hostname, authentication, salt) + if success { s.sessionVars.User = user return true } // Check Hostname. for _, addr := range getHostByIP(user.Hostname) { - if pm.ConnectionVerification(user.Username, addr, authentication, salt) { + u, h, success := pm.ConnectionVerification(user.Username, addr, authentication, salt) + if success { s.sessionVars.User = &auth.UserIdentity{ - Username: user.Username, - Hostname: addr, + Username: user.Username, + Hostname: addr, + AuthUsername: u, + AuthHostname: h, } return true } diff --git a/util/auth/auth.go b/util/auth/auth.go index 0c60b7d67d892..74187d6e0efdb 100644 --- a/util/auth/auth.go +++ b/util/auth/auth.go @@ -25,9 +25,11 @@ import ( // UserIdentity represents username and hostname. type UserIdentity struct { - Username string - Hostname string - CurrentUser bool + Username string + Hostname string + CurrentUser bool + AuthUsername string // Username matched in privileges system + AuthHostname string // Match in privs system (i.e. could be a wildcard) } // String converts UserIdentity to the format user@host. @@ -36,6 +38,12 @@ func (user *UserIdentity) String() string { return fmt.Sprintf("%s@%s", user.Username, user.Hostname) } +// AuthIdentityString returns matched identity in user@host format +func (user *UserIdentity) AuthIdentityString() string { + // TODO: Escape username and hostname. + return fmt.Sprintf("%s@%s", user.AuthUsername, user.AuthHostname) +} + // CheckScrambledPassword check scrambled password received from client. // The new authentication is performed in following manner: // SERVER: public_seed=create_random_string()