Skip to content

Commit

Permalink
Removed deciders; Cleaned up validators. (#554)
Browse files Browse the repository at this point in the history
Signed-off-by: bwplotka <bwplotka@gmail.com>
  • Loading branch information
bwplotka committed Apr 4, 2023
1 parent 8c53766 commit 0e1142d
Show file tree
Hide file tree
Showing 6 changed files with 78 additions and 110 deletions.
21 changes: 6 additions & 15 deletions interceptors/validator/interceptors.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,12 @@ import (
// UnaryServerInterceptor returns a new unary server interceptor that validates incoming messages.
//
// Invalid messages will be rejected with `InvalidArgument` before reaching any userspace handlers.
// If `WithFailFast` used it will interceptor and returns the first validation error. Otherwise, the interceptor
// returns ALL validation error as a wrapped multi-error.
// If `WithLogger` used it will log all the validation errors. Otherwise, no default logging.
// Note that generated codes prior to protoc-gen-validate v0.6.0 do not provide an all-validation
// interface. In this case the interceptor fallbacks to legacy validation and `all` is ignored.
func UnaryServerInterceptor(opts ...Option) grpc.UnaryServerInterceptor {
o := evaluateServerOpt(opts)
o := evaluateOpts(opts)
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
if err := validate(req, o.shouldFailFast, o.level, o.logger); err != nil {
if err := validate(ctx, req, o.shouldFailFast, o.onValidationErrFunc); err != nil {
return nil, err
}
return handler(ctx, req)
Expand All @@ -30,15 +27,12 @@ func UnaryServerInterceptor(opts ...Option) grpc.UnaryServerInterceptor {
// UnaryClientInterceptor returns a new unary client interceptor that validates outgoing messages.
//
// Invalid messages will be rejected with `InvalidArgument` before sending the request to server.
// If `WithFailFast` used it will interceptor and returns the first validation error. Otherwise, the interceptor
// returns ALL validation error as a wrapped multi-error.
// If `WithLogger` used it will log all the validation errors. Otherwise, no default logging.
// Note that generated codes prior to protoc-gen-validate v0.6.0 do not provide an all-validation
// interface. In this case the interceptor fallbacks to legacy validation and `all` is ignored.
func UnaryClientInterceptor(opts ...Option) grpc.UnaryClientInterceptor {
o := evaluateClientOpt(opts)
o := evaluateOpts(opts)
return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
if err := validate(req, o.shouldFailFast, o.level, o.logger); err != nil {
if err := validate(ctx, req, o.shouldFailFast, o.onValidationErrFunc); err != nil {
return err
}
return invoker(ctx, method, req, reply, cc, opts...)
Expand All @@ -47,17 +41,14 @@ func UnaryClientInterceptor(opts ...Option) grpc.UnaryClientInterceptor {

// StreamServerInterceptor returns a new streaming server interceptor that validates incoming messages.
//
// If `WithFailFast` used it will interceptor and returns the first validation error. Otherwise, the interceptor
// returns ALL validation error as a wrapped multi-error.
// If `WithLogger` used it will log all the validation errors. Otherwise, no default logging.
// Note that generated codes prior to protoc-gen-validate v0.6.0 do not provide an all-validation
// interface. In this case the interceptor fallbacks to legacy validation and `all` is ignored.
// The stage at which invalid messages will be rejected with `InvalidArgument` varies based on the
// type of the RPC. For `ServerStream` (1:m) requests, it will happen before reaching any userspace
// handlers. For `ClientStream` (n:1) or `BidiStream` (n:m) RPCs, the messages will be rejected on
// calls to `stream.Recv()`.
func StreamServerInterceptor(opts ...Option) grpc.StreamServerInterceptor {
o := evaluateServerOpt(opts)
o := evaluateOpts(opts)
return func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
wrapper := &recvWrapper{
options: o,
Expand All @@ -77,7 +68,7 @@ func (s *recvWrapper) RecvMsg(m any) error {
if err := s.ServerStream.RecvMsg(m); err != nil {
return err
}
if err := validate(m, s.shouldFailFast, s.level, s.logger); err != nil {
if err := validate(s.Context(), m, s.shouldFailFast, s.onValidationErrFunc); err != nil {
return err
}
return nil
Expand Down
44 changes: 25 additions & 19 deletions interceptors/validator/interceptors_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (
"io"
"testing"

"github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/logging"
"github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/validator"
"github.com/grpc-ecosystem/go-grpc-middleware/v2/testing/testpb"
"github.com/stretchr/testify/assert"
Expand All @@ -19,10 +18,6 @@ import (
"google.golang.org/grpc/status"
)

type TestLogger struct{}

func (l *TestLogger) Log(ctx context.Context, level logging.Level, msg string, fields ...any) {}

type ValidatorTestSuite struct {
*testpb.InterceptorTestSuite
}
Expand Down Expand Up @@ -104,35 +99,42 @@ func TestValidatorTestSuite(t *testing.T) {
}
suite.Run(t, sWithNoArgs)

sWithWithFailFastArgs := &ValidatorTestSuite{
sWithFailFastArgs := &ValidatorTestSuite{
InterceptorTestSuite: &testpb.InterceptorTestSuite{
ServerOpts: []grpc.ServerOption{
grpc.StreamInterceptor(validator.StreamServerInterceptor(validator.WithFailFast())),
grpc.UnaryInterceptor(validator.UnaryServerInterceptor(validator.WithFailFast())),
},
},
}
suite.Run(t, sWithWithFailFastArgs)
suite.Run(t, sWithFailFastArgs)

sWithWithLoggerArgs := &ValidatorTestSuite{
var gotErrMsgs []string
onErr := func(ctx context.Context, err error) {
gotErrMsgs = append(gotErrMsgs, err.Error())
}
sWithOnErrFuncArgs := &ValidatorTestSuite{
InterceptorTestSuite: &testpb.InterceptorTestSuite{
ServerOpts: []grpc.ServerOption{
grpc.StreamInterceptor(validator.StreamServerInterceptor(validator.WithLogger(logging.LevelDebug, &TestLogger{}))),
grpc.UnaryInterceptor(validator.UnaryServerInterceptor(validator.WithLogger(logging.LevelDebug, &TestLogger{}))),
grpc.StreamInterceptor(validator.StreamServerInterceptor(validator.WithOnValidationErrFunc(onErr))),
grpc.UnaryInterceptor(validator.UnaryServerInterceptor(validator.WithOnValidationErrFunc(onErr))),
},
},
}
suite.Run(t, sWithWithLoggerArgs)
suite.Run(t, sWithOnErrFuncArgs)
require.Equal(t, []string{"cannot sleep for more than 10s", "cannot sleep for more than 10s", "cannot sleep for more than 10s"}, gotErrMsgs)

gotErrMsgs = gotErrMsgs[:0]
sAll := &ValidatorTestSuite{
InterceptorTestSuite: &testpb.InterceptorTestSuite{
ServerOpts: []grpc.ServerOption{
grpc.StreamInterceptor(validator.StreamServerInterceptor(validator.WithFailFast(), validator.WithLogger(logging.LevelDebug, &TestLogger{}))),
grpc.UnaryInterceptor(validator.UnaryServerInterceptor(validator.WithFailFast(), validator.WithLogger(logging.LevelDebug, &TestLogger{}))),
grpc.StreamInterceptor(validator.StreamServerInterceptor(validator.WithFailFast(), validator.WithOnValidationErrFunc(onErr))),
grpc.UnaryInterceptor(validator.UnaryServerInterceptor(validator.WithFailFast(), validator.WithOnValidationErrFunc(onErr))),
},
},
}
suite.Run(t, sAll)
require.Equal(t, []string{"cannot sleep for more than 10s", "cannot sleep for more than 10s", "cannot sleep for more than 10s"}, gotErrMsgs)

csWithNoArgs := &ClientValidatorTestSuite{
InterceptorTestSuite: &testpb.InterceptorTestSuite{
Expand All @@ -143,30 +145,34 @@ func TestValidatorTestSuite(t *testing.T) {
}
suite.Run(t, csWithNoArgs)

csWithWithFailFastArgs := &ClientValidatorTestSuite{
csWithFailFastArgs := &ClientValidatorTestSuite{
InterceptorTestSuite: &testpb.InterceptorTestSuite{
ServerOpts: []grpc.ServerOption{
grpc.UnaryInterceptor(validator.UnaryServerInterceptor(validator.WithFailFast())),
},
},
}
suite.Run(t, csWithWithFailFastArgs)
suite.Run(t, csWithFailFastArgs)

csWithWithLoggerArgs := &ClientValidatorTestSuite{
gotErrMsgs = gotErrMsgs[:0]
csWithOnErrFuncArgs := &ClientValidatorTestSuite{
InterceptorTestSuite: &testpb.InterceptorTestSuite{
ServerOpts: []grpc.ServerOption{
grpc.UnaryInterceptor(validator.UnaryServerInterceptor(validator.WithLogger(logging.LevelDebug, &TestLogger{}))),
grpc.UnaryInterceptor(validator.UnaryServerInterceptor(validator.WithOnValidationErrFunc(onErr))),
},
},
}
suite.Run(t, csWithWithLoggerArgs)
suite.Run(t, csWithOnErrFuncArgs)
require.Equal(t, []string{"cannot sleep for more than 10s"}, gotErrMsgs)

gotErrMsgs = gotErrMsgs[:0]
csAll := &ClientValidatorTestSuite{
InterceptorTestSuite: &testpb.InterceptorTestSuite{
ClientOpts: []grpc.DialOption{
grpc.WithUnaryInterceptor(validator.UnaryClientInterceptor(validator.WithFailFast())),
grpc.WithUnaryInterceptor(validator.UnaryClientInterceptor(validator.WithFailFast(), validator.WithOnValidationErrFunc(onErr))),
},
},
}
suite.Run(t, csAll)
require.Equal(t, []string{"cannot sleep for more than 10s"}, gotErrMsgs)
}
27 changes: 11 additions & 16 deletions interceptors/validator/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,40 +3,35 @@

package validator

import "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/logging"
import (
"context"
)

type options struct {
level logging.Level
logger logging.Logger
shouldFailFast bool
shouldFailFast bool
onValidationErrFunc OnValidationErr
}
type Option func(*options)

func evaluateServerOpt(opts []Option) *options {
func evaluateOpts(opts []Option) *options {
optCopy := &options{}
for _, o := range opts {
o(optCopy)
}
return optCopy
}

func evaluateClientOpt(opts []Option) *options {
optCopy := &options{}
for _, o := range opts {
o(optCopy)
}
return optCopy
}
type OnValidationErr func(ctx context.Context, err error)

// WithLogger tells validator to log all the validation errors with the given log level.
func WithLogger(level logging.Level, logger logging.Logger) Option {
// WithOnValidationErrFunc registers function that will be invoked on validation error(s).
func WithOnValidationErrFunc(onValidationErrFunc OnValidationErr) Option {
return func(o *options) {
o.level = level
o.logger = logger
o.onValidationErrFunc = onValidationErrFunc
}
}

// WithFailFast tells validator to immediately stop doing further validation after first validation error.
// This option is ignored if message is only supporting validator.validatorLegacy interface.
func WithFailFast() Option {
return func(o *options) {
o.shouldFailFast = true
Expand Down
57 changes: 18 additions & 39 deletions interceptors/validator/validator.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ package validator
import (
"context"

"github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/logging"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
Expand All @@ -28,51 +27,31 @@ type validatorLegacy interface {
Validate() error
}

func log(level logging.Level, logger logging.Logger, msg string) {
if logger != nil {
// TODO(bwplotka): Fix in separate PR.
logger.Log(context.TODO(), level, msg)
}
}

func validate(req interface{}, shouldFailFast bool, level logging.Level, logger logging.Logger) error {
// shouldFailFast tells validator to immediately stop doing further validation after first validation error.
func validate(ctx context.Context, reqOrRes interface{}, shouldFailFast bool, onValidationErrFunc OnValidationErr) (err error) {
if shouldFailFast {
switch v := req.(type) {
switch v := reqOrRes.(type) {
case validatorLegacy:
if err := v.Validate(); err != nil {
log(level, logger, err.Error())
return status.Error(codes.InvalidArgument, err.Error())
}
err = v.Validate()
case validator:
if err := v.Validate(false); err != nil {
log(level, logger, err.Error())
return status.Error(codes.InvalidArgument, err.Error())
}
err = v.Validate(false)
}
} else {
switch v := reqOrRes.(type) {
case validateAller:
err = v.ValidateAll()
case validator:
err = v.Validate(true)
case validatorLegacy:
err = v.Validate()
}
}

if err == nil {
return nil
}

// shouldNotFailFast tells validator to continue doing further validation even if after a validation error.
switch v := req.(type) {
case validateAller:
if err := v.ValidateAll(); err != nil {
log(level, logger, err.Error())
return status.Error(codes.InvalidArgument, err.Error())
}
case validator:
if err := v.Validate(true); err != nil {
log(level, logger, err.Error())
return status.Error(codes.InvalidArgument, err.Error())
}
case validatorLegacy:
// Fallback to legacy validator
if err := v.Validate(); err != nil {
log(level, logger, err.Error())
return status.Error(codes.InvalidArgument, err.Error())
}
if onValidationErrFunc != nil {
onValidationErrFunc(ctx, err)
}

return nil
return status.Error(codes.InvalidArgument, err.Error())
}
35 changes: 16 additions & 19 deletions interceptors/validator/validator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,28 +7,25 @@ import (
"context"
"testing"

"github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/logging"
"github.com/grpc-ecosystem/go-grpc-middleware/v2/testing/testpb"
"github.com/stretchr/testify/assert"
)

type TestLogger struct{}

func (l *TestLogger) Log(ctx context.Context, level logging.Level, msg string, fields ...any) {}

func TestValidateWrapper(t *testing.T) {
assert.NoError(t, validate(testpb.GoodPing, false, logging.LevelError, &TestLogger{}))
assert.Error(t, validate(testpb.BadPing, false, logging.LevelError, &TestLogger{}))
assert.NoError(t, validate(testpb.GoodPing, true, logging.LevelError, &TestLogger{}))
assert.Error(t, validate(testpb.BadPing, true, logging.LevelError, &TestLogger{}))

assert.NoError(t, validate(testpb.GoodPingError, false, logging.LevelError, &TestLogger{}))
assert.Error(t, validate(testpb.BadPingError, false, logging.LevelError, &TestLogger{}))
assert.NoError(t, validate(testpb.GoodPingError, true, logging.LevelError, &TestLogger{}))
assert.Error(t, validate(testpb.BadPingError, true, logging.LevelError, &TestLogger{}))

assert.NoError(t, validate(testpb.GoodPingResponse, false, logging.LevelError, &TestLogger{}))
assert.NoError(t, validate(testpb.GoodPingResponse, true, logging.LevelError, &TestLogger{}))
assert.Error(t, validate(testpb.BadPingResponse, false, logging.LevelError, &TestLogger{}))
assert.Error(t, validate(testpb.BadPingResponse, true, logging.LevelError, &TestLogger{}))
ctx := context.Background()

assert.NoError(t, validate(ctx, testpb.GoodPing, false, nil))
assert.Error(t, validate(ctx, testpb.BadPing, false, nil))
assert.NoError(t, validate(ctx, testpb.GoodPing, true, nil))
assert.Error(t, validate(ctx, testpb.BadPing, true, nil))

assert.NoError(t, validate(ctx, testpb.GoodPingError, false, nil))
assert.Error(t, validate(ctx, testpb.BadPingError, false, nil))
assert.NoError(t, validate(ctx, testpb.GoodPingError, true, nil))
assert.Error(t, validate(ctx, testpb.BadPingError, true, nil))

assert.NoError(t, validate(ctx, testpb.GoodPingResponse, false, nil))
assert.NoError(t, validate(ctx, testpb.GoodPingResponse, true, nil))
assert.Error(t, validate(ctx, testpb.BadPingResponse, false, nil))
assert.Error(t, validate(ctx, testpb.BadPingResponse, true, nil))
}
4 changes: 2 additions & 2 deletions testing/testpb/test.manual_validator.pb.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,15 @@ func (x *PingStreamRequest) Validate(bool) error {
return nil
}

// Implements the legacy validation interface from protoc-gen-validate.
// Validate implements the legacy validation interface from protoc-gen-validate.
func (x *PingResponse) Validate() error {
if x.Counter > math.MaxInt16 {
return errors.New("ping allocation exceeded")
}
return nil
}

// Implements the new ValidateAll interface from protoc-gen-validate.
// ValidateAll implements the new ValidateAll interface from protoc-gen-validate.
func (x *PingResponse) ValidateAll() error {
if x.Counter > math.MaxInt16 {
return errors.New("ping allocation exceeded")
Expand Down

0 comments on commit 0e1142d

Please sign in to comment.