diff --git a/interceptors/validator/interceptors.go b/interceptors/validator/interceptors.go index 69c8a054a..46c5ed0bd 100644 --- a/interceptors/validator/interceptors.go +++ b/interceptors/validator/interceptors.go @@ -12,15 +12,12 @@ import ( // UnaryServerInterceptor returns a new unary server interceptor that validates incoming messages. // // Invalid messages will be rejected with `InvalidArgument` before reaching any userspace handlers. -// If `WithFailFast` used it will interceptor and returns the first validation error. Otherwise, the interceptor -// returns ALL validation error as a wrapped multi-error. -// If `WithLogger` used it will log all the validation errors. Otherwise, no default logging. // Note that generated codes prior to protoc-gen-validate v0.6.0 do not provide an all-validation // interface. In this case the interceptor fallbacks to legacy validation and `all` is ignored. func UnaryServerInterceptor(opts ...Option) grpc.UnaryServerInterceptor { - o := evaluateServerOpt(opts) + o := evaluateOpts(opts) return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { - if err := validate(req, o.shouldFailFast, o.level, o.logger); err != nil { + if err := validate(ctx, req, o.shouldFailFast, o.onValidationErrFunc); err != nil { return nil, err } return handler(ctx, req) @@ -30,15 +27,12 @@ func UnaryServerInterceptor(opts ...Option) grpc.UnaryServerInterceptor { // UnaryClientInterceptor returns a new unary client interceptor that validates outgoing messages. // // Invalid messages will be rejected with `InvalidArgument` before sending the request to server. -// If `WithFailFast` used it will interceptor and returns the first validation error. Otherwise, the interceptor -// returns ALL validation error as a wrapped multi-error. -// If `WithLogger` used it will log all the validation errors. Otherwise, no default logging. // Note that generated codes prior to protoc-gen-validate v0.6.0 do not provide an all-validation // interface. In this case the interceptor fallbacks to legacy validation and `all` is ignored. func UnaryClientInterceptor(opts ...Option) grpc.UnaryClientInterceptor { - o := evaluateClientOpt(opts) + o := evaluateOpts(opts) return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { - if err := validate(req, o.shouldFailFast, o.level, o.logger); err != nil { + if err := validate(ctx, req, o.shouldFailFast, o.onValidationErrFunc); err != nil { return err } return invoker(ctx, method, req, reply, cc, opts...) @@ -47,9 +41,6 @@ func UnaryClientInterceptor(opts ...Option) grpc.UnaryClientInterceptor { // StreamServerInterceptor returns a new streaming server interceptor that validates incoming messages. // -// If `WithFailFast` used it will interceptor and returns the first validation error. Otherwise, the interceptor -// returns ALL validation error as a wrapped multi-error. -// If `WithLogger` used it will log all the validation errors. Otherwise, no default logging. // Note that generated codes prior to protoc-gen-validate v0.6.0 do not provide an all-validation // interface. In this case the interceptor fallbacks to legacy validation and `all` is ignored. // The stage at which invalid messages will be rejected with `InvalidArgument` varies based on the @@ -57,7 +48,7 @@ func UnaryClientInterceptor(opts ...Option) grpc.UnaryClientInterceptor { // handlers. For `ClientStream` (n:1) or `BidiStream` (n:m) RPCs, the messages will be rejected on // calls to `stream.Recv()`. func StreamServerInterceptor(opts ...Option) grpc.StreamServerInterceptor { - o := evaluateServerOpt(opts) + o := evaluateOpts(opts) return func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { wrapper := &recvWrapper{ options: o, @@ -77,7 +68,7 @@ func (s *recvWrapper) RecvMsg(m any) error { if err := s.ServerStream.RecvMsg(m); err != nil { return err } - if err := validate(m, s.shouldFailFast, s.level, s.logger); err != nil { + if err := validate(s.Context(), m, s.shouldFailFast, s.onValidationErrFunc); err != nil { return err } return nil diff --git a/interceptors/validator/interceptors_test.go b/interceptors/validator/interceptors_test.go index b29d3fd4b..43f8462ee 100644 --- a/interceptors/validator/interceptors_test.go +++ b/interceptors/validator/interceptors_test.go @@ -8,7 +8,6 @@ import ( "io" "testing" - "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/logging" "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/validator" "github.com/grpc-ecosystem/go-grpc-middleware/v2/testing/testpb" "github.com/stretchr/testify/assert" @@ -19,10 +18,6 @@ import ( "google.golang.org/grpc/status" ) -type TestLogger struct{} - -func (l *TestLogger) Log(ctx context.Context, level logging.Level, msg string, fields ...any) {} - type ValidatorTestSuite struct { *testpb.InterceptorTestSuite } @@ -104,7 +99,7 @@ func TestValidatorTestSuite(t *testing.T) { } suite.Run(t, sWithNoArgs) - sWithWithFailFastArgs := &ValidatorTestSuite{ + sWithFailFastArgs := &ValidatorTestSuite{ InterceptorTestSuite: &testpb.InterceptorTestSuite{ ServerOpts: []grpc.ServerOption{ grpc.StreamInterceptor(validator.StreamServerInterceptor(validator.WithFailFast())), @@ -112,27 +107,34 @@ func TestValidatorTestSuite(t *testing.T) { }, }, } - suite.Run(t, sWithWithFailFastArgs) + suite.Run(t, sWithFailFastArgs) - sWithWithLoggerArgs := &ValidatorTestSuite{ + var gotErrMsgs []string + onErr := func(ctx context.Context, err error) { + gotErrMsgs = append(gotErrMsgs, err.Error()) + } + sWithOnErrFuncArgs := &ValidatorTestSuite{ InterceptorTestSuite: &testpb.InterceptorTestSuite{ ServerOpts: []grpc.ServerOption{ - grpc.StreamInterceptor(validator.StreamServerInterceptor(validator.WithLogger(logging.LevelDebug, &TestLogger{}))), - grpc.UnaryInterceptor(validator.UnaryServerInterceptor(validator.WithLogger(logging.LevelDebug, &TestLogger{}))), + grpc.StreamInterceptor(validator.StreamServerInterceptor(validator.WithOnValidationErrFunc(onErr))), + grpc.UnaryInterceptor(validator.UnaryServerInterceptor(validator.WithOnValidationErrFunc(onErr))), }, }, } - suite.Run(t, sWithWithLoggerArgs) + suite.Run(t, sWithOnErrFuncArgs) + require.Equal(t, []string{"cannot sleep for more than 10s", "cannot sleep for more than 10s", "cannot sleep for more than 10s"}, gotErrMsgs) + gotErrMsgs = gotErrMsgs[:0] sAll := &ValidatorTestSuite{ InterceptorTestSuite: &testpb.InterceptorTestSuite{ ServerOpts: []grpc.ServerOption{ - grpc.StreamInterceptor(validator.StreamServerInterceptor(validator.WithFailFast(), validator.WithLogger(logging.LevelDebug, &TestLogger{}))), - grpc.UnaryInterceptor(validator.UnaryServerInterceptor(validator.WithFailFast(), validator.WithLogger(logging.LevelDebug, &TestLogger{}))), + grpc.StreamInterceptor(validator.StreamServerInterceptor(validator.WithFailFast(), validator.WithOnValidationErrFunc(onErr))), + grpc.UnaryInterceptor(validator.UnaryServerInterceptor(validator.WithFailFast(), validator.WithOnValidationErrFunc(onErr))), }, }, } suite.Run(t, sAll) + require.Equal(t, []string{"cannot sleep for more than 10s", "cannot sleep for more than 10s", "cannot sleep for more than 10s"}, gotErrMsgs) csWithNoArgs := &ClientValidatorTestSuite{ InterceptorTestSuite: &testpb.InterceptorTestSuite{ @@ -143,30 +145,34 @@ func TestValidatorTestSuite(t *testing.T) { } suite.Run(t, csWithNoArgs) - csWithWithFailFastArgs := &ClientValidatorTestSuite{ + csWithFailFastArgs := &ClientValidatorTestSuite{ InterceptorTestSuite: &testpb.InterceptorTestSuite{ ServerOpts: []grpc.ServerOption{ grpc.UnaryInterceptor(validator.UnaryServerInterceptor(validator.WithFailFast())), }, }, } - suite.Run(t, csWithWithFailFastArgs) + suite.Run(t, csWithFailFastArgs) - csWithWithLoggerArgs := &ClientValidatorTestSuite{ + gotErrMsgs = gotErrMsgs[:0] + csWithOnErrFuncArgs := &ClientValidatorTestSuite{ InterceptorTestSuite: &testpb.InterceptorTestSuite{ ServerOpts: []grpc.ServerOption{ - grpc.UnaryInterceptor(validator.UnaryServerInterceptor(validator.WithLogger(logging.LevelDebug, &TestLogger{}))), + grpc.UnaryInterceptor(validator.UnaryServerInterceptor(validator.WithOnValidationErrFunc(onErr))), }, }, } - suite.Run(t, csWithWithLoggerArgs) + suite.Run(t, csWithOnErrFuncArgs) + require.Equal(t, []string{"cannot sleep for more than 10s"}, gotErrMsgs) + gotErrMsgs = gotErrMsgs[:0] csAll := &ClientValidatorTestSuite{ InterceptorTestSuite: &testpb.InterceptorTestSuite{ ClientOpts: []grpc.DialOption{ - grpc.WithUnaryInterceptor(validator.UnaryClientInterceptor(validator.WithFailFast())), + grpc.WithUnaryInterceptor(validator.UnaryClientInterceptor(validator.WithFailFast(), validator.WithOnValidationErrFunc(onErr))), }, }, } suite.Run(t, csAll) + require.Equal(t, []string{"cannot sleep for more than 10s"}, gotErrMsgs) } diff --git a/interceptors/validator/options.go b/interceptors/validator/options.go index 357d0d62a..3f4cc946b 100644 --- a/interceptors/validator/options.go +++ b/interceptors/validator/options.go @@ -3,16 +3,17 @@ package validator -import "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/logging" +import ( + "context" +) type options struct { - level logging.Level - logger logging.Logger - shouldFailFast bool + shouldFailFast bool + onValidationErrFunc OnValidationErr } type Option func(*options) -func evaluateServerOpt(opts []Option) *options { +func evaluateOpts(opts []Option) *options { optCopy := &options{} for _, o := range opts { o(optCopy) @@ -20,23 +21,17 @@ func evaluateServerOpt(opts []Option) *options { return optCopy } -func evaluateClientOpt(opts []Option) *options { - optCopy := &options{} - for _, o := range opts { - o(optCopy) - } - return optCopy -} +type OnValidationErr func(ctx context.Context, err error) -// WithLogger tells validator to log all the validation errors with the given log level. -func WithLogger(level logging.Level, logger logging.Logger) Option { +// WithOnValidationErrFunc registers function that will be invoked on validation error(s). +func WithOnValidationErrFunc(onValidationErrFunc OnValidationErr) Option { return func(o *options) { - o.level = level - o.logger = logger + o.onValidationErrFunc = onValidationErrFunc } } // WithFailFast tells validator to immediately stop doing further validation after first validation error. +// This option is ignored if message is only supporting validator.validatorLegacy interface. func WithFailFast() Option { return func(o *options) { o.shouldFailFast = true diff --git a/interceptors/validator/validator.go b/interceptors/validator/validator.go index c56f4e399..d6d72558c 100644 --- a/interceptors/validator/validator.go +++ b/interceptors/validator/validator.go @@ -6,7 +6,6 @@ package validator import ( "context" - "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/logging" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" ) @@ -28,51 +27,31 @@ type validatorLegacy interface { Validate() error } -func log(level logging.Level, logger logging.Logger, msg string) { - if logger != nil { - // TODO(bwplotka): Fix in separate PR. - logger.Log(context.TODO(), level, msg) - } -} - -func validate(req interface{}, shouldFailFast bool, level logging.Level, logger logging.Logger) error { - // shouldFailFast tells validator to immediately stop doing further validation after first validation error. +func validate(ctx context.Context, reqOrRes interface{}, shouldFailFast bool, onValidationErrFunc OnValidationErr) (err error) { if shouldFailFast { - switch v := req.(type) { + switch v := reqOrRes.(type) { case validatorLegacy: - if err := v.Validate(); err != nil { - log(level, logger, err.Error()) - return status.Error(codes.InvalidArgument, err.Error()) - } + err = v.Validate() case validator: - if err := v.Validate(false); err != nil { - log(level, logger, err.Error()) - return status.Error(codes.InvalidArgument, err.Error()) - } + err = v.Validate(false) + } + } else { + switch v := reqOrRes.(type) { + case validateAller: + err = v.ValidateAll() + case validator: + err = v.Validate(true) + case validatorLegacy: + err = v.Validate() } + } + if err == nil { return nil } - // shouldNotFailFast tells validator to continue doing further validation even if after a validation error. - switch v := req.(type) { - case validateAller: - if err := v.ValidateAll(); err != nil { - log(level, logger, err.Error()) - return status.Error(codes.InvalidArgument, err.Error()) - } - case validator: - if err := v.Validate(true); err != nil { - log(level, logger, err.Error()) - return status.Error(codes.InvalidArgument, err.Error()) - } - case validatorLegacy: - // Fallback to legacy validator - if err := v.Validate(); err != nil { - log(level, logger, err.Error()) - return status.Error(codes.InvalidArgument, err.Error()) - } + if onValidationErrFunc != nil { + onValidationErrFunc(ctx, err) } - - return nil + return status.Error(codes.InvalidArgument, err.Error()) } diff --git a/interceptors/validator/validator_test.go b/interceptors/validator/validator_test.go index a88c9006e..6d1891144 100644 --- a/interceptors/validator/validator_test.go +++ b/interceptors/validator/validator_test.go @@ -7,28 +7,25 @@ import ( "context" "testing" - "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/logging" "github.com/grpc-ecosystem/go-grpc-middleware/v2/testing/testpb" "github.com/stretchr/testify/assert" ) -type TestLogger struct{} - -func (l *TestLogger) Log(ctx context.Context, level logging.Level, msg string, fields ...any) {} - func TestValidateWrapper(t *testing.T) { - assert.NoError(t, validate(testpb.GoodPing, false, logging.LevelError, &TestLogger{})) - assert.Error(t, validate(testpb.BadPing, false, logging.LevelError, &TestLogger{})) - assert.NoError(t, validate(testpb.GoodPing, true, logging.LevelError, &TestLogger{})) - assert.Error(t, validate(testpb.BadPing, true, logging.LevelError, &TestLogger{})) - - assert.NoError(t, validate(testpb.GoodPingError, false, logging.LevelError, &TestLogger{})) - assert.Error(t, validate(testpb.BadPingError, false, logging.LevelError, &TestLogger{})) - assert.NoError(t, validate(testpb.GoodPingError, true, logging.LevelError, &TestLogger{})) - assert.Error(t, validate(testpb.BadPingError, true, logging.LevelError, &TestLogger{})) - - assert.NoError(t, validate(testpb.GoodPingResponse, false, logging.LevelError, &TestLogger{})) - assert.NoError(t, validate(testpb.GoodPingResponse, true, logging.LevelError, &TestLogger{})) - assert.Error(t, validate(testpb.BadPingResponse, false, logging.LevelError, &TestLogger{})) - assert.Error(t, validate(testpb.BadPingResponse, true, logging.LevelError, &TestLogger{})) + ctx := context.Background() + + assert.NoError(t, validate(ctx, testpb.GoodPing, false, nil)) + assert.Error(t, validate(ctx, testpb.BadPing, false, nil)) + assert.NoError(t, validate(ctx, testpb.GoodPing, true, nil)) + assert.Error(t, validate(ctx, testpb.BadPing, true, nil)) + + assert.NoError(t, validate(ctx, testpb.GoodPingError, false, nil)) + assert.Error(t, validate(ctx, testpb.BadPingError, false, nil)) + assert.NoError(t, validate(ctx, testpb.GoodPingError, true, nil)) + assert.Error(t, validate(ctx, testpb.BadPingError, true, nil)) + + assert.NoError(t, validate(ctx, testpb.GoodPingResponse, false, nil)) + assert.NoError(t, validate(ctx, testpb.GoodPingResponse, true, nil)) + assert.Error(t, validate(ctx, testpb.BadPingResponse, false, nil)) + assert.Error(t, validate(ctx, testpb.BadPingResponse, true, nil)) } diff --git a/testing/testpb/test.manual_validator.pb.go b/testing/testpb/test.manual_validator.pb.go index ec1e8f639..6f7acf0fb 100644 --- a/testing/testpb/test.manual_validator.pb.go +++ b/testing/testpb/test.manual_validator.pb.go @@ -36,7 +36,7 @@ func (x *PingStreamRequest) Validate(bool) error { return nil } -// Implements the legacy validation interface from protoc-gen-validate. +// Validate implements the legacy validation interface from protoc-gen-validate. func (x *PingResponse) Validate() error { if x.Counter > math.MaxInt16 { return errors.New("ping allocation exceeded") @@ -44,7 +44,7 @@ func (x *PingResponse) Validate() error { return nil } -// Implements the new ValidateAll interface from protoc-gen-validate. +// ValidateAll implements the new ValidateAll interface from protoc-gen-validate. func (x *PingResponse) ValidateAll() error { if x.Counter > math.MaxInt16 { return errors.New("ping allocation exceeded")