diff --git a/backend/adapter_utils.go b/backend/adapter_utils.go index 57864035f..0955d1e17 100644 --- a/backend/adapter_utils.go +++ b/backend/adapter_utils.go @@ -35,7 +35,7 @@ func wrapHandler(ctx context.Context, pluginCtx PluginContext, next handlerWrapp } func setupHandlerContext(ctx context.Context, pluginCtx PluginContext) context.Context { - ctx = initErrorSource(ctx) + ctx = InitErrorSource(ctx) ctx = WithGrafanaConfig(ctx, pluginCtx.GrafanaConfig) ctx = WithPluginContext(ctx, pluginCtx) ctx = WithUser(ctx, pluginCtx.User) @@ -75,7 +75,7 @@ func metricWrapper(next handlerWrapperFunc) handlerWrapperFunc { endpoint := EndpointFromContext(ctx) status, err := next(ctx) - pluginRequestCounter.WithLabelValues(endpoint.String(), status.String(), string(errorSourceFromContext(ctx))).Inc() + pluginRequestCounter.WithLabelValues(endpoint.String(), status.String(), string(ErrorSourceFromContext(ctx))).Inc() return status, err } @@ -106,7 +106,7 @@ func tracingWrapper(next handlerWrapperFunc) handlerWrapperFunc { span.SetAttributes( attribute.String("request_status", status.String()), - attribute.String("status_source", string(errorSourceFromContext(ctx))), + attribute.String("status_source", string(ErrorSourceFromContext(ctx))), ) if err != nil { @@ -136,7 +136,7 @@ func logWrapper(next handlerWrapperFunc) handlerWrapperFunc { logParams = append(logParams, "error", err) } - logParams = append(logParams, "statusSource", string(errorSourceFromContext(ctx))) + logParams = append(logParams, "statusSource", string(ErrorSourceFromContext(ctx))) if status > RequestStatusCancelled { logFunc = ctxLogger.Error diff --git a/backend/adapter_utils_test.go b/backend/adapter_utils_test.go index 15223ce48..0a3610714 100644 --- a/backend/adapter_utils_test.go +++ b/backend/adapter_utils_test.go @@ -10,7 +10,7 @@ import ( func TestErrorWrapper(t *testing.T) { t.Run("No downstream error should not set downstream error source in context", func(t *testing.T) { - ctx := initErrorSource(context.Background()) + ctx := InitErrorSource(context.Background()) actualErr := errors.New("BOOM") wrapper := errorWrapper(func(_ context.Context) (RequestStatus, error) { @@ -19,11 +19,11 @@ func TestErrorWrapper(t *testing.T) { status, err := wrapper(ctx) require.ErrorIs(t, err, actualErr) require.Equal(t, RequestStatusError, status) - require.Equal(t, DefaultErrorSource, errorSourceFromContext(ctx)) + require.Equal(t, DefaultErrorSource, ErrorSourceFromContext(ctx)) }) t.Run("Downstream error should set downstream error source in context", func(t *testing.T) { - ctx := initErrorSource(context.Background()) + ctx := InitErrorSource(context.Background()) actualErr := errors.New("BOOM") wrapper := errorWrapper(func(_ context.Context) (RequestStatus, error) { @@ -32,6 +32,6 @@ func TestErrorWrapper(t *testing.T) { status, err := wrapper(ctx) require.ErrorIs(t, err, actualErr) require.Equal(t, RequestStatusError, status) - require.Equal(t, ErrorSourceDownstream, errorSourceFromContext(ctx)) + require.Equal(t, ErrorSourceDownstream, ErrorSourceFromContext(ctx)) }) } diff --git a/backend/data_adapter_test.go b/backend/data_adapter_test.go index f20ef284a..b78cd26b7 100644 --- a/backend/data_adapter_test.go +++ b/backend/data_adapter_test.go @@ -224,7 +224,7 @@ func TestQueryData(t *testing.T) { require.NoError(t, err) } - ss := errorSourceFromContext(actualCtx) + ss := ErrorSourceFromContext(actualCtx) require.Equal(t, tc.expErrorSource, ss) }) } diff --git a/backend/error_source.go b/backend/error_source.go index 8c157cf30..c92039399 100644 --- a/backend/error_source.go +++ b/backend/error_source.go @@ -107,9 +107,9 @@ func (e errorWithSourceImpl) Unwrap() error { type errorSourceCtxKey struct{} -// errorSourceFromContext returns the error source stored in the context. +// ErrorSourceFromContext returns the error source stored in the context. // If no error source is stored in the context, [DefaultErrorSource] is returned. -func errorSourceFromContext(ctx context.Context) ErrorSource { +func ErrorSourceFromContext(ctx context.Context) ErrorSource { value, ok := ctx.Value(errorSourceCtxKey{}).(*ErrorSource) if ok { return *value @@ -117,8 +117,8 @@ func errorSourceFromContext(ctx context.Context) ErrorSource { return DefaultErrorSource } -// initErrorSource initialize the status source for the context. -func initErrorSource(ctx context.Context) context.Context { +// InitErrorSource initialize the error source for the context. +func InitErrorSource(ctx context.Context) context.Context { s := DefaultErrorSource return context.WithValue(ctx, errorSourceCtxKey{}, &s) } diff --git a/backend/error_source_middleware.go b/backend/error_source_middleware.go new file mode 100644 index 000000000..764306da0 --- /dev/null +++ b/backend/error_source_middleware.go @@ -0,0 +1,128 @@ +package backend + +import ( + "context" + "errors" + "fmt" +) + +// NewErrorSourceMiddleware returns a new backend.HandlerMiddleware that sets the error source in the +// context.Context, based on returned errors or query data response errors. +// If at least one query data response has a "downstream" error source and there isn't one with a "plugin" error source, +// the error source in the context is set to "downstream". +func NewErrorSourceMiddleware() HandlerMiddleware { + return HandlerMiddlewareFunc(func(next Handler) Handler { + return &ErrorSourceMiddleware{ + BaseHandler: NewBaseHandler(next), + } + }) +} + +type ErrorSourceMiddleware struct { + BaseHandler +} + +func (m *ErrorSourceMiddleware) handleDownstreamError(ctx context.Context, err error) error { + if err == nil { + return nil + } + + if IsDownstreamError(err) { + if innerErr := WithDownstreamErrorSource(ctx); innerErr != nil { + return fmt.Errorf("failed to set downstream error source: %w", errors.Join(innerErr, err)) + } + } + + return err +} + +func (m *ErrorSourceMiddleware) QueryData(ctx context.Context, req *QueryDataRequest) (*QueryDataResponse, error) { + resp, err := m.BaseHandler.QueryData(ctx, req) + err = m.handleDownstreamError(ctx, err) + + if err != nil { + return resp, err + } else if resp == nil || len(resp.Responses) == 0 { + return nil, errors.New("both response and error are nil, but one must be provided") + } + + // Set downstream error source in the context if there's at least one response with downstream error source, + // and if there's no plugin error + var hasPluginError bool + var hasDownstreamError bool + for _, r := range resp.Responses { + if r.Error == nil { + continue + } + + // if error source not set and the error is a downstream error, set error source to downstream. + if !r.ErrorSource.IsValid() && IsDownstreamError(r.Error) { + r.ErrorSource = ErrorSourceDownstream + } + + if !r.Status.IsValid() { + r.Status = statusFromError(r.Error) + } + + if r.ErrorSource == ErrorSourceDownstream { + hasDownstreamError = true + } else { + hasPluginError = true + } + } + + // A plugin error has higher priority than a downstream error, + // so set to downstream only if there's no plugin error + if hasDownstreamError && !hasPluginError { + if err := WithDownstreamErrorSource(ctx); err != nil { + return resp, fmt.Errorf("failed to set downstream status source: %w", err) + } + } + + return resp, err +} + +func (m *ErrorSourceMiddleware) CallResource(ctx context.Context, req *CallResourceRequest, sender CallResourceResponseSender) error { + err := m.BaseHandler.CallResource(ctx, req, sender) + return m.handleDownstreamError(ctx, err) +} + +func (m *ErrorSourceMiddleware) CheckHealth(ctx context.Context, req *CheckHealthRequest) (*CheckHealthResult, error) { + resp, err := m.BaseHandler.CheckHealth(ctx, req) + return resp, m.handleDownstreamError(ctx, err) +} + +func (m *ErrorSourceMiddleware) CollectMetrics(ctx context.Context, req *CollectMetricsRequest) (*CollectMetricsResult, error) { + resp, err := m.BaseHandler.CollectMetrics(ctx, req) + return resp, m.handleDownstreamError(ctx, err) +} + +func (m *ErrorSourceMiddleware) SubscribeStream(ctx context.Context, req *SubscribeStreamRequest) (*SubscribeStreamResponse, error) { + resp, err := m.BaseHandler.SubscribeStream(ctx, req) + return resp, m.handleDownstreamError(ctx, err) +} + +func (m *ErrorSourceMiddleware) PublishStream(ctx context.Context, req *PublishStreamRequest) (*PublishStreamResponse, error) { + resp, err := m.BaseHandler.PublishStream(ctx, req) + return resp, m.handleDownstreamError(ctx, err) +} + +func (m *ErrorSourceMiddleware) RunStream(ctx context.Context, req *RunStreamRequest, sender *StreamSender) error { + err := m.BaseHandler.RunStream(ctx, req, sender) + return m.handleDownstreamError(ctx, err) +} + +func (m *ErrorSourceMiddleware) ValidateAdmission(ctx context.Context, req *AdmissionRequest) (*ValidationResponse, error) { + resp, err := m.BaseHandler.ValidateAdmission(ctx, req) + return resp, m.handleDownstreamError(ctx, err) +} + +func (m *ErrorSourceMiddleware) MutateAdmission(ctx context.Context, req *AdmissionRequest) (*MutationResponse, error) { + resp, err := m.BaseHandler.MutateAdmission(ctx, req) + return resp, m.handleDownstreamError(ctx, err) +} + +func (m *ErrorSourceMiddleware) ConvertObjects(ctx context.Context, req *ConversionRequest) (*ConversionResponse, error) { + resp, err := m.BaseHandler.ConvertObjects(ctx, req) + return resp, m.handleDownstreamError(ctx, err) +} diff --git a/backend/error_source_middleware_test.go b/backend/error_source_middleware_test.go new file mode 100644 index 000000000..681a7dadb --- /dev/null +++ b/backend/error_source_middleware_test.go @@ -0,0 +1,207 @@ +package backend_test + +import ( + "context" + "errors" + "testing" + + "github.com/grafana/grafana-plugin-sdk-go/backend" + "github.com/grafana/grafana-plugin-sdk-go/backend/handlertest" + "github.com/stretchr/testify/require" +) + +func TestErrorSourceMiddleware(t *testing.T) { + someErr := errors.New("oops") + downstreamErr := backend.DownstreamError(someErr) + + t.Run("Handlers return errors", func(t *testing.T) { + for _, tc := range []struct { + name string + err error + expErrorSource backend.ErrorSource + }{ + { + name: `no downstream error`, + err: someErr, + expErrorSource: backend.ErrorSourcePlugin, + }, + { + name: `downstream error`, + err: downstreamErr, + expErrorSource: backend.ErrorSourceDownstream, + }, + } { + t.Run(tc.name, func(t *testing.T) { + cdt := handlertest.NewHandlerMiddlewareTest(t, + handlertest.WithMiddlewares( + backend.NewErrorSourceMiddleware(), + ), + ) + setupHandlersWithError(cdt, tc.err) + + _, err := cdt.MiddlewareHandler.QueryData(context.Background(), &backend.QueryDataRequest{}) + require.Error(t, err) + ss := backend.ErrorSourceFromContext(cdt.QueryDataCtx) + require.Equal(t, tc.expErrorSource, ss) + + _, err = cdt.MiddlewareHandler.CheckHealth(context.Background(), &backend.CheckHealthRequest{}) + require.Error(t, err) + ss = backend.ErrorSourceFromContext(cdt.CheckHealthCtx) + require.Equal(t, tc.expErrorSource, ss) + + err = cdt.MiddlewareHandler.CallResource(context.Background(), &backend.CallResourceRequest{}, backend.CallResourceResponseSenderFunc(func(_ *backend.CallResourceResponse) error { return nil })) + require.Error(t, err) + ss = backend.ErrorSourceFromContext(cdt.CallResourceCtx) + require.Equal(t, tc.expErrorSource, ss) + + _, err = cdt.MiddlewareHandler.CollectMetrics(context.Background(), &backend.CollectMetricsRequest{}) + require.Error(t, err) + ss = backend.ErrorSourceFromContext(cdt.CollectMetricsCtx) + require.Equal(t, tc.expErrorSource, ss) + + _, err = cdt.MiddlewareHandler.SubscribeStream(context.Background(), &backend.SubscribeStreamRequest{}) + require.Error(t, err) + ss = backend.ErrorSourceFromContext(cdt.SubscribeStreamCtx) + require.Equal(t, tc.expErrorSource, ss) + + err = cdt.MiddlewareHandler.RunStream(context.Background(), &backend.RunStreamRequest{}, backend.NewStreamSender(nil)) + require.Error(t, err) + ss = backend.ErrorSourceFromContext(cdt.RunStreamCtx) + require.Equal(t, tc.expErrorSource, ss) + + _, err = cdt.MiddlewareHandler.PublishStream(context.Background(), &backend.PublishStreamRequest{}) + require.Error(t, err) + ss = backend.ErrorSourceFromContext(cdt.PublishStreamCtx) + require.Equal(t, tc.expErrorSource, ss) + + _, err = cdt.MiddlewareHandler.ValidateAdmission(context.Background(), &backend.AdmissionRequest{}) + require.Error(t, err) + ss = backend.ErrorSourceFromContext(cdt.ValidateAdmissionCtx) + require.Equal(t, tc.expErrorSource, ss) + + _, err = cdt.MiddlewareHandler.MutateAdmission(context.Background(), &backend.AdmissionRequest{}) + require.Error(t, err) + ss = backend.ErrorSourceFromContext(cdt.MutateAdmissionCtx) + require.Equal(t, tc.expErrorSource, ss) + + _, err = cdt.MiddlewareHandler.ConvertObjects(context.Background(), &backend.ConversionRequest{}) + require.Error(t, err) + ss = backend.ErrorSourceFromContext(cdt.ConvertObjectCtx) + require.Equal(t, tc.expErrorSource, ss) + }) + } + }) + + t.Run("QueryData response with errors", func(t *testing.T) { + for _, tc := range []struct { + name string + queryDataResponse *backend.QueryDataResponse + expErrorSource backend.ErrorSource + }{ + { + name: `no error should be "plugin" error source`, + queryDataResponse: nil, + expErrorSource: backend.ErrorSourcePlugin, + }, + { + name: `single downstream error should be "downstream" error source`, + queryDataResponse: &backend.QueryDataResponse{ + Responses: map[string]backend.DataResponse{ + "A": {Error: someErr, ErrorSource: backend.ErrorSourceDownstream}, + }, + }, + expErrorSource: backend.ErrorSourceDownstream, + }, + { + name: `single plugin error should be "plugin" error source`, + queryDataResponse: &backend.QueryDataResponse{ + Responses: map[string]backend.DataResponse{ + "A": {Error: someErr, ErrorSource: backend.ErrorSourcePlugin}, + }, + }, + expErrorSource: backend.ErrorSourcePlugin, + }, + { + name: `multiple downstream errors should be "downstream" error source`, + queryDataResponse: &backend.QueryDataResponse{ + Responses: map[string]backend.DataResponse{ + "A": {Error: someErr, ErrorSource: backend.ErrorSourceDownstream}, + "B": {Error: someErr, ErrorSource: backend.ErrorSourceDownstream}, + }, + }, + expErrorSource: backend.ErrorSourceDownstream, + }, + { + name: `single plugin error mixed with downstream errors should be "plugin" error source`, + queryDataResponse: &backend.QueryDataResponse{ + Responses: map[string]backend.DataResponse{ + "A": {Error: someErr, ErrorSource: backend.ErrorSourceDownstream}, + "B": {Error: someErr, ErrorSource: backend.ErrorSourcePlugin}, + "C": {Error: someErr, ErrorSource: backend.ErrorSourceDownstream}, + }, + }, + expErrorSource: backend.ErrorSourcePlugin, + }, + } { + t.Run(tc.name, func(t *testing.T) { + cdt := handlertest.NewHandlerMiddlewareTest(t, + handlertest.WithMiddlewares( + backend.NewErrorSourceMiddleware(), + ), + ) + cdt.TestHandler.QueryDataFunc = func(ctx context.Context, _ *backend.QueryDataRequest) (*backend.QueryDataResponse, error) { + cdt.QueryDataCtx = ctx + return tc.queryDataResponse, nil + } + + _, _ = cdt.MiddlewareHandler.QueryData(context.Background(), &backend.QueryDataRequest{}) + + ss := backend.ErrorSourceFromContext(cdt.QueryDataCtx) + require.Equal(t, tc.expErrorSource, ss) + }) + } + }) +} + +func setupHandlersWithError(cdt *handlertest.HandlerMiddlewareTest, err error) { + cdt.TestHandler.QueryDataFunc = func(ctx context.Context, _ *backend.QueryDataRequest) (*backend.QueryDataResponse, error) { + cdt.QueryDataCtx = ctx + return nil, err + } + cdt.TestHandler.CheckHealthFunc = func(ctx context.Context, _ *backend.CheckHealthRequest) (*backend.CheckHealthResult, error) { + cdt.CheckHealthCtx = ctx + return nil, err + } + cdt.TestHandler.CallResourceFunc = func(ctx context.Context, _ *backend.CallResourceRequest, _ backend.CallResourceResponseSender) error { + cdt.CallResourceCtx = ctx + return err + } + cdt.TestHandler.CollectMetricsFunc = func(ctx context.Context, _ *backend.CollectMetricsRequest) (*backend.CollectMetricsResult, error) { + cdt.CollectMetricsCtx = ctx + return nil, err + } + cdt.TestHandler.SubscribeStreamFunc = func(ctx context.Context, _ *backend.SubscribeStreamRequest) (*backend.SubscribeStreamResponse, error) { + cdt.SubscribeStreamCtx = ctx + return nil, err + } + cdt.TestHandler.PublishStreamFunc = func(ctx context.Context, _ *backend.PublishStreamRequest) (*backend.PublishStreamResponse, error) { + cdt.PublishStreamCtx = ctx + return nil, err + } + cdt.TestHandler.RunStreamFunc = func(ctx context.Context, _ *backend.RunStreamRequest, _ *backend.StreamSender) error { + cdt.RunStreamCtx = ctx + return err + } + cdt.TestHandler.ValidateAdmissionFunc = func(ctx context.Context, _ *backend.AdmissionRequest) (*backend.ValidationResponse, error) { + cdt.ValidateAdmissionCtx = ctx + return nil, err + } + cdt.TestHandler.MutateAdmissionFunc = func(ctx context.Context, _ *backend.AdmissionRequest) (*backend.MutationResponse, error) { + cdt.MutateAdmissionCtx = ctx + return nil, err + } + cdt.TestHandler.ConvertObjectsFunc = func(ctx context.Context, _ *backend.ConversionRequest) (*backend.ConversionResponse, error) { + cdt.ConvertObjectCtx = ctx + return nil, err + } +} diff --git a/backend/handler_middleware.go b/backend/handler_middleware.go index c7d3502d8..0de70deaf 100644 --- a/backend/handler_middleware.go +++ b/backend/handler_middleware.go @@ -47,7 +47,7 @@ func HandlerFromMiddlewares(finalHandler Handler, middlewares ...HandlerMiddlewa } func (h *MiddlewareHandler) setupContext(ctx context.Context, pluginCtx PluginContext, endpoint Endpoint) context.Context { - ctx = initErrorSource(ctx) + ctx = InitErrorSource(ctx) ctx = WithEndpoint(ctx, endpoint) ctx = WithPluginContext(ctx, pluginCtx) ctx = WithGrafanaConfig(ctx, pluginCtx.GrafanaConfig)