Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Removed deciders; Cleaned up validators. #554

Merged
merged 1 commit into from
Apr 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 6 additions & 15 deletions interceptors/validator/interceptors.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,12 @@ import (
// UnaryServerInterceptor returns a new unary server interceptor that validates incoming messages.
//
// Invalid messages will be rejected with `InvalidArgument` before reaching any userspace handlers.
// If `WithFailFast` used it will interceptor and returns the first validation error. Otherwise, the interceptor
// returns ALL validation error as a wrapped multi-error.
// If `WithLogger` used it will log all the validation errors. Otherwise, no default logging.
// Note that generated codes prior to protoc-gen-validate v0.6.0 do not provide an all-validation
// interface. In this case the interceptor fallbacks to legacy validation and `all` is ignored.
func UnaryServerInterceptor(opts ...Option) grpc.UnaryServerInterceptor {
o := evaluateServerOpt(opts)
o := evaluateOpts(opts)
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
if err := validate(req, o.shouldFailFast, o.level, o.logger); err != nil {
if err := validate(ctx, req, o.shouldFailFast, o.onValidationErrFunc); err != nil {
return nil, err
}
return handler(ctx, req)
Expand All @@ -30,15 +27,12 @@ func UnaryServerInterceptor(opts ...Option) grpc.UnaryServerInterceptor {
// UnaryClientInterceptor returns a new unary client interceptor that validates outgoing messages.
//
// Invalid messages will be rejected with `InvalidArgument` before sending the request to server.
// If `WithFailFast` used it will interceptor and returns the first validation error. Otherwise, the interceptor
// returns ALL validation error as a wrapped multi-error.
// If `WithLogger` used it will log all the validation errors. Otherwise, no default logging.
// Note that generated codes prior to protoc-gen-validate v0.6.0 do not provide an all-validation
// interface. In this case the interceptor fallbacks to legacy validation and `all` is ignored.
func UnaryClientInterceptor(opts ...Option) grpc.UnaryClientInterceptor {
o := evaluateClientOpt(opts)
o := evaluateOpts(opts)
return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
if err := validate(req, o.shouldFailFast, o.level, o.logger); err != nil {
if err := validate(ctx, req, o.shouldFailFast, o.onValidationErrFunc); err != nil {
return err
}
return invoker(ctx, method, req, reply, cc, opts...)
Expand All @@ -47,17 +41,14 @@ func UnaryClientInterceptor(opts ...Option) grpc.UnaryClientInterceptor {

// StreamServerInterceptor returns a new streaming server interceptor that validates incoming messages.
//
// If `WithFailFast` used it will interceptor and returns the first validation error. Otherwise, the interceptor
// returns ALL validation error as a wrapped multi-error.
// If `WithLogger` used it will log all the validation errors. Otherwise, no default logging.
// Note that generated codes prior to protoc-gen-validate v0.6.0 do not provide an all-validation
// interface. In this case the interceptor fallbacks to legacy validation and `all` is ignored.
// The stage at which invalid messages will be rejected with `InvalidArgument` varies based on the
// type of the RPC. For `ServerStream` (1:m) requests, it will happen before reaching any userspace
// handlers. For `ClientStream` (n:1) or `BidiStream` (n:m) RPCs, the messages will be rejected on
// calls to `stream.Recv()`.
func StreamServerInterceptor(opts ...Option) grpc.StreamServerInterceptor {
o := evaluateServerOpt(opts)
o := evaluateOpts(opts)
return func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
wrapper := &recvWrapper{
options: o,
Expand All @@ -77,7 +68,7 @@ func (s *recvWrapper) RecvMsg(m any) error {
if err := s.ServerStream.RecvMsg(m); err != nil {
return err
}
if err := validate(m, s.shouldFailFast, s.level, s.logger); err != nil {
if err := validate(s.Context(), m, s.shouldFailFast, s.onValidationErrFunc); err != nil {
return err
}
return nil
Expand Down
44 changes: 25 additions & 19 deletions interceptors/validator/interceptors_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (
"io"
"testing"

"github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/logging"
"github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/validator"
"github.com/grpc-ecosystem/go-grpc-middleware/v2/testing/testpb"
"github.com/stretchr/testify/assert"
Expand All @@ -19,10 +18,6 @@ import (
"google.golang.org/grpc/status"
)

type TestLogger struct{}

func (l *TestLogger) Log(ctx context.Context, level logging.Level, msg string, fields ...any) {}

type ValidatorTestSuite struct {
*testpb.InterceptorTestSuite
}
Expand Down Expand Up @@ -104,35 +99,42 @@ func TestValidatorTestSuite(t *testing.T) {
}
suite.Run(t, sWithNoArgs)

sWithWithFailFastArgs := &ValidatorTestSuite{
sWithFailFastArgs := &ValidatorTestSuite{
InterceptorTestSuite: &testpb.InterceptorTestSuite{
ServerOpts: []grpc.ServerOption{
grpc.StreamInterceptor(validator.StreamServerInterceptor(validator.WithFailFast())),
grpc.UnaryInterceptor(validator.UnaryServerInterceptor(validator.WithFailFast())),
},
},
}
suite.Run(t, sWithWithFailFastArgs)
suite.Run(t, sWithFailFastArgs)

sWithWithLoggerArgs := &ValidatorTestSuite{
var gotErrMsgs []string
onErr := func(ctx context.Context, err error) {
gotErrMsgs = append(gotErrMsgs, err.Error())
}
sWithOnErrFuncArgs := &ValidatorTestSuite{
InterceptorTestSuite: &testpb.InterceptorTestSuite{
ServerOpts: []grpc.ServerOption{
grpc.StreamInterceptor(validator.StreamServerInterceptor(validator.WithLogger(logging.LevelDebug, &TestLogger{}))),
grpc.UnaryInterceptor(validator.UnaryServerInterceptor(validator.WithLogger(logging.LevelDebug, &TestLogger{}))),
grpc.StreamInterceptor(validator.StreamServerInterceptor(validator.WithOnValidationErrFunc(onErr))),
grpc.UnaryInterceptor(validator.UnaryServerInterceptor(validator.WithOnValidationErrFunc(onErr))),
},
},
}
suite.Run(t, sWithWithLoggerArgs)
suite.Run(t, sWithOnErrFuncArgs)
require.Equal(t, []string{"cannot sleep for more than 10s", "cannot sleep for more than 10s", "cannot sleep for more than 10s"}, gotErrMsgs)

gotErrMsgs = gotErrMsgs[:0]
sAll := &ValidatorTestSuite{
InterceptorTestSuite: &testpb.InterceptorTestSuite{
ServerOpts: []grpc.ServerOption{
grpc.StreamInterceptor(validator.StreamServerInterceptor(validator.WithFailFast(), validator.WithLogger(logging.LevelDebug, &TestLogger{}))),
grpc.UnaryInterceptor(validator.UnaryServerInterceptor(validator.WithFailFast(), validator.WithLogger(logging.LevelDebug, &TestLogger{}))),
grpc.StreamInterceptor(validator.StreamServerInterceptor(validator.WithFailFast(), validator.WithOnValidationErrFunc(onErr))),
grpc.UnaryInterceptor(validator.UnaryServerInterceptor(validator.WithFailFast(), validator.WithOnValidationErrFunc(onErr))),
},
},
}
suite.Run(t, sAll)
require.Equal(t, []string{"cannot sleep for more than 10s", "cannot sleep for more than 10s", "cannot sleep for more than 10s"}, gotErrMsgs)

csWithNoArgs := &ClientValidatorTestSuite{
InterceptorTestSuite: &testpb.InterceptorTestSuite{
Expand All @@ -143,30 +145,34 @@ func TestValidatorTestSuite(t *testing.T) {
}
suite.Run(t, csWithNoArgs)

csWithWithFailFastArgs := &ClientValidatorTestSuite{
csWithFailFastArgs := &ClientValidatorTestSuite{
InterceptorTestSuite: &testpb.InterceptorTestSuite{
ServerOpts: []grpc.ServerOption{
grpc.UnaryInterceptor(validator.UnaryServerInterceptor(validator.WithFailFast())),
},
},
}
suite.Run(t, csWithWithFailFastArgs)
suite.Run(t, csWithFailFastArgs)

csWithWithLoggerArgs := &ClientValidatorTestSuite{
gotErrMsgs = gotErrMsgs[:0]
csWithOnErrFuncArgs := &ClientValidatorTestSuite{
InterceptorTestSuite: &testpb.InterceptorTestSuite{
ServerOpts: []grpc.ServerOption{
grpc.UnaryInterceptor(validator.UnaryServerInterceptor(validator.WithLogger(logging.LevelDebug, &TestLogger{}))),
grpc.UnaryInterceptor(validator.UnaryServerInterceptor(validator.WithOnValidationErrFunc(onErr))),
},
},
}
suite.Run(t, csWithWithLoggerArgs)
suite.Run(t, csWithOnErrFuncArgs)
require.Equal(t, []string{"cannot sleep for more than 10s"}, gotErrMsgs)

gotErrMsgs = gotErrMsgs[:0]
csAll := &ClientValidatorTestSuite{
InterceptorTestSuite: &testpb.InterceptorTestSuite{
ClientOpts: []grpc.DialOption{
grpc.WithUnaryInterceptor(validator.UnaryClientInterceptor(validator.WithFailFast())),
grpc.WithUnaryInterceptor(validator.UnaryClientInterceptor(validator.WithFailFast(), validator.WithOnValidationErrFunc(onErr))),
},
},
}
suite.Run(t, csAll)
require.Equal(t, []string{"cannot sleep for more than 10s"}, gotErrMsgs)
}
27 changes: 11 additions & 16 deletions interceptors/validator/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,40 +3,35 @@

package validator

import "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/logging"
import (
"context"
)

type options struct {
level logging.Level
logger logging.Logger
shouldFailFast bool
shouldFailFast bool
onValidationErrFunc OnValidationErr
}
type Option func(*options)

func evaluateServerOpt(opts []Option) *options {
func evaluateOpts(opts []Option) *options {
optCopy := &options{}
for _, o := range opts {
o(optCopy)
}
return optCopy
}

func evaluateClientOpt(opts []Option) *options {
optCopy := &options{}
for _, o := range opts {
o(optCopy)
}
return optCopy
}
type OnValidationErr func(ctx context.Context, err error)

// WithLogger tells validator to log all the validation errors with the given log level.
func WithLogger(level logging.Level, logger logging.Logger) Option {
// WithOnValidationErrFunc registers function that will be invoked on validation error(s).
func WithOnValidationErrFunc(onValidationErrFunc OnValidationErr) Option {
return func(o *options) {
o.level = level
o.logger = logger
o.onValidationErrFunc = onValidationErrFunc
}
}

// WithFailFast tells validator to immediately stop doing further validation after first validation error.
// This option is ignored if message is only supporting validator.validatorLegacy interface.
func WithFailFast() Option {
return func(o *options) {
o.shouldFailFast = true
Expand Down
57 changes: 18 additions & 39 deletions interceptors/validator/validator.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ package validator
import (
"context"

"github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/logging"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
Expand All @@ -28,51 +27,31 @@ type validatorLegacy interface {
Validate() error
}

func log(level logging.Level, logger logging.Logger, msg string) {
if logger != nil {
// TODO(bwplotka): Fix in separate PR.
logger.Log(context.TODO(), level, msg)
}
}

func validate(req interface{}, shouldFailFast bool, level logging.Level, logger logging.Logger) error {
// shouldFailFast tells validator to immediately stop doing further validation after first validation error.
func validate(ctx context.Context, reqOrRes interface{}, shouldFailFast bool, onValidationErrFunc OnValidationErr) (err error) {
if shouldFailFast {
switch v := req.(type) {
switch v := reqOrRes.(type) {
case validatorLegacy:
if err := v.Validate(); err != nil {
log(level, logger, err.Error())
return status.Error(codes.InvalidArgument, err.Error())
}
err = v.Validate()
case validator:
if err := v.Validate(false); err != nil {
log(level, logger, err.Error())
return status.Error(codes.InvalidArgument, err.Error())
}
err = v.Validate(false)
}
} else {
switch v := reqOrRes.(type) {
case validateAller:
err = v.ValidateAll()
case validator:
err = v.Validate(true)
case validatorLegacy:
err = v.Validate()
}
}

if err == nil {
return nil
}

// shouldNotFailFast tells validator to continue doing further validation even if after a validation error.
switch v := req.(type) {
case validateAller:
if err := v.ValidateAll(); err != nil {
log(level, logger, err.Error())
return status.Error(codes.InvalidArgument, err.Error())
}
case validator:
if err := v.Validate(true); err != nil {
log(level, logger, err.Error())
return status.Error(codes.InvalidArgument, err.Error())
}
case validatorLegacy:
// Fallback to legacy validator
if err := v.Validate(); err != nil {
log(level, logger, err.Error())
return status.Error(codes.InvalidArgument, err.Error())
}
if onValidationErrFunc != nil {
onValidationErrFunc(ctx, err)
}

return nil
return status.Error(codes.InvalidArgument, err.Error())
}
35 changes: 16 additions & 19 deletions interceptors/validator/validator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,28 +7,25 @@ import (
"context"
"testing"

"github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/logging"
"github.com/grpc-ecosystem/go-grpc-middleware/v2/testing/testpb"
"github.com/stretchr/testify/assert"
)

type TestLogger struct{}

func (l *TestLogger) Log(ctx context.Context, level logging.Level, msg string, fields ...any) {}

func TestValidateWrapper(t *testing.T) {
assert.NoError(t, validate(testpb.GoodPing, false, logging.LevelError, &TestLogger{}))
assert.Error(t, validate(testpb.BadPing, false, logging.LevelError, &TestLogger{}))
assert.NoError(t, validate(testpb.GoodPing, true, logging.LevelError, &TestLogger{}))
assert.Error(t, validate(testpb.BadPing, true, logging.LevelError, &TestLogger{}))

assert.NoError(t, validate(testpb.GoodPingError, false, logging.LevelError, &TestLogger{}))
assert.Error(t, validate(testpb.BadPingError, false, logging.LevelError, &TestLogger{}))
assert.NoError(t, validate(testpb.GoodPingError, true, logging.LevelError, &TestLogger{}))
assert.Error(t, validate(testpb.BadPingError, true, logging.LevelError, &TestLogger{}))

assert.NoError(t, validate(testpb.GoodPingResponse, false, logging.LevelError, &TestLogger{}))
assert.NoError(t, validate(testpb.GoodPingResponse, true, logging.LevelError, &TestLogger{}))
assert.Error(t, validate(testpb.BadPingResponse, false, logging.LevelError, &TestLogger{}))
assert.Error(t, validate(testpb.BadPingResponse, true, logging.LevelError, &TestLogger{}))
ctx := context.Background()

assert.NoError(t, validate(ctx, testpb.GoodPing, false, nil))
assert.Error(t, validate(ctx, testpb.BadPing, false, nil))
assert.NoError(t, validate(ctx, testpb.GoodPing, true, nil))
assert.Error(t, validate(ctx, testpb.BadPing, true, nil))

assert.NoError(t, validate(ctx, testpb.GoodPingError, false, nil))
assert.Error(t, validate(ctx, testpb.BadPingError, false, nil))
assert.NoError(t, validate(ctx, testpb.GoodPingError, true, nil))
assert.Error(t, validate(ctx, testpb.BadPingError, true, nil))

assert.NoError(t, validate(ctx, testpb.GoodPingResponse, false, nil))
assert.NoError(t, validate(ctx, testpb.GoodPingResponse, true, nil))
assert.Error(t, validate(ctx, testpb.BadPingResponse, false, nil))
assert.Error(t, validate(ctx, testpb.BadPingResponse, true, nil))
}
4 changes: 2 additions & 2 deletions testing/testpb/test.manual_validator.pb.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,15 @@ func (x *PingStreamRequest) Validate(bool) error {
return nil
}

// Implements the legacy validation interface from protoc-gen-validate.
// Validate implements the legacy validation interface from protoc-gen-validate.
func (x *PingResponse) Validate() error {
if x.Counter > math.MaxInt16 {
return errors.New("ping allocation exceeded")
}
return nil
}

// Implements the new ValidateAll interface from protoc-gen-validate.
// ValidateAll implements the new ValidateAll interface from protoc-gen-validate.
func (x *PingResponse) ValidateAll() error {
if x.Counter > math.MaxInt16 {
return errors.New("ping allocation exceeded")
Expand Down