Skip to content

Commit

Permalink
Support TLS to connect mysql/TiDB (#894)
Browse files Browse the repository at this point in the history
  • Loading branch information
july2993 committed Mar 15, 2020
1 parent 738f0ca commit 7e30098
Show file tree
Hide file tree
Showing 12 changed files with 61 additions and 17 deletions.
2 changes: 1 addition & 1 deletion arbiter/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ func NewServer(cfg *Config) (srv *Server, err error) {
up := cfg.Up
down := cfg.Down

srv.downDB, err = createDB(down.User, down.Password, down.Host, down.Port)
srv.downDB, err = createDB(down.User, down.Password, down.Host, down.Port, nil)
if err != nil {
return nil, errors.Trace(err)
}
Expand Down
7 changes: 4 additions & 3 deletions arbiter/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ package arbiter

import (
"context"
"crypto/tls"
"database/sql"
"fmt"
"sync"
Expand Down Expand Up @@ -55,7 +56,7 @@ func (l *dummyLoader) Close() {
type testNewServerSuite struct {
db *sql.DB
dbMock sqlmock.Sqlmock
origCreateDB func(string, string, string, int) (*sql.DB, error)
origCreateDB func(string, string, string, int, *tls.Config) (*sql.DB, error)
origNewReader func(*reader.Config) (*reader.Reader, error)
origNewLoader func(*sql.DB, ...loader.Option) (loader.Loader, error)
}
Expand All @@ -71,7 +72,7 @@ func (s *testNewServerSuite) SetUpTest(c *C) {
s.dbMock = mock

s.origCreateDB = createDB
createDB = func(user string, password string, host string, port int) (*sql.DB, error) {
createDB = func(user string, password string, host string, port int, _ *tls.Config) (*sql.DB, error) {
return s.db, nil
}

Expand Down Expand Up @@ -105,7 +106,7 @@ func (s *testNewServerSuite) TestRejectInvalidAddr(c *C) {
}

func (s *testNewServerSuite) TestStopIfFailedtoConnectDownStream(c *C) {
createDB = func(user string, password string, host string, port int) (*sql.DB, error) {
createDB = func(user string, password string, host string, port int, _ *tls.Config) (*sql.DB, error) {
return nil, fmt.Errorf("Can't create db")
}

Expand Down
10 changes: 10 additions & 0 deletions cmd/drainer/drainer.toml
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,16 @@ port = 3306
# when setting SyncPartialColumn drainer will allow the downstream schema
# having more or less column numbers and relax sql mode by removing STRICT_TRANS_TABLES.
# sync-mode = 1
#
# Uncomment this part if you need TLS to connecting downstream MySQL/TiDB.
# You can only specified only `ssl-ca` if there is no client certificate and don't need server to authenticate client.
# [syncer.to.security]
# Path of file that contains list of trusted SSL CAs.
# ssl-ca = "/path/to/ca.pem"
# Path of file that contains X509 certificate in PEM format.
# ssl-cert = "/path/to/drainer.pem"
# Path of file that contains X509 key in PEM format.
# ssl-key = "/path/to/drainer-key.pem"

[syncer.to.checkpoint]
# only support mysql or tidb now, you can uncomment this to control where the checkpoint is saved.
Expand Down
7 changes: 7 additions & 0 deletions drainer/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,13 @@ func (cfg *Config) Parse(args []string) error {
return errors.Errorf("tls config %+v error %v", cfg.Security, err)
}

if cfg.SyncerCfg != nil && cfg.SyncerCfg.To != nil {
cfg.SyncerCfg.To.TLS, err = cfg.SyncerCfg.To.Security.ToTLSConfig()
if err != nil {
return errors.Errorf("tls config %+v error %v", cfg.SyncerCfg.To.Security, err)
}
}

if err = cfg.adjustConfig(); err != nil {
return errors.Trace(err)
}
Expand Down
2 changes: 1 addition & 1 deletion drainer/relay.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ func feedByRelayLogIfNeed(cfg *Config) error {
return errors.Annotate(err, "failed to create reader")
}

db, err := loader.CreateDBWithSQLMode(scfg.To.User, scfg.To.Password, scfg.To.Host, scfg.To.Port, scfg.StrSQLMode)
db, err := loader.CreateDBWithSQLMode(scfg.To.User, scfg.To.Password, scfg.To.Host, scfg.To.Port, scfg.To.TLS, scfg.StrSQLMode)
if err != nil {
return errors.Annotate(err, "failed to create SQL db")
}
Expand Down
8 changes: 6 additions & 2 deletions drainer/sync/mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,11 @@ func NewMysqlSyncer(
relayer relay.Relayer,
info *loopbacksync.LoopBackSync,
) (*MysqlSyncer, error) {
db, err := createDB(cfg.User, cfg.Password, cfg.Host, cfg.Port, sqlMode)
if cfg.TLS != nil {
log.Info("enable TLS to connect downstream MySQL/TiDB")
}

db, err := createDB(cfg.User, cfg.Password, cfg.Host, cfg.Port, cfg.TLS, sqlMode)
if err != nil {
return nil, errors.Trace(err)
}
Expand All @@ -104,7 +108,7 @@ func NewMysqlSyncer(

if newMode != oldMode {
db.Close()
db, err = createDB(cfg.User, cfg.Password, cfg.Host, cfg.Port, &newMode)
db, err = createDB(cfg.User, cfg.Password, cfg.Host, cfg.Port, cfg.TLS, &newMode)
if err != nil {
return nil, errors.Trace(err)
}
Expand Down
3 changes: 2 additions & 1 deletion drainer/sync/syncer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
package sync

import (
"crypto/tls"
"database/sql"
"reflect"
"sync/atomic"
Expand Down Expand Up @@ -57,7 +58,7 @@ func (s *syncerSuite) SetUpTest(c *check.C) {

// create mysql syncer
oldCreateDB := createDB
createDB = func(string, string, string, int, *string) (db *sql.DB, err error) {
createDB = func(string, string, string, int, *tls.Config, *string) (db *sql.DB, err error) {
db, s.mysqlMock, err = sqlmock.New()
return
}
Expand Down
11 changes: 8 additions & 3 deletions drainer/sync/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,20 @@
package sync

import (
"crypto/tls"

// mysql driver
_ "github.com/go-sql-driver/mysql"
"github.com/pingcap/tidb-binlog/pkg/security"
)

// DBConfig is the DB configuration.
type DBConfig struct {
Host string `toml:"host" json:"host"`
User string `toml:"user" json:"user"`
Password string `toml:"password" json:"password"`
Host string `toml:"host" json:"host"`
User string `toml:"user" json:"user"`
Password string `toml:"password" json:"password"`
Security security.Config `toml:"security" json:"security"`
TLS *tls.Config `toml:"-" json:"-"`
// if EncryptedPassword is not empty, Password will be ignore.
EncryptedPassword string `toml:"encrypted_password" json:"encrypted_password"`
SyncMode int `toml:"sync-mode" json:"sync-mode"`
Expand Down
2 changes: 1 addition & 1 deletion pkg/loader/example_loader_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ import "log"

func Example() {
// create sql.DB
db, err := CreateDB("root", "", "localhost", 4000)
db, err := CreateDB("root", "", "localhost", 4000, nil /* *tls.Config */)
if err != nil {
log.Fatal(err)
}
Expand Down
21 changes: 18 additions & 3 deletions pkg/loader/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,16 @@
package loader

import (
"crypto/tls"
gosql "database/sql"
"fmt"
"hash/crc32"
"net/url"
"strconv"
"strings"
"sync/atomic"

"github.com/go-sql-driver/mysql"
"github.com/pingcap/errors"
)

Expand Down Expand Up @@ -77,14 +81,25 @@ func getTableInfo(db *gosql.DB, schema string, table string) (info *tableInfo, e
return
}

var customID int64

// CreateDBWithSQLMode return sql.DB
func CreateDBWithSQLMode(user string, password string, host string, port int, sqlMode *string) (db *gosql.DB, err error) {
func CreateDBWithSQLMode(user string, password string, host string, port int, tlsConfig *tls.Config, sqlMode *string) (db *gosql.DB, err error) {
dsn := fmt.Sprintf("%s:%s@tcp(%s:%d)/?charset=utf8mb4,utf8&interpolateParams=true&readTimeout=1m&multiStatements=true", user, password, host, port)
if sqlMode != nil {
// same as "set sql_mode = '<sqlMode>'"
dsn += "&sql_mode='" + url.QueryEscape(*sqlMode) + "'"
}

if tlsConfig != nil {
name := "custom_" + strconv.FormatInt(atomic.AddInt64(&customID, 1), 10)
err := mysql.RegisterTLSConfig(name, tlsConfig)
if err != nil {
return nil, errors.Annotate(err, "failed to RegisterTLSConfig")
}
dsn += "&tls=" + name
}

db, err = gosql.Open("mysql", dsn)
if err != nil {
return nil, errors.Trace(err)
Expand All @@ -93,8 +108,8 @@ func CreateDBWithSQLMode(user string, password string, host string, port int, sq
}

// CreateDB return sql.DB
func CreateDB(user string, password string, host string, port int) (db *gosql.DB, err error) {
return CreateDBWithSQLMode(user, password, host, port, nil)
func CreateDB(user string, password string, host string, port int, tls *tls.Config) (db *gosql.DB, err error) {
return CreateDBWithSQLMode(user, password, host, port, tls, nil)
}

func quoteSchema(schema string, table string) string {
Expand Down
2 changes: 1 addition & 1 deletion reparo/syncer/mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ var (
var createDB = loader.CreateDB

func newMysqlSyncer(cfg *DBConfig, worker int, batchSize int, safemode bool) (*mysqlSyncer, error) {
db, err := createDB(cfg.User, cfg.Password, cfg.Host, cfg.Port)
db, err := createDB(cfg.User, cfg.Password, cfg.Host, cfg.Port, nil)
if err != nil {
return nil, errors.Trace(err)
}
Expand Down
3 changes: 2 additions & 1 deletion reparo/syncer/mysql_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package syncer

import (
"crypto/tls"
"database/sql"
"time"

Expand All @@ -24,7 +25,7 @@ func (s *testMysqlSuite) testMysqlSyncer(c *check.C, safemode bool) {
)

oldCreateDB := createDB
createDB = func(string, string, string, int) (db *sql.DB, err error) {
createDB = func(string, string, string, int, *tls.Config) (db *sql.DB, err error) {
db, mock, err = sqlmock.New()
return
}
Expand Down

0 comments on commit 7e30098

Please sign in to comment.