Skip to content
This repository has been archived by the owner on Apr 22, 2024. It is now read-only.

Switch to standard lib error wrapping #22

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 12 additions & 11 deletions algorithms.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,12 @@ import (
"crypto/sha512"
"encoding/base64"
"encoding/json"
"errors"
"hash"
"strings"
"time"

"github.com/pkg/errors"
"fmt"
)

//Algorithm is used to sign and validate a token.
Expand Down Expand Up @@ -44,7 +45,7 @@ func (a *Algorithm) write(data []byte) (int, error) {
func (a *Algorithm) Sign(unsignedToken string) ([]byte, error) {
_, err := a.write([]byte(unsignedToken))
if err != nil {
return nil, errors.Wrap(err, "Unable to write to HMAC-SHA256")
return nil, fmt.Errorf("Unable to write to HMAC-SHA256: %w", err)
}

encodedToken := a.sum(nil)
Expand All @@ -59,14 +60,14 @@ func (a *Algorithm) Encode(payload *Claims) (string, error) {

jsonTokenHeader, err := json.Marshal(header)
if err != nil {
return "", errors.Wrap(err, "unable to marshal header")
return "", fmt.Errorf("unable to marshal header: %w", err)
}

b64TokenHeader := base64.RawURLEncoding.EncodeToString(jsonTokenHeader)

jsonTokenPayload, err := json.Marshal(payload.claimsMap)
if err != nil {
return "", errors.Wrap(err, "unable to marshal payload")
return "", fmt.Errorf("unable to marshal payload: %w", err)
}

b64TokenPayload := base64.RawURLEncoding.EncodeToString(jsonTokenPayload)
Expand All @@ -75,7 +76,7 @@ func (a *Algorithm) Encode(payload *Claims) (string, error) {

signature, err := a.Sign(unsignedSignature)
if err != nil {
return "", errors.Wrap(err, "unable to sign token")
return "", fmt.Errorf("unable to sign token: %w", err)
}
b64Signature := base64.RawURLEncoding.EncodeToString([]byte(signature))

Expand All @@ -96,11 +97,11 @@ func (a *Algorithm) Decode(encoded string) (*Claims, error) {
var claims map[string]interface{}
payload, err := base64.RawURLEncoding.DecodeString(b64Payload)
if err != nil {
return nil, errors.Wrap(err, "unable to decode base64 payload")
return nil, fmt.Errorf("unable to decode base64 payload: %w", err)
}

if err := json.Unmarshal(payload, &claims); err != nil {
return nil, errors.Wrap(err, "unable to unmarshal payload json")
return nil, fmt.Errorf("unable to unmarshal payload json: %w", err)
}

return &Claims{
Expand All @@ -122,17 +123,17 @@ func (a *Algorithm) DecodeAndValidate(encoded string) (claims *Claims, err error
}

if err = a.validateSignature(encoded); err != nil {
err = errors.Wrap(err, "failed to validate signature")
err = fmt.Errorf("failed to validate signature: %w", err)
return
}

if err = a.validateExp(claims); err != nil {
err = errors.Wrap(err, "failed to validate exp")
err = fmt.Errorf("failed to validate exp: %w", err)
return
}

if err = a.validateNbf(claims); err != nil {
err = errors.Wrap(err, "failed to validate nbf")
err = fmt.Errorf("failed to validate nbf: %w", err)
}

return
Expand All @@ -148,7 +149,7 @@ func (a *Algorithm) validateSignature(encoded string) error {
unsignedAttempt := b64Header + "." + b64Payload
signedAttempt, err := a.Sign(unsignedAttempt)
if err != nil {
return errors.Wrap(err, "unable to sign token for validation")
return fmt.Errorf("unable to sign token for validation: %w", err)
}

b64SignedAttempt := base64.RawURLEncoding.EncodeToString([]byte(signedAttempt))
Expand Down
9 changes: 4 additions & 5 deletions claims.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
package jwt

import (
"fmt"
"time"

"github.com/pkg/errors"
)

// Claims contains the claims of a jwt.
Expand Down Expand Up @@ -45,7 +44,7 @@ func (c *Claims) SetTime(key string, value time.Time) {
func (c Claims) Get(key string) (interface{}, error) {
result, ok := c.claimsMap[key]
if !ok {
return "", errors.Errorf("claim (%s) doesn't exist", key)
return "", fmt.Errorf("claim (%s) doesn't exist", key)
}

return result, nil
Expand All @@ -55,12 +54,12 @@ func (c Claims) Get(key string) (interface{}, error) {
func (c *Claims) GetTime(key string) (time.Time, error) {
raw, err := c.Get(key)
if err != nil {
return time.Unix(0, 0), errors.Wrapf(err, "claim (%s) doesn't exist", key)
return time.Unix(0, 0), fmt.Errorf("claim (%s) doesn't exist: %w", key, err)
}

timeFloat, ok := raw.(float64)
if !ok {
return time.Unix(0, 0), errors.Wrap(err, "claim isn't a valid float")
return time.Unix(0, 0), fmt.Errorf("claim isn't a valid float: %w", err)
}

return time.Unix(int64(timeFloat), 0), nil
Expand Down