Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

errorsbp: Update for go 1.20 #591

Merged
merged 1 commit into from
Jan 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 28 additions & 5 deletions errorsbp/batch.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,23 @@ import (
var (
_ error = Batch{}
_ error = (*Batch)(nil)

_ batchUnwrapper = Batch{}
_ batchUnwrapper = (*Batch)(nil)
)

type batchUnwrapper interface {
Unwrap() []error
}

// Batch is an error that can contain multiple errors.
//
// The zero value of Batch is valid (with no errors) and ready to use.
//
// This type is not thread-safe.
// The same batch should not be operated on different goroutines concurrently.
//
// To be deprecated when we drop support for go 1.19.
type Batch struct {
errors []error
}
Expand Down Expand Up @@ -185,10 +194,18 @@ func (be Batch) GetErrors() []error {
return errors
}

// Unwrap implements the optional interface defined in go 1.20.
//
// It's an alias to GetErrors.
func (be Batch) Unwrap() []error {
return be.GetErrors()
}

// BatchSize returns the size of the batch for error err.
//
// If err is either errorsbp.Batch or *errorsbp.Batch,
// this function returns its Len().
// If err implements `Unwrap() []error` (optional interface defined in go 1.20),
// which includes errorsbp.Batch and *errorsbp.Batch,
// it returns the total size of Unwrap'd errors recursively.
// Otherwise, it returns 1 if err is non-nil, and 0 if err is nil.
//
// It's useful in tests,
Expand All @@ -198,9 +215,15 @@ func BatchSize(err error) int {
if err == nil {
return 0
}
var be Batch
if errors.As(err, &be) {
return be.Len()
if unwrapper, ok := err.(batchUnwrapper); ok {
// Since neither errors.Join nor fmt.Errorf tries to flatten the errors when
// combining them, do this recursively instead of simply return
// len(unwrapper.Unwrap()).
var total int
for _, e := range unwrapper.Unwrap() {
total += BatchSize(e)
}
return total
}
// single, non-batch error.
return 1
Expand Down
82 changes: 82 additions & 0 deletions errorsbp/batch_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package errorsbp_test

import (
"errors"
"fmt"
"reflect"
"testing"

Expand Down Expand Up @@ -265,3 +266,84 @@ func TestAddPrefix(t *testing.T) {
}
}
}

type simpleBatch []error

func (sb simpleBatch) Unwrap() []error {
return []error(sb)
}

func (sb simpleBatch) Error() string {
return fmt.Sprintf("simpleBatch-%d", len(sb))
}

func TestBatchSize(t *testing.T) {
for _, c := range []struct {
label string
err error
want int
}{
{
label: "nil",
err: nil,
want: 0,
},
{
label: "errors.New",
err: errors.New("foo"),
want: 1,
},
{
label: "fmt.Errorf-wrap-single",
err: fmt.Errorf("bar: %w", errors.New("foo")),
want: 1,
},
{
label: "batch-0",
err: new(errorsbp.Batch),
want: 0,
},
{
label: "batch-1",
want: 1,
err: func() error {
var batch errorsbp.Batch
batch.Add(errors.New("foo"))
return batch
}(),
},
{
label: "batch-2",
want: 2,
err: func() error {
var batch errorsbp.Batch
batch.Add(errors.New("foo"))
batch.Add(errors.New("bar"))
return batch
}(),
},
{
label: "recursion",
want: 4,
err: simpleBatch{
nil, // 0
errors.New("foo"), // 1
simpleBatch{errors.New("foo")}, // 1
simpleBatch{
nil, // 0
errors.New("foo"), // 1
errors.New("bar"), // 1
},
nil, // 0
},
},
// TODO: Add cases from errors.Join and fmt.Errorf once we drop support for

This comment was marked as resolved.

// go 1.19.
} {
t.Run(c.label, func(t *testing.T) {
if got := errorsbp.BatchSize(c.err); got != c.want {
t.Errorf("errorsbp.BatchSize(%#v) got %v want %v", c.err, got, c.want)
}
})
}
}