Skip to content

Commit

Permalink
support TLS for components and downstream db (#931)
Browse files Browse the repository at this point in the history
* support TLS for components (#904)


[relate docs](https://pingcap.com/docs/stable/how-to/secure/enable-tls-between-components/) ([Chinese version](https://pingcap.com/docs-cn/stable/how-to/secure/enable-tls-between-components/))
This Commit:
- properly handle things about TLS when enabling TLS
- enable TLS in the integration tests
- log pump config at startup time
  • Loading branch information
july2993 authored Mar 17, 2020
1 parent e4f2ba2 commit b9a8759
Show file tree
Hide file tree
Showing 35 changed files with 484 additions and 210 deletions.
2 changes: 1 addition & 1 deletion arbiter/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ func NewServer(cfg *Config) (srv *Server, err error) {
up := cfg.Up
down := cfg.Down

srv.downDB, err = createDB(down.User, down.Password, down.Host, down.Port)
srv.downDB, err = createDB(down.User, down.Password, down.Host, down.Port, nil)
if err != nil {
return nil, errors.Trace(err)
}
Expand Down
7 changes: 4 additions & 3 deletions arbiter/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ package arbiter

import (
"context"
"crypto/tls"
"database/sql"
"fmt"
"sync"
Expand Down Expand Up @@ -55,7 +56,7 @@ func (l *dummyLoader) Close() {
type testNewServerSuite struct {
db *sql.DB
dbMock sqlmock.Sqlmock
origCreateDB func(string, string, string, int) (*sql.DB, error)
origCreateDB func(string, string, string, int, *tls.Config) (*sql.DB, error)
origNewReader func(*reader.Config) (*reader.Reader, error)
origNewLoader func(*sql.DB, ...loader.Option) (loader.Loader, error)
}
Expand All @@ -71,7 +72,7 @@ func (s *testNewServerSuite) SetUpTest(c *C) {
s.dbMock = mock

s.origCreateDB = createDB
createDB = func(user string, password string, host string, port int) (*sql.DB, error) {
createDB = func(user string, password string, host string, port int, _ *tls.Config) (*sql.DB, error) {
return s.db, nil
}

Expand Down Expand Up @@ -105,7 +106,7 @@ func (s *testNewServerSuite) TestRejectInvalidAddr(c *C) {
}

func (s *testNewServerSuite) TestStopIfFailedtoConnectDownStream(c *C) {
createDB = func(user string, password string, host string, port int) (*sql.DB, error) {
createDB = func(user string, password string, host string, port int, _ *tls.Config) (*sql.DB, error) {
return nil, fmt.Errorf("Can't create db")
}

Expand Down
30 changes: 15 additions & 15 deletions binlogctl/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,20 +65,20 @@ const (

// Config holds the configuration of drainer
type Config struct {
*flag.FlagSet

Command string `toml:"cmd" json:"cmd"`
NodeID string `toml:"node-id" json:"node-id"`
DataDir string `toml:"data-dir" json:"data-dir"`
TimeZone string `toml:"time-zone" json:"time-zone"`
EtcdURLs string `toml:"pd-urls" json:"pd-urls"`
SSLCA string `toml:"ssl-ca" json:"ssl-ca"`
SSLCert string `toml:"ssl-cert" json:"ssl-cert"`
SSLKey string `toml:"ssl-key" json:"ssl-key"`
State string `toml:"state" json:"state"`
ShowOfflineNodes bool `toml:"state" json:"show-offline-nodes"`
Text string `toml:"text" json:"text"`
tls *tls.Config
*flag.FlagSet `toml:"-" json:"-"`

Command string `toml:"cmd" json:"cmd"`
NodeID string `toml:"node-id" json:"node-id"`
DataDir string `toml:"data-dir" json:"data-dir"`
TimeZone string `toml:"time-zone" json:"time-zone"`
EtcdURLs string `toml:"pd-urls" json:"pd-urls"`
SSLCA string `toml:"ssl-ca" json:"ssl-ca"`
SSLCert string `toml:"ssl-cert" json:"ssl-cert"`
SSLKey string `toml:"ssl-key" json:"ssl-key"`
State string `toml:"state" json:"state"`
ShowOfflineNodes bool `toml:"state" json:"show-offline-nodes"`
Text string `toml:"text" json:"text"`
TLS *tls.Config `toml:"-" json:"-"`
printVersion bool
}

Expand Down Expand Up @@ -134,7 +134,7 @@ func (cfg *Config) Parse(args []string) error {
SSLCert: cfg.SSLCert,
SSLKey: cfg.SSLKey,
}
cfg.tls, err = sCfg.ToTLSConfig()
cfg.TLS, err = sCfg.ToTLSConfig()
if err != nil {
return errors.Errorf("tls config error %v", err)
}
Expand Down
39 changes: 28 additions & 11 deletions binlogctl/nodes.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ package binlogctl

import (
"context"
"crypto/tls"
"fmt"
"net/http"
"time"
Expand All @@ -34,8 +35,8 @@ var (
)

// QueryNodesByKind returns specified nodes, like pumps/drainers
func QueryNodesByKind(urls string, kind string, showOffline bool) error {
registry, err := createRegistryFuc(urls)
func QueryNodesByKind(urls string, kind string, showOffline bool, tlsConfig *tls.Config) error {
registry, err := createRegistryFuc(urls, tlsConfig)
if err != nil {
return errors.Trace(err)
}
Expand All @@ -56,12 +57,12 @@ func QueryNodesByKind(urls string, kind string, showOffline bool) error {
}

// UpdateNodeState update pump or drainer's state.
func UpdateNodeState(urls, kind, nodeID, state string) error {
func UpdateNodeState(urls, kind, nodeID, state string, tlsConfig *tls.Config) error {
/*
node's state can be online, pausing, paused, closing and offline.
if the state is one of them, will update the node's state saved in etcd directly.
*/
registry, err := createRegistryFuc(urls)
registry, err := createRegistryFuc(urls, tlsConfig)
if err != nil {
return errors.Trace(err)
}
Expand All @@ -81,12 +82,12 @@ func UpdateNodeState(urls, kind, nodeID, state string) error {
}

// createRegistry returns an ectd registry
func createRegistry(urls string) (*node.EtcdRegistry, error) {
func createRegistry(urls string, tlsConfig *tls.Config) (*node.EtcdRegistry, error) {
ectdEndpoints, err := flags.ParseHostPortAddr(urls)
if err != nil {
return nil, errors.Trace(err)
}
cli, err := newEtcdClientFromCfgFunc(ectdEndpoints, etcdDialTimeout, node.DefaultRootPath, nil)
cli, err := newEtcdClientFromCfgFunc(ectdEndpoints, etcdDialTimeout, node.DefaultRootPath, tlsConfig)
if err != nil {
return nil, errors.Trace(err)
}
Expand All @@ -95,8 +96,8 @@ func createRegistry(urls string) (*node.EtcdRegistry, error) {
}

// ApplyAction applies action on pump or drainer
func ApplyAction(urls, kind, nodeID string, action string) error {
registry, err := createRegistryFuc(urls)
func ApplyAction(urls, kind, nodeID string, action string, tlsConfig *tls.Config) error {
registry, err := createRegistryFuc(urls, tlsConfig)
if err != nil {
return errors.Trace(err)
}
Expand All @@ -106,18 +107,34 @@ func ApplyAction(urls, kind, nodeID string, action string) error {
return errors.Trace(err)
}

var client http.Client
url := fmt.Sprintf("http://%s/state/%s/%s", n.Addr, n.NodeID, action)
schema := "http"
if tlsConfig != nil {
schema = "https"
}

url := fmt.Sprintf("%s://%s/state/%s/%s", schema, n.Addr, n.NodeID, action)
log.Debug("send put http request", zap.String("url", url))
req, err := http.NewRequest("PUT", url, nil)
if err != nil {
return errors.Trace(err)
}
_, err = client.Do(req)
_, err = getClient(tlsConfig).Do(req)
if err == nil {
log.Info("Apply action on node success", zap.String("action", action), zap.String("NodeID", n.NodeID))
return nil
}

return errors.Trace(err)
}

func getClient(tlsConfig *tls.Config) *http.Client {
if tlsConfig == nil {
return &http.Client{}
}

return &http.Client{
Transport: &http.Transport{
TLSClientConfig: tlsConfig,
},
}
}
16 changes: 8 additions & 8 deletions binlogctl/nodes_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ type testNodesSuite struct{}
func (s *testNodesSuite) SetUpTest(c *C) {
newEtcdClientFromCfgFunc = newFakeEtcdClientFromCfg
createRegistryFuc = createMockRegistry
_, err := createMockRegistry("127.0.0.1:2379")
_, err := createMockRegistry("127.0.0.1:2379", nil)
c.Assert(err, IsNil)
}

Expand All @@ -63,29 +63,29 @@ func (s *testNodesSuite) TestApplyAction(c *C) {

registerPumpForTest(c, "test", url)

err := ApplyAction("127.0.0.1:2379", "pumps", "test2", PausePump)
err := ApplyAction("127.0.0.1:2379", "pumps", "test2", PausePump, nil)
c.Assert(errors.IsNotFound(err), IsTrue)

// TODO: handle log information and add check
err = ApplyAction("127.0.0.1:2379", "pumps", "test", PausePump)
err = ApplyAction("127.0.0.1:2379", "pumps", "test", PausePump, nil)
c.Assert(err, IsNil)
}

func (s *testNodesSuite) TestQueryNodesByKind(c *C) {
registerPumpForTest(c, "test", "127.0.0.1:8255")

// TODO: handle log information and add check
err := QueryNodesByKind("127.0.0.1:2379", "pumps", false)
err := QueryNodesByKind("127.0.0.1:2379", "pumps", false, nil)
c.Assert(err, IsNil)
}

func (s *testNodesSuite) TestUpdateNodeState(c *C) {
registerPumpForTest(c, "test", "127.0.0.1:8255")

err := UpdateNodeState("127.0.0.1:2379", "pumps", "test2", node.Paused)
err := UpdateNodeState("127.0.0.1:2379", "pumps", "test2", node.Paused, nil)
c.Assert(err, ErrorMatches, ".*not found.*")

err = UpdateNodeState("127.0.0.1:2379", "pumps", "test", node.Paused)
err = UpdateNodeState("127.0.0.1:2379", "pumps", "test", node.Paused, nil)
c.Assert(err, IsNil)

// check node's state is changed to paused
Expand All @@ -104,7 +104,7 @@ func (s *testNodesSuite) TestUpdateNodeState(c *C) {

func (s *testNodesSuite) TestCreateRegistry(c *C) {
urls := "127.0.0.1:2379"
registry, err := createRegistry(urls)
registry, err := createRegistry(urls, nil)
c.Assert(err, IsNil)
c.Assert(registry, NotNil)

Expand All @@ -131,7 +131,7 @@ func (s *testNodesSuite) TestCreateRegistry(c *C) {

}

func createMockRegistry(urls string) (*node.EtcdRegistry, error) {
func createMockRegistry(urls string, _ *tls.Config) (*node.EtcdRegistry, error) {
if fakeRegistry != nil {
return fakeRegistry, nil
}
Expand Down
16 changes: 8 additions & 8 deletions cmd/binlogctl/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,21 +45,21 @@ func main() {
case ctl.GenerateMeta:
err = ctl.GenerateMetaInfo(cfg)
case ctl.QueryPumps:
err = ctl.QueryNodesByKind(cfg.EtcdURLs, node.PumpNode, cfg.ShowOfflineNodes)
err = ctl.QueryNodesByKind(cfg.EtcdURLs, node.PumpNode, cfg.ShowOfflineNodes, cfg.TLS)
case ctl.QueryDrainers:
err = ctl.QueryNodesByKind(cfg.EtcdURLs, node.DrainerNode, cfg.ShowOfflineNodes)
err = ctl.QueryNodesByKind(cfg.EtcdURLs, node.DrainerNode, cfg.ShowOfflineNodes, cfg.TLS)
case ctl.UpdatePump:
err = ctl.UpdateNodeState(cfg.EtcdURLs, node.PumpNode, cfg.NodeID, cfg.State)
err = ctl.UpdateNodeState(cfg.EtcdURLs, node.PumpNode, cfg.NodeID, cfg.State, cfg.TLS)
case ctl.UpdateDrainer:
err = ctl.UpdateNodeState(cfg.EtcdURLs, node.DrainerNode, cfg.NodeID, cfg.State)
err = ctl.UpdateNodeState(cfg.EtcdURLs, node.DrainerNode, cfg.NodeID, cfg.State, cfg.TLS)
case ctl.PausePump:
err = ctl.ApplyAction(cfg.EtcdURLs, node.PumpNode, cfg.NodeID, pause)
err = ctl.ApplyAction(cfg.EtcdURLs, node.PumpNode, cfg.NodeID, pause, cfg.TLS)
case ctl.PauseDrainer:
err = ctl.ApplyAction(cfg.EtcdURLs, node.DrainerNode, cfg.NodeID, pause)
err = ctl.ApplyAction(cfg.EtcdURLs, node.DrainerNode, cfg.NodeID, pause, cfg.TLS)
case ctl.OfflinePump:
err = ctl.ApplyAction(cfg.EtcdURLs, node.PumpNode, cfg.NodeID, close)
err = ctl.ApplyAction(cfg.EtcdURLs, node.PumpNode, cfg.NodeID, close, cfg.TLS)
case ctl.OfflineDrainer:
err = ctl.ApplyAction(cfg.EtcdURLs, node.DrainerNode, cfg.NodeID, close)
err = ctl.ApplyAction(cfg.EtcdURLs, node.DrainerNode, cfg.NodeID, close, cfg.TLS)
case ctl.Encrypt:
if len(cfg.Text) == 0 {
err = errors.New("need to specify the text to be encrypt")
Expand Down
10 changes: 10 additions & 0 deletions cmd/drainer/drainer.toml
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,16 @@ port = 3306
# when setting SyncPartialColumn drainer will allow the downstream schema
# having more or less column numbers and relax sql mode by removing STRICT_TRANS_TABLES.
# sync-mode = 1
#
# Uncomment this part if you need TLS to connecting downstream MySQL/TiDB.
# You can only specified only `ssl-ca` if there is no client certificate and don't need server to authenticate client.
# [syncer.to.security]
# Path of file that contains list of trusted SSL CAs.
# ssl-ca = "/path/to/ca.pem"
# Path of file that contains X509 certificate in PEM format.
# ssl-cert = "/path/to/drainer.pem"
# Path of file that contains X509 key in PEM format.
# ssl-key = "/path/to/drainer-key.pem"

[syncer.to.checkpoint]
# only support mysql or tidb now, you can uncomment this to control where the checkpoint is saved.
Expand Down
1 change: 1 addition & 0 deletions cmd/pump/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ func main() {
log.Fatal("Failed to initialize log", zap.Error(err))
}
version.PrintVersionInfo("Pump")
log.Info("start pump...", zap.Reflect("config", cfg))

p, err := pump.NewServer(cfg)
if err != nil {
Expand Down
5 changes: 4 additions & 1 deletion drainer/collector.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
package drainer

import (
"crypto/tls"
"fmt"
"net/http"
"strings"
Expand Down Expand Up @@ -49,6 +50,7 @@ type notifyResult struct {
// Collector collects binlog from all pump, and send binlog to syncer.
type Collector struct {
clusterID uint64
tls *tls.Config
interval time.Duration
reg *node.EtcdRegistry
tiStore kv.Storage
Expand Down Expand Up @@ -106,6 +108,7 @@ func NewCollector(cfg *Config, clusterID uint64, s *Syncer, cpt checkpoint.Check

c := &Collector{
clusterID: clusterID,
tls: cfg.tls,
interval: time.Duration(cfg.DetectInterval) * time.Second,
reg: node.NewEtcdRegistry(cli, cfg.EtcdTimeout),
pumps: make(map[string]*Pump),
Expand Down Expand Up @@ -308,7 +311,7 @@ func (c *Collector) handlePumpStatusUpdate(ctx context.Context, n *node.Status)
}

commitTS := c.merger.GetLatestTS()
p := NewPump(n.NodeID, n.Addr, c.clusterID, commitTS, c.errCh)
p := NewPump(n.NodeID, n.Addr, c.tls, c.clusterID, commitTS, c.errCh)
c.pumps[n.NodeID] = p
c.merger.AddSource(MergeSource{
ID: n.NodeID,
Expand Down
7 changes: 7 additions & 0 deletions drainer/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,13 @@ func (cfg *Config) Parse(args []string) error {
return errors.Errorf("tls config %+v error %v", cfg.Security, err)
}

if cfg.SyncerCfg != nil && cfg.SyncerCfg.To != nil {
cfg.SyncerCfg.To.TLS, err = cfg.SyncerCfg.To.Security.ToTLSConfig()
if err != nil {
return errors.Errorf("tls config %+v error %v", cfg.SyncerCfg.To.Security, err)
}
}

if err = cfg.adjustConfig(); err != nil {
return errors.Trace(err)
}
Expand Down
Loading

0 comments on commit b9a8759

Please sign in to comment.