Skip to content

Commit

Permalink
secrets: Fix vault CSI support
Browse files Browse the repository at this point in the history
In Vault CSI the secret files are actually written under `..data`
subdirectory which is also a symlink, and fs.WalkDir will skip symlinks
unless they are the root.

So walk under `..data` instead, and also strip the prefix from the path
key. Also add a test to test the rotation case and some improvements on
existing tests.
  • Loading branch information
fishy committed Dec 19, 2022
1 parent 2dd3ab8 commit 646bca6
Show file tree
Hide file tree
Showing 10 changed files with 348 additions and 130 deletions.
8 changes: 7 additions & 1 deletion filewatcher/filewatcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (

"github.com/fsnotify/fsnotify"

"github.com/reddit/baseplate.go/errorsbp"
"github.com/reddit/baseplate.go/internal/limitopen"
"github.com/reddit/baseplate.go/log"
)
Expand Down Expand Up @@ -363,16 +364,21 @@ func New(ctx context.Context, cfg Config) (*Result, error) {
var mtime time.Time
var files []string

var lastErr error
for {
select {
default:
case <-ctx.Done():
return nil, fmt.Errorf("filewatcher: context cancelled while waiting for file under %q to load. %w", cfg.Path, ctx.Err())
var batch errorsbp.Batch
batch.Add(ctx.Err())
batch.AddPrefix("last error", lastErr)
return nil, fmt.Errorf("filewatcher: context canceled while waiting for file(s) under %q to load: %w", cfg.Path, batch.Compile())
}

var err error
data, mtime, files, err = openAndParse(cfg.Path, cfg.Parser, limit, hardLimit)
if errors.Is(err, fs.ErrNotExist) {
lastErr = err
time.Sleep(InitialReadInterval)
continue
}
Expand Down
26 changes: 14 additions & 12 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,11 @@ require (
github.com/prometheus/client_model v0.2.0
github.com/sony/gobreaker v0.4.1
go.uber.org/automaxprocs v1.5.1
go.uber.org/zap v1.15.0
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a
google.golang.org/grpc v1.41.0
go.uber.org/zap v1.21.0
golang.org/x/sys v0.3.0
google.golang.org/grpc v1.47.0
gopkg.in/yaml.v2 v2.4.0
sigs.k8s.io/secrets-store-csi-driver v1.3.0
)

require (
Expand All @@ -40,6 +41,7 @@ require (
github.com/eapache/queue v1.1.0 // indirect
github.com/garyburd/redigo v1.6.2 // indirect
github.com/go-logfmt/logfmt v0.5.1 // indirect
github.com/go-logr/logr v1.2.3 // indirect
github.com/golang/protobuf v1.5.2 // indirect
github.com/golang/snappy v0.0.3 // indirect
github.com/hashicorp/go-uuid v1.0.2 // indirect
Expand All @@ -49,7 +51,7 @@ require (
github.com/jcmturner/gokrb5/v8 v8.4.2 // indirect
github.com/jcmturner/rpc/v2 v2.0.3 // indirect
github.com/klauspost/compress v1.12.2 // indirect
github.com/matttproud/golang_protobuf_extensions v1.0.1 // indirect
github.com/matttproud/golang_protobuf_extensions v1.0.2-0.20181231171920-c182affec369 // indirect
github.com/mediocregopher/radix.v2 v0.0.0-20181115013041-b67df6e626f9 // indirect
github.com/pierrec/lz4 v2.6.0+incompatible // indirect
github.com/prometheus/common v0.37.0 // indirect
Expand All @@ -59,15 +61,15 @@ require (
go.opentelemetry.io/otel v0.20.0 // indirect
go.opentelemetry.io/otel/metric v0.20.0 // indirect
go.opentelemetry.io/otel/trace v0.20.0 // indirect
go.uber.org/atomic v1.6.0 // indirect
go.uber.org/multierr v1.5.0 // indirect
golang.org/x/crypto v0.0.0-20210616213533-5ff15b29337e // indirect
golang.org/x/net v0.0.0-20220225172249-27dd8689420f // indirect
golang.org/x/text v0.3.7 // indirect
golang.org/x/tools v0.1.3-0.20210608163600-9ed039809d4c // indirect
google.golang.org/genproto v0.0.0-20200825200019-8632dd797987 // indirect
go.uber.org/atomic v1.7.0 // indirect
go.uber.org/multierr v1.6.0 // indirect
golang.org/x/crypto v0.0.0-20221005025214-4161e89ecf1b // indirect
golang.org/x/net v0.4.0 // indirect
golang.org/x/text v0.5.0 // indirect
google.golang.org/genproto v0.0.0-20220502173005-c8bf987b8c21 // indirect
google.golang.org/protobuf v1.28.1 // indirect
honnef.co/go/tools v0.2.0 // indirect
k8s.io/apimachinery v0.25.0 // indirect
k8s.io/klog/v2 v2.80.0 // indirect
)

// Please use v0.9.3 or later versions instead.
Expand Down
104 changes: 70 additions & 34 deletions go.sum

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion grpcbp/client_middlewares_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
pb "github.com/grpc-ecosystem/go-grpc-middleware/testing/testproto"
"github.com/opentracing/opentracing-go"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/test/bufconn"

"github.com/reddit/baseplate.go/mqsend"
Expand Down Expand Up @@ -116,7 +117,7 @@ func setupClient(t *testing.T, l *bufconn.Listener, opts ...grpc.DialOption) *gr
}
opts = append([]grpc.DialOption{
grpc.WithContextDialer(bufDialer),
grpc.WithInsecure(),
grpc.WithTransportCredentials(insecure.NewCredentials()),
}, opts...)

// create connection to be used by gRPC client
Expand Down
2 changes: 1 addition & 1 deletion kafkabp/consumer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ func setupPartitionConsumers(t *testing.T, kc *consumer) (*mocks.PartitionConsum
c := *kc.consumer.Load()
mc, ok := c.(*mocks.Consumer)
if !ok {
t.Fatalf("kc.consumer is not *mocks.Consumer. %#v", kc.consumer)
t.Fatalf("kc.consumer is not *mocks.Consumer. %#v", kc.consumer.Load())
}
pc := mc.ExpectConsumePartition(kc.cfg.Topic, (*kc.partitions.Load())[0], kc.offset)
pc1 := mc.ExpectConsumePartition(kc.cfg.Topic, (*kc.partitions.Load())[1], kc.offset)
Expand Down
8 changes: 6 additions & 2 deletions secrets/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,12 @@ var (
//
// Can be deserialized from YAML.
type Config struct {
// Path is the path to the secrets.json file file to load your service's
// secrets from.
// Path is the path to the secrets.json file or Vault CSI directory to load
// your service's secrets from.
//
// Examples:
// - /var/local/secrets/secrets.json
// - /mnt/secrets
Path string `yaml:"path"`
}

Expand Down
34 changes: 31 additions & 3 deletions secrets/secrets.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@ package secrets

import (
"encoding/json"
"errors"
"fmt"
"io"
"io/fs"
"path/filepath"

"github.com/reddit/baseplate.go/errorsbp"
)
Expand Down Expand Up @@ -321,21 +323,47 @@ func secretsValidate(secretsDocument Document) (*Secrets, error) {
return secrets, nil
}

type notCSIError struct {
err error
}

func (e notCSIError) Error() string {
return fmt.Sprintf("configured directory does not appear to be the root of a Vault CSI mount point: %v", e.err)
}

func (e notCSIError) Unwrap() error {
return e.err
}

// walkCSIDirectory parses a directory for vault secrets and merges them into one object
func walkCSIDirectory(dir fs.FS) (Document, error) {
const (
// This is where k8s actually writes the content,
// ref: https://pkg.go.dev/sigs.k8s.io/secrets-store-csi-driver/pkg/util/fileutil#AtomicWriter.Write
k8sSubdirectory = "..data"
)
secretsDocument := Document{
Secrets: make(map[string]GenericSecret),
}
err := fs.WalkDir(
dir,
".",
k8sSubdirectory,
func(path string, d fs.DirEntry, err error) error {
if err != nil {
if path == k8sSubdirectory && errors.Is(err, fs.ErrNotExist) {
return notCSIError{err: err}
}
return err
}
if d.IsDir() {
return nil
}
relPath, err := filepath.Rel(k8sSubdirectory, path)
if err != nil {
// Should not happen as this means path is not under k8sSubdirectory,
// but just in case
return nil
}
file, err := dir.Open(path)
if err != nil {
return err
Expand All @@ -347,12 +375,12 @@ func walkCSIDirectory(dir fs.FS) (Document, error) {
if err != nil {
return fmt.Errorf("decoding %q: %w", path, err)
}
secretsDocument.Secrets[path] = secretFile.Secret
secretsDocument.Secrets[relPath] = secretFile.Secret
return nil
},
)
if err != nil {
return Document{}, fmt.Errorf("Error during walkCSIDirectory: %w", err)
return Document{}, fmt.Errorf("secrets.walkCSIDirectory: %w", err)
}
return secretsDocument, nil
}
8 changes: 8 additions & 0 deletions secrets/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"io"
"io/fs"
"os"
"time"

"github.com/reddit/baseplate.go/filewatcher"
"github.com/reddit/baseplate.go/log"
Expand Down Expand Up @@ -41,6 +42,11 @@ type Store struct {
// Context should come with a timeout otherwise this might block forever, i.e.
// if the path never becomes available.
func NewStore(ctx context.Context, path string, logger log.Wrapper, middlewares ...SecretMiddleware) (*Store, error) {
return newStore(ctx, 0 /* use default fsEventsDelay */, path, logger, middlewares...)
}

// Used in tests to override FSEventsDelay
func newStore(ctx context.Context, fsEventsDelay time.Duration, path string, logger log.Wrapper, middlewares ...SecretMiddleware) (*Store, error) {
store := &Store{
secretHandlerFunc: nopSecretHandlerFunc,
}
Expand All @@ -61,6 +67,8 @@ func NewStore(ctx context.Context, path string, logger log.Wrapper, middlewares
Path: path,
Parser: parser,
Logger: logger,

FSEventsDelay: fsEventsDelay,
},
)
if err != nil {
Expand Down
126 changes: 121 additions & 5 deletions secrets/store_internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,14 @@ package secrets

import (
"context"
"errors"
"io/fs"
"os"
"reflect"
"testing"
"time"

"sigs.k8s.io/secrets-store-csi-driver/pkg/util/fileutil"

"github.com/reddit/baseplate.go/log"
)
Expand Down Expand Up @@ -44,11 +50,7 @@ func TestNewStore(t *testing.T) {
},
}

dir, err := os.MkdirTemp("", "secret_test_")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(dir)
dir := t.TempDir()

for _, tt := range tests {
t.Run(
Expand Down Expand Up @@ -77,3 +79,117 @@ func TestNewStore(t *testing.T) {
)
}
}

func TestDirRotation(t *testing.T) {
const (
delay = 50 * time.Millisecond
sleep = delay * 5
)
const (
key1 = "secrets/foo"
key2 = "secrets/bar"

apiKey = `
{
"request_id": "1afc3036-2282-d483-c2d4-6d483efdf16c",
"lease_id": "",
"lease_duration": 2764800,
"renewable": false,
"data": {
"type": "simple",
"value": "Y2RvVXhNMVdsTXJma3BDaHRGZ0dPYkVGSg==",
"encoding": "base64"
},
"warnings": null
}
`
)
var wantSecret = SimpleSecret{Value: Secret("cdoUxM1WlMrfkpChtFgGObEFJ")}

dir := t.TempDir()
writer, err := fileutil.NewAtomicWriter(dir, "")
if err != nil {
t.Fatalf("Failed to create k8s atomic writer: %v", err)
}
content := fileutil.FileProjection{
Data: []byte(apiKey),
Mode: 0777,
}
if err := writer.Write(map[string]fileutil.FileProjection{
key1: content,
}); err != nil {
t.Fatalf("Failed to write initial payload: %v", err)
}

store, err := newStore(context.Background(), delay, dir, log.TestWrapper(t))
if err != nil {
t.Fatalf("Failed to create secrets store: %v", err)
}
t.Cleanup(func() { store.Close() })

t.Run("initial-payload", func(t *testing.T) {
const (
correctKey = key1
wrongKey = key2
)
_, err := store.GetSimpleSecret(wrongKey)
if err == nil {
t.Errorf("Expected error when getting %q, got nil", wrongKey)
}

secret, err := store.GetSimpleSecret(correctKey)
if err != nil {
t.Fatalf("Expected no error when getting %q, got %v", correctKey, err)
}
if !reflect.DeepEqual(secret, wantSecret) {
t.Errorf("Got secret %+v, want %+v", secret, wantSecret)
}
})

t.Run("rotated-payload", func(t *testing.T) {
const (
correctKey = key2
wrongKey = key1
)
if err := writer.Write(map[string]fileutil.FileProjection{
correctKey: content,
}); err != nil {
t.Fatalf("Failed to write rotated payload: %v", err)
}
time.Sleep(sleep)

_, err := store.GetSimpleSecret(wrongKey)
if err == nil {
t.Errorf("Expected error when getting %q, got nil", wrongKey)
}

secret, err := store.GetSimpleSecret(correctKey)
if err != nil {
t.Fatalf("Expected no error when getting %q, got %v", correctKey, err)
}
if !reflect.DeepEqual(secret, wantSecret) {
t.Errorf("Got secret %+v, want %+v", secret, wantSecret)
}
})
}

func TestDirectoryError(t *testing.T) {
dir := t.TempDir()
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond)
t.Cleanup(cancel)
store, err := NewStore(ctx, dir, log.NopWrapper)
if err == nil {
store.Close()
t.Fatal("Expected NewStore to return an error on an empty directory, got nil")
}
t.Logf("NewStore returned error: %v", err)
if !errors.Is(err, fs.ErrNotExist) {
t.Errorf("Error is not %v", fs.ErrNotExist)
}
if !errors.Is(err, context.DeadlineExceeded) {
t.Errorf("Error is not %v", context.Canceled)
}
if !errors.As(err, new(notCSIError)) {
t.Error("Error is not of type notCSIError")
}
}
Loading

0 comments on commit 646bca6

Please sign in to comment.