Skip to content

Commit

Permalink
Add Unwrap() implementations for Error types (#128)
Browse files Browse the repository at this point in the history
* Add Unwrap() implementations for Error types
---------
Signed-off-by: Caleb Brown <calebbrown@google.com>
  • Loading branch information
calebbrown authored Feb 26, 2024
1 parent 8f1d49e commit 586bdee
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 0 deletions.
16 changes: 16 additions & 0 deletions graphql.go
Original file line number Diff line number Diff line change
Expand Up @@ -333,13 +333,19 @@ type Error struct {
Column int `json:"column"`
} `json:"locations"`
Path []interface{} `json:"path"`
err error
}

// Error implements error interface.
func (e Error) Error() string {
return fmt.Sprintf("Message: %s, Locations: %+v, Extensions: %+v, Path: %+v", e.Message, e.Locations, e.Extensions, e.Path)
}

// Unwrap implement the unwrap interface.
func (e Error) Unwrap() error {
return e.err
}

// Error implements error interface.
func (e Errors) Error() string {
b := strings.Builder{}
Expand All @@ -349,6 +355,15 @@ func (e Errors) Error() string {
return b.String()
}

// Unwrap implements the error unwrap interface.
func (e Errors) Unwrap() []error {
var errs []error
for _, err := range e {
errs = append(errs, err.err)
}
return errs
}

func (e Error) getInternalExtension() map[string]interface{} {
if e.Extensions == nil {
return make(map[string]interface{})
Expand All @@ -367,6 +382,7 @@ func newError(code string, err error) Error {
Extensions: map[string]interface{}{
"code": code,
},
err: err,
}
}

Expand Down
67 changes: 67 additions & 0 deletions graphql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,63 @@ func TestClient_Query_errorStatusCode(t *testing.T) {
}
}

func TestClient_Query_requestError(t *testing.T) {
want := errors.New("bad error")
client := graphql.NewClient("/graphql", &http.Client{Transport: errorRoundTripper{err: want}})

var q struct {
User struct {
Name string
}
}
err := client.Query(context.Background(), &q, nil)
if err == nil {
t.Fatal("got error: nil, want: non-nil")
}
if got, want := err.Error(), `Message: Post "/graphql": bad error, Locations: [], Extensions: map[code:request_error], Path: []`; got != want {
t.Errorf("got error: %v, want: %v", got, want)
}
if q.User.Name != "" {
t.Errorf("got non-empty q.User.Name: %v", q.User.Name)
}
if got := err; !errors.Is(got, want) {
t.Errorf("got error: %v, want: %v", got, want)
}

gqlErr := err.(graphql.Errors)
if got, want := gqlErr[0].Extensions["code"], graphql.ErrRequestError; got != want {
t.Errorf("got error: %v, want: %v", got, want)
}
if _, ok := gqlErr[0].Extensions["internal"]; ok {
t.Errorf("expected empty internal error")
}
if got := gqlErr[0]; !errors.Is(err, want) {
t.Errorf("got error: %v, want %v", got, want)
}

// test internal error data
client = client.WithDebug(true)
err = client.Query(context.Background(), &q, nil)
if err == nil {
t.Fatal("got error: nil, want: non-nil")
}
if !errors.As(err, &graphql.Errors{}) {
t.Errorf("the error type should be graphql.Errors")
}
gqlErr = err.(graphql.Errors)
if got, want := gqlErr[0].Message, `Post "/graphql": bad error`; got != want {
t.Errorf("got error: %v, want: %v", got, want)
}
if got, want := gqlErr[0].Extensions["code"], graphql.ErrRequestError; got != want {
t.Errorf("got error: %v, want: %v", got, want)
}
interErr := gqlErr[0].Extensions["internal"].(map[string]interface{})

if got, want := interErr["request"].(map[string]interface{})["body"], "{\"query\":\"{user{name}}\"}\n"; got != want {
t.Errorf("got error: %v, want: %v", got, want)
}
}

// Test that an empty (but non-nil) variables map is
// handled no differently than a nil variables map.
func TestClient_Query_emptyVariables(t *testing.T) {
Expand Down Expand Up @@ -425,6 +482,16 @@ func (l localRoundTripper) RoundTrip(req *http.Request) (*http.Response, error)
return w.Result(), nil
}

// errorRoundTripper is an http.RoundTripper that always returns the supplied
// error.
type errorRoundTripper struct {
err error
}

func (e errorRoundTripper) RoundTrip(_ *http.Request) (*http.Response, error) {
return nil, e.err
}

func mustRead(r io.Reader) string {
b, err := io.ReadAll(r)
if err != nil {
Expand Down

0 comments on commit 586bdee

Please sign in to comment.