diff --git a/internal/retry.go b/internal/retry.go index 7a7b4c2052db..2943a6d0b457 100644 --- a/internal/retry.go +++ b/internal/retry.go @@ -16,9 +16,11 @@ package internal import ( "context" + "fmt" "time" gax "github.com/googleapis/gax-go/v2" + "google.golang.org/grpc/status" ) // Retry calls the supplied function f repeatedly according to the provided @@ -44,11 +46,40 @@ func retry(ctx context.Context, bo gax.Backoff, f func() (stop bool, err error), lastErr = err } p := bo.Pause() - if cerr := sleep(ctx, p); cerr != nil { + if ctxErr := sleep(ctx, p); ctxErr != nil { if lastErr != nil { - return Annotatef(lastErr, "retry failed with %v; last error", cerr) + return wrappedCallErr{ctxErr: ctxErr, wrappedErr: lastErr} } - return cerr + return ctxErr } } } + +// Use this error type to return an error which allows introspection of both +// the context error and the error from the service. +type wrappedCallErr struct { + ctxErr error + wrappedErr error +} + +func (e wrappedCallErr) Error() string { + return fmt.Sprintf("retry failed with %v; last error: %v", e.ctxErr, e.wrappedErr) +} + +func (e wrappedCallErr) Unwrap() error { + return e.wrappedErr +} + +// Is allows errors.Is to match the error from the call as well as context +// sentinel errors. +func (e wrappedCallErr) Is(err error) bool { + return e.ctxErr == err || e.wrappedErr == err +} + +// GRPCStatus allows the wrapped error to be used with status.FromError. +func (e wrappedCallErr) GRPCStatus() *status.Status { + if s, ok := status.FromError(e.wrappedErr); ok { + return s + } + return nil +} diff --git a/internal/retry_test.go b/internal/retry_test.go index a01205689d0c..771cb26ffca4 100644 --- a/internal/retry_test.go +++ b/internal/retry_test.go @@ -17,7 +17,6 @@ package internal import ( "context" "errors" - "fmt" "testing" "time" @@ -75,6 +74,13 @@ func TestRetryPreserveError(t *testing.T) { func(context.Context, time.Duration) error { return context.DeadlineExceeded }) + if err == nil { + t.Fatalf("unexpectedly got nil error") + } + wantError := "retry failed with context deadline exceeded; last error: rpc error: code = NotFound desc = not found" + if g, w := err.Error(), wantError; g != w { + t.Errorf("got error %q, want %q", g, w) + } got, ok := status.FromError(err) if !ok { t.Fatalf("got %T, wanted a status", got) @@ -82,7 +88,7 @@ func TestRetryPreserveError(t *testing.T) { if g, w := got.Code(), codes.NotFound; g != w { t.Errorf("got code %v, want %v", g, w) } - wantMessage := fmt.Sprintf("retry failed with %v; last error: not found", context.DeadlineExceeded) + wantMessage := "not found" if g, w := got.Message(), wantMessage; g != w { t.Errorf("got message %q, want %q", g, w) }