From 6bb7504e4ae2d1efbcb717cafb8edb9fe0c07d88 Mon Sep 17 00:00:00 2001 From: Junjie Gao <43160897+JeyJeyGao@users.noreply.github.com> Date: Tue, 16 Aug 2022 16:58:06 +0800 Subject: [PATCH] add jws test Signed-off-by: Junjie Gao --- signature/errors.go | 12 +- signature/internal/base/envelope.go | 2 +- signature/jws/envelope.go | 94 +++--- signature/jws/envelope_test.go | 481 ++++++++++++++++++++++++++-- signature/jws/jws.go | 21 +- signature/jws/jwt.go | 127 +++++--- signature/jws/jwt_test.go | 126 ++++++++ 7 files changed, 718 insertions(+), 145 deletions(-) create mode 100644 signature/jws/jwt_test.go diff --git a/signature/errors.go b/signature/errors.go index 2428c754..c71b6034 100644 --- a/signature/errors.go +++ b/signature/errors.go @@ -122,5 +122,15 @@ type EnvelopeKeyRepeatedError struct { // Error returns the formatted error message. func (e *EnvelopeKeyRepeatedError) Error() string { - return fmt.Sprintf("repeated key: %q exists in the envelope.", e.Key) + return fmt.Sprintf("repeated key: `%s` exists in the both protected header and extended signed attributes.", e.Key) +} + +// RemoteSigningError is used when remote signer causes the error. +type RemoteSigningError struct { + Msg string +} + +// Error returns formated remote signing error +func (e *RemoteSigningError) Error() string { + return fmt.Sprintf("remote signing error. Error: %s", e.Msg) } diff --git a/signature/internal/base/envelope.go b/signature/internal/base/envelope.go index 4c73d7d8..9919698a 100644 --- a/signature/internal/base/envelope.go +++ b/signature/internal/base/envelope.go @@ -171,7 +171,7 @@ func validateSignerInfo(info *signature.SignerInfo) error { ) } -// validateSigningTime checks that sigining time is within the valid range of +// validateSigningTime checks that signing time is within the valid range of // time duration. func validateSigningTime(signingTime, expireTime time.Time) error { if signingTime.IsZero() { diff --git a/signature/jws/envelope.go b/signature/jws/envelope.go index 521ec4b6..c881990f 100644 --- a/signature/jws/envelope.go +++ b/signature/jws/envelope.go @@ -4,7 +4,6 @@ import ( "crypto/x509" "encoding/base64" "encoding/json" - "strings" "github.com/golang-jwt/jwt/v4" "github.com/notaryproject/notation-core-go/signature" @@ -46,24 +45,32 @@ func ParseEnvelope(envelopeBytes []byte) (signature.Envelope, error) { // Sign signs the envelope and return the encoded message func (e *envelope) Sign(req *signature.SignRequest) ([]byte, error) { - // get all attributes ready to be signed - signedAttrs, err := getSignedAttrs(req) - if err != nil { - return nil, err + // check signer type + var signingMethod SigningMethod + var err error + if localSigner, ok := req.Signer.(signature.LocalSigner); ok { + // for local signer + signingMethod, err = newLocalSigningMethod(localSigner) + } else { + // for remote signer + signingMethod, err = newRemoteSigningMethod(req.Signer) } - - // JWT sign - compact, err := sign(req.Payload.Content, signedAttrs, req.Signer) if err != nil { return nil, &signature.MalformedSignRequestError{Msg: err.Error()} } - // get certificate chain - certs, err := req.Signer.CertificateChain() + // get all attributes ready to be signed + signedAttrs, err := getSignedAttrs(req, signingMethod.Alg()) if err != nil { return nil, err } + // JWT sign and get certificate chain + compact, certs, err := sign(req.Payload.Content, signedAttrs, signingMethod) + if err != nil { + return nil, &signature.MalformedSignRequestError{Msg: err.Error()} + } + // generate envelope env, err := generateJWS(compact, req, certs) if err != nil { @@ -78,17 +85,6 @@ func (e *envelope) Sign(req *signature.SignRequest) ([]byte, error) { return encoded, nil } -// compactJWS converts Flattened JWS JSON Serialization Syntax (section-7.2.2) to -// JWS Compact Serialization (section-7.1) -// -// [RFC 7515]: https://www.rfc-editor.org/rfc/rfc7515.html -func compactJWS(envelope *jwsEnvelope) string { - return strings.Join([]string{ - envelope.Protected, - envelope.Payload, - envelope.Signature}, ".") -} - // Verify checks the validity of the envelope and returns the payload and signerInfo func (e *envelope) Verify() (*signature.Payload, *signature.SignerInfo, error) { if e.internalEnvelope == nil { @@ -96,7 +92,7 @@ func (e *envelope) Verify() (*signature.Payload, *signature.SignerInfo, error) { } if len(e.internalEnvelope.Header.CertChain) == 0 { - return nil, nil, &signature.MalformedSignatureError{Msg: "malformed leaf certificate"} + return nil, nil, &signature.MalformedSignatureError{Msg: "certificate chain is not set"} } cert, err := x509.ParseCertificate(e.internalEnvelope.Header.CertChain[0]) @@ -106,7 +102,7 @@ func (e *envelope) Verify() (*signature.Payload, *signature.SignerInfo, error) { // verify JWT compact := compactJWS(e.internalEnvelope) - if err = verifyJWT(compact, cert); err != nil { + if err = verifyJWT(compact, cert.PublicKey); err != nil { return nil, nil, err } @@ -127,7 +123,7 @@ func (e *envelope) Verify() (*signature.Payload, *signature.SignerInfo, error) { // Payload returns the payload of JWS envelope func (e *envelope) Payload() (*signature.Payload, error) { if e.internalEnvelope == nil { - return nil, &signature.MalformedSignatureError{Msg: "missing jws signature envelope"} + return nil, &signature.SignatureNotFoundError{} } // parse protected header to get payload context type protected, err := parseProtectedHeaders(e.internalEnvelope.Protected) @@ -147,7 +143,7 @@ func (e *envelope) Payload() (*signature.Payload, error) { var claims jwtPayload _, _, err = parser.ParseUnverified(tokenString, &claims) if err != nil { - return nil, err + return nil, &signature.MalformedSignatureError{Msg: err.Error()} } return &signature.Payload{ @@ -161,14 +157,14 @@ func (e *envelope) SignerInfo() (*signature.SignerInfo, error) { if e.internalEnvelope == nil { return nil, &signature.SignatureNotFoundError{} } - var signInfo signature.SignerInfo + var signerInfo signature.SignerInfo // parse protected headers protected, err := parseProtectedHeaders(e.internalEnvelope.Protected) if err != nil { return nil, err } - if err := populateProtectedHeaders(protected, &signInfo); err != nil { + if err := populateProtectedHeaders(protected, &signerInfo); err != nil { return nil, err } @@ -180,7 +176,7 @@ func (e *envelope) SignerInfo() (*signature.SignerInfo, error) { if len(sig) == 0 { return nil, &signature.MalformedSignatureError{Msg: "cose envelope missing signature"} } - signInfo.Signature = sig + signerInfo.Signature = sig // parse headers var certs []*x509.Certificate @@ -191,34 +187,28 @@ func (e *envelope) SignerInfo() (*signature.SignerInfo, error) { } certs = append(certs, cert) } - signInfo.CertificateChain = certs - signInfo.UnsignedAttributes.SigningAgent = e.internalEnvelope.Header.SigningAgent - signInfo.UnsignedAttributes.TimestampSignature = e.internalEnvelope.Header.TimestampSignature - - return &signInfo, nil + signerInfo.CertificateChain = certs + signerInfo.UnsignedAttributes.SigningAgent = e.internalEnvelope.Header.SigningAgent + signerInfo.UnsignedAttributes.TimestampSignature = e.internalEnvelope.Header.TimestampSignature + return &signerInfo, nil } -// sign the given payload and headers using the given signing method and signature provider -func sign(payload jwtPayload, headers map[string]interface{}, signer signature.Signer) (string, error) { - var privateKey interface{} - var signingMethod jwt.SigningMethod - if localSigner, ok := signer.(signature.LocalSigner); ok { - // local signer - alg, err := extractJwtAlgorithm(localSigner) - if err != nil { - return "", err - } - signingMethod = jwt.GetSigningMethod(alg) - - // sign with private key - privateKey = localSigner.PrivateKey() - } else { - // remote signer - signingMethod = newRemoteSigningMethod(signer) - } +// sign the given payload and headers using the given signature provider +func sign(payload jwtPayload, headers map[string]interface{}, signingMethod SigningMethod) (string, []*x509.Certificate, error) { // generate token token := jwt.NewWithClaims(signingMethod, payload) token.Header = headers - return token.SignedString(privateKey) + // sign and return compact JWS + compact, err := token.SignedString(signingMethod.PrivateKey()) + if err != nil { + return "", nil, err + } + + // access certificate chain after sign + certs, err := signingMethod.CertificateChain() + if err != nil { + return "", nil, err + } + return compact, certs, nil } diff --git a/signature/jws/envelope_test.go b/signature/jws/envelope_test.go index 61acfe63..e8a6e011 100644 --- a/signature/jws/envelope_test.go +++ b/signature/jws/envelope_test.go @@ -1,7 +1,17 @@ package jws import ( + "crypto" + "crypto/ecdsa" + "crypto/rand" + "crypto/rsa" "crypto/x509" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "reflect" + "strings" "testing" "time" @@ -9,12 +19,99 @@ import ( "github.com/notaryproject/notation-core-go/testhelper" ) +// remoteSignerMock is used to mock remote signer +type remoteSignerMock struct { + privateKey crypto.PrivateKey + certs []*x509.Certificate +} + +// Sign signs the digest and returns the raw signature +func (signer *remoteSignerMock) Sign(payload []byte) ([]byte, []*x509.Certificate, error) { + // calculate hash + keySpec, err := signer.KeySpec() + if err != nil { + return nil, nil, err + } + + // calculate hash + hasher := keySpec.SignatureAlgorithm().Hash().HashFunc() + h := hasher.New() + h.Write(payload) + hash := h.Sum(nil) + + // sign + switch key := signer.privateKey.(type) { + case *rsa.PrivateKey: + sig, err := rsa.SignPSS(rand.Reader, key, hasher.HashFunc(), hash, &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash}) + if err != nil { + return nil, nil, err + } + return sig, signer.certs, nil + case *ecdsa.PrivateKey: + r, s, err := ecdsa.Sign(rand.Reader, key, hash) + if err != nil { + return nil, nil, err + } + + curveBits := key.Curve.Params().BitSize + keyBytes := curveBits / 8 + if curveBits%8 > 0 { + keyBytes += 1 + } + + out := make([]byte, 2*keyBytes) + r.FillBytes(out[0:keyBytes]) // r is assigned to the first half of output. + s.FillBytes(out[keyBytes:]) // s is assigned to the second half of output. + return out, signer.certs, nil + } + + return nil, nil, &signature.UnsupportedSigningKeyError{} +} + +// KeySpec returns the key specification +func (signer *remoteSignerMock) KeySpec() (signature.KeySpec, error) { + return signature.ExtractKeySpec(signer.certs[0]) +} + +func checkError(t *testing.T, err error) { + if err != nil { + t.Fatal(t) + } +} + +func cmpError(t *testing.T, got string, want string) { + if got != want { + t.Fatalf("want: %v, got: %v\n", want, got) + } +} + +var ( + extSignedAttr = []signature.Attribute{ + { + Key: "testKey", + Critical: true, + Value: "testValue", + }, + { + Key: "testKey2", + Critical: false, + Value: "testValue2", + }, + } + extSignedAttrFailed = []signature.Attribute{ + { + Key: "cty", + Critical: false, + Value: "testValue2", + }, + } +) + func getSigningCerts() []*x509.Certificate { return []*x509.Certificate{testhelper.GetRSALeafCertificate().Cert, testhelper.GetRSARootCertificate().Cert} } -func getSignReq(signingScheme signature.SigningScheme) (*signature.SignRequest, error) { - certs := getSigningCerts() +func getSignReq(signingScheme signature.SigningScheme, signer signature.Signer, extendedSignedAttribute []signature.Attribute) (*signature.SignRequest, error) { payloadBytes := []byte(`{ "subject": { "mediaType": "application/vnd.oci.image.manifest.v1+json", @@ -26,10 +123,6 @@ func getSignReq(signingScheme signature.SigningScheme) (*signature.SignRequest, } } `) - signer, err := signature.NewLocalSigner(certs, testhelper.GetRSALeafCertificate().PrivateKey) - if err != nil { - return nil, err - } return &signature.SignRequest{ Payload: signature.Payload{ ContentType: signature.MediaTypePayloadV1, @@ -38,22 +131,79 @@ func getSignReq(signingScheme signature.SigningScheme) (*signature.SignRequest, Signer: signer, SigningTime: time.Now(), Expiry: time.Now().Add(time.Hour), - ExtendedSignedAttributes: nil, + ExtendedSignedAttributes: extendedSignedAttribute, SigningAgent: "Notation/1.0.0", SigningScheme: signingScheme, }, nil +} + +func getSigner(isLocal bool, certs []*x509.Certificate, privateKey *rsa.PrivateKey) (signature.Signer, error) { + if certs == nil { + certs = getSigningCerts() + } + if privateKey == nil { + privateKey = testhelper.GetRSALeafCertificate().PrivateKey + } + if isLocal { + return signature.NewLocalSigner(certs, privateKey) + } + return &remoteSignerMock{ + certs: certs, + privateKey: privateKey, + }, nil } -func signCore(signingScheme signature.SigningScheme) ([]byte, error) { - signReq, err := getSignReq(signingScheme) +func getEnvelope(signingScheme signature.SigningScheme, isLocal bool, extendedSignedAttribute []signature.Attribute) (*jwsEnvelope, error) { + encoded, err := getEncodedMessage(signingScheme, isLocal, extendedSignedAttribute) + if err != nil { + return nil, err + } + var jwsEnv jwsEnvelope + err = json.Unmarshal(encoded, &jwsEnv) if err != nil { return nil, err } - e := NewEnvelope() + return &jwsEnv, nil +} + +func getEncodedMessage(signingScheme signature.SigningScheme, isLocal bool, extendedSignedAttribute []signature.Attribute) ([]byte, error) { + signer, err := getSigner(isLocal, nil, nil) + if err != nil { + return nil, err + } + + signReq, err := getSignReq(signingScheme, signer, extendedSignedAttribute) + if err != nil { + return nil, err + } + e := envelope{} return e.Sign(signReq) } +func getSignedEnvelope(signingScheme signature.SigningScheme, isLocal bool, extendedSignedAttribute []signature.Attribute) (*jwsEnvelope, error) { + encoded, err := getEncodedMessage(signingScheme, isLocal, extendedSignedAttribute) + if err != nil { + return nil, err + } + // + var env jwsEnvelope + err = json.Unmarshal(encoded, &env) + if err != nil { + return nil, err + } + return &env, nil +} + +func verifyEnvelope(env *jwsEnvelope) error { + newEncoded, err := json.Marshal(env) + if err != nil { + return err + } + _, _, err = verifyCore(newEncoded) + return err +} + func verifyCore(encoded []byte) (*signature.Payload, *signature.SignerInfo, error) { env, err := ParseEnvelope(encoded) if err != nil { @@ -62,41 +212,302 @@ func verifyCore(encoded []byte) (*signature.Payload, *signature.SignerInfo, erro return env.Verify() } -func Test_envelope_Verify_X509(t *testing.T) { - encoded, err := signCore(signature.SigningSchemeX509) - if err != nil { - t.Fatal(err) +func TestNewEnvelope(t *testing.T) { + env := NewEnvelope() + if env == nil { + t.Fatal("should get an JWS envelope") } - _, _, err = verifyCore(encoded) - if err != nil { - t.Fatal(err) +} + +// Test the same key exists both in extended signed attributes and protected header +func TestEnvelopeSignFailed(t *testing.T) { + _, err := getEncodedMessage(signature.SigningSchemeX509, true, extSignedAttrFailed) + if err == nil { + t.Fatal("should cause error") } } -func Test_envelope_Verify_X509SigningAuthority(t *testing.T) { - encoded, err := signCore(signature.SigningSchemeX509SigningAuthority) - if err != nil { - t.Fatal(err) +func TestEnvelopeVerify(t *testing.T) { + var signParams = []struct { + isLocal bool + signingScheme signature.SigningScheme + }{ + {true, signature.SigningSchemeX509}, + {true, signature.SigningSchemeX509SigningAuthority}, + {false, signature.SigningSchemeX509}, + {false, signature.SigningSchemeX509SigningAuthority}, } - _, _, err = verifyCore(encoded) - if err != nil { - t.Fatal(err) + + for _, tt := range signParams { + t.Run(fmt.Sprintf("verify_isLocal=%v_signingScheme=%v", tt.isLocal, tt.signingScheme), func(t *testing.T) { + encoded, err := getEncodedMessage(tt.signingScheme, tt.isLocal, extSignedAttr) + checkError(t, err) + + _, _, err = verifyCore(encoded) + checkError(t, err) + }) } } -func Test_envelope_Verify_failed(t *testing.T) { - encoded, err := signCore(signature.SigningSchemeX509) - if err != nil { - t.Fatal(t) +func TestVerify(t *testing.T) { + t.Run("break json format", func(t *testing.T) { + encoded, err := getEncodedMessage(signature.SigningSchemeX509, true, extSignedAttr) + checkError(t, err) + + encoded[0] = '}' + + _, _, err = verifyCore(encoded) + cmpError(t, err.Error(), "invalid character '}' looking for beginning of value") + }) + + t.Run("empty certificate", func(t *testing.T) { + // get envelope + env, err := getSignedEnvelope(signature.SigningSchemeX509, true, extSignedAttr) + checkError(t, err) + + // temper envelope + env.Header.CertChain = [][]byte{} + + err = verifyEnvelope(env) + cmpError(t, err.Error(), "certificate chain is not set") + }) + + t.Run("tamper certificate", func(t *testing.T) { + // get envelope + env, err := getSignedEnvelope(signature.SigningSchemeX509, true, extSignedAttr) + checkError(t, err) + + // temper envelope + env.Header.CertChain[0][0] = 'C' + + err = verifyEnvelope(env) + cmpError(t, err.Error(), "malformed leaf certificate") + }) + + t.Run("malformed protected header base64 encoded", func(t *testing.T) { + // get envelope + env, err := getSignedEnvelope(signature.SigningSchemeX509, true, extSignedAttr) + checkError(t, err) + + // temper envelope + env.Protected = "$" + env.Protected + + err = verifyEnvelope(env) + cmpError(t, err.Error(), "jws envelope protected header can't be decoded: illegal base64 data at input byte 0") + }) + t.Run("malformed protected header raw", func(t *testing.T) { + // get envelope + env, err := getSignedEnvelope(signature.SigningSchemeX509, true, extSignedAttr) + checkError(t, err) + + // temper envelope + rawProtected, err := base64.RawURLEncoding.DecodeString(env.Protected) + checkError(t, err) + + rawProtected[0] = '}' + env.Protected = base64.RawURLEncoding.EncodeToString(rawProtected) + + err = verifyEnvelope(env) + cmpError(t, err.Error(), "jws envelope protected header can't be decoded: invalid character '}' looking for beginning of value") + }) +} + +func TestSignerInfo(t *testing.T) { + getEnvelopeAndHeader := func(signingScheme signature.SigningScheme) (*jwsEnvelope, *jwsProtectedHeader) { + // get envelope + env, err := getSignedEnvelope(signingScheme, true, extSignedAttr) + checkError(t, err) + + // get protected header + header, err := parseProtectedHeaders(env.Protected) + checkError(t, err) + return env, header } - // manipulate envelope - encoded[len(encoded)-10] = 'C' + updateProtectedHeader := func(env *jwsEnvelope, protected *jwsProtectedHeader) { + // generate protected header + headerMap := make(map[string]interface{}) + valueOf := reflect.ValueOf(*protected) + for i := 0; i < valueOf.NumField(); i++ { + var key string + tags := strings.Split(valueOf.Type().Field(i).Tag.Get("json"), ",") + if len(tags) > 0 { + key = tags[0] + } + if key == "-" { + continue + } + headerMap[key] = valueOf.Field(i).Interface() + } + // extract extended attribute + for key, value := range protected.ExtendedAttributes { + headerMap[key] = value + } - // verify manipulated envelope - _, _, err = verifyCore(encoded) + // marshal and write back to envelope + rawProtected, err := json.Marshal(headerMap) + checkError(t, err) + env.Protected = base64.RawURLEncoding.EncodeToString(rawProtected) + } + getSignerInfo := func(env *jwsEnvelope, protected *jwsProtectedHeader) (*signature.SignerInfo, error) { + updateProtectedHeader(env, protected) + // marshal tampered envelope + newEncoded, err := json.Marshal(env) + checkError(t, err) - // should get an error - if err == nil { - t.Fatalf("should verify failed.") + // parse tampered envelope + newEnv, err := ParseEnvelope(newEncoded) + checkError(t, err) + + return newEnv.SignerInfo() } + + t.Run("tamper protected header signing scheme X509", func(t *testing.T) { + env, header := getEnvelopeAndHeader(signature.SigningSchemeX509) + + // temper protected header + signingTime := time.Now() + header.AuthenticSigningTime = &signingTime + + _, err := getSignerInfo(env, header) + cmpError(t, err.Error(), `signature envelope format is malformed. error: "io.cncf.notary.authenticSigningTime" header must not be present for notary.x509 signing scheme`) + }) + + t.Run("tamper protected header signing scheme X509 Signing Authority", func(t *testing.T) { + env, header := getEnvelopeAndHeader(signature.SigningSchemeX509SigningAuthority) + + // temper protected header + signingTime := time.Now() + header.SigningTime = &signingTime + + _, err := getSignerInfo(env, header) + cmpError(t, err.Error(), `signature envelope format is malformed. error: "io.cncf.notary.signingTime" header must not be present for notary.x509.signingAuthority signing scheme`) + }) + + t.Run("tamper protected header signing scheme X509 Signing Authority 2", func(t *testing.T) { + env, header := getEnvelopeAndHeader(signature.SigningSchemeX509SigningAuthority) + + // temper protected header + header.AuthenticSigningTime = nil + + _, err := getSignerInfo(env, header) + cmpError(t, err.Error(), `signature envelope format is malformed. error: "io.cncf.notary.authenticSigningTime" header must be present for notary.x509 signing scheme`) + }) + + t.Run("tamper protected header extended attributes", func(t *testing.T) { + env, header := getEnvelopeAndHeader(signature.SigningSchemeX509) + + // temper protected header + header.ExtendedAttributes = make(map[string]interface{}) + + _, err := getSignerInfo(env, header) + cmpError(t, err.Error(), `signature envelope format is malformed. error: "testKey" header is marked critical but not present`) + }) + + t.Run("add protected header critical key", func(t *testing.T) { + env, header := getEnvelopeAndHeader(signature.SigningSchemeX509) + + // temper protected header + header.Critical = header.Critical[:len(header.Critical)-2] + + _, err := getSignerInfo(env, header) + cmpError(t, err.Error(), `signature envelope format is malformed. error: these required headers are not marked as critical: [io.cncf.notary.expiry]`) + }) + t.Run("tamper raw protected header json format", func(t *testing.T) { + // get envelope + env, err := getSignedEnvelope(signature.SigningSchemeX509, true, extSignedAttr) + checkError(t, err) + + rawProtected, err := base64.RawURLEncoding.DecodeString(env.Protected) + checkError(t, err) + + // temper envelope + rawProtected[0] = '}' + env.Protected = base64.RawURLEncoding.EncodeToString(rawProtected) + + newEncoded, err := json.Marshal(env) + checkError(t, err) + + // parse tampered envelope + newEnv, err := ParseEnvelope(newEncoded) + checkError(t, err) + + _, err = newEnv.SignerInfo() + cmpError(t, err.Error(), "signature envelope format is malformed. error: jws envelope protected header can't be decoded: invalid character '}' looking for beginning of value") + }) + t.Run("tamper signature base64 encoding", func(t *testing.T) { + env, header := getEnvelopeAndHeader(signature.SigningSchemeX509) + + // temper protected header + env.Signature = "{" + env.Signature + + _, err := getSignerInfo(env, header) + cmpError(t, err.Error(), `signature envelope format is malformed. error: illegal base64 data at input byte 0`) + }) + t.Run("tamper empty signature", func(t *testing.T) { + env, header := getEnvelopeAndHeader(signature.SigningSchemeX509) + + // temper protected header + env.Signature = "" + + _, err := getSignerInfo(env, header) + cmpError(t, err.Error(), `signature envelope format is malformed. error: cose envelope missing signature`) + }) + t.Run("tamper cert chain", func(t *testing.T) { + env, header := getEnvelopeAndHeader(signature.SigningSchemeX509) + + // temper protected header + env.Header.CertChain[0] = append(env.Header.CertChain[0], 'v') + + _, err := getSignerInfo(env, header) + cmpError(t, err.Error(), `signature envelope format is malformed. error: x509: trailing data`) + }) +} + +func TestPayload(t *testing.T) { + t.Run("tamper envelope cause JWT parse failed", func(t *testing.T) { + // get envelope + env, err := getSignedEnvelope(signature.SigningSchemeX509, true, extSignedAttr) + checkError(t, err) + + // tamper payload + env.Payload = env.Payload[1:] + + // marshal tampered envelope + newEncoded, err := json.Marshal(env) + checkError(t, err) + + // parse tampered envelope + newEnv, err := ParseEnvelope(newEncoded) + checkError(t, err) + + _, err = newEnv.Payload() + cmpError(t, err.Error(), "illegal base64 data at input byte 476") + + }) +} + +func TestEmptyEnvelope(t *testing.T) { + wantErr := &signature.SignatureNotFoundError{} + env := envelope{} + + t.Run("Verify()_with_empty_envelope", func(t *testing.T) { + _, _, err := env.Verify() + if !errors.Is(err, wantErr) { + t.Fatalf("want: %v, got: %v", wantErr, err) + } + }) + + t.Run("Payload()_with_empty_envelope", func(t *testing.T) { + _, err := env.Payload() + if !errors.Is(err, wantErr) { + t.Fatalf("want: %v, got: %v", wantErr, err) + } + }) + + t.Run("SignerInfo()_with_empty_envelope", func(t *testing.T) { + _, err := env.SignerInfo() + if !errors.Is(err, wantErr) { + t.Fatalf("want: %v, got: %v", wantErr, err) + } + }) } diff --git a/signature/jws/jws.go b/signature/jws/jws.go index 1b2d9ad2..0e53238e 100644 --- a/signature/jws/jws.go +++ b/signature/jws/jws.go @@ -177,7 +177,7 @@ func generateJWS(compact string, req *signature.SignRequest, certs []*x509.Certi } // getSignerAttrs merge extended signed attributes and protected header to be signed attributes -func getSignedAttrs(req *signature.SignRequest) (map[string]interface{}, error) { +func getSignedAttrs(req *signature.SignRequest, algorithm string) (map[string]interface{}, error) { extAttrs := make(map[string]interface{}) crit := []string{headerKeySigningScheme} @@ -189,14 +189,8 @@ func getSignedAttrs(req *signature.SignRequest) (map[string]interface{}, error) } } - // extract JWT algorithm name from signer - jwtAlgorithm, err := extractJwtAlgorithm(req.Signer) - if err != nil { - return nil, err - } - jwsProtectedHeader := jwsProtectedHeader{ - Algorithm: jwtAlgorithm, + Algorithm: algorithm, ContentType: req.Payload.ContentType, SigningScheme: req.SigningScheme, } @@ -246,3 +240,14 @@ func mergeMaps(maps ...map[string]interface{}) (map[string]interface{}, error) { } return result, nil } + +// compactJWS converts Flattened JWS JSON Serialization Syntax (section-7.2.2) to +// JWS Compact Serialization (section-7.1) +// +// [RFC 7515]: https://www.rfc-editor.org/rfc/rfc7515.html +func compactJWS(envelope *jwsEnvelope) string { + return strings.Join([]string{ + envelope.Protected, + envelope.Payload, + envelope.Signature}, ".") +} diff --git a/signature/jws/jwt.go b/signature/jws/jwt.go index 56e04cda..ffcef42c 100644 --- a/signature/jws/jwt.go +++ b/signature/jws/jwt.go @@ -1,6 +1,7 @@ package jws import ( + "crypto" "crypto/x509" "encoding/base64" "fmt" @@ -9,13 +10,36 @@ import ( "github.com/notaryproject/notation-core-go/signature" ) -// remoteSigningMethod wraps the remote signer to be a jwt.SigningMethod +// SigningMethod is the interface for jwt.SigingMethod with additional method to +// access certificate chain after calling Sign() +type SigningMethod interface { + jwt.SigningMethod + + // CertificateChain returns the certificate chain. + // + // should be called after calling Sign() + CertificateChain() ([]*x509.Certificate, error) + + // PrivateKey returns the private key. + PrivateKey() crypto.PrivateKey +} + +// remoteSigningMethod wraps the remote signer to be a SigningMethod type remoteSigningMethod struct { - signer signature.Signer + signer signature.Signer + certs []*x509.Certificate + algorithm string } -func newRemoteSigningMethod(signer signature.Signer) jwt.SigningMethod { - return &remoteSigningMethod{signer: signer} +func newRemoteSigningMethod(signer signature.Signer) (SigningMethod, error) { + algorithm, err := extractJwtAlgorithm(signer) + if err != nil { + return nil, err + } + return &remoteSigningMethod{ + signer: signer, + algorithm: algorithm, + }, nil } // Verify doesn't need to be implemented. @@ -25,51 +49,66 @@ func (s *remoteSigningMethod) Verify(signingString, signature string, key interf // Sign hashes the signingString and call the remote signer to sign the digest. func (s *remoteSigningMethod) Sign(signingString string, key interface{}) (string, error) { - keySpec, err := s.signer.KeySpec() - if err != nil { - return "", err - } - - // get hasher - hasher := keySpec.SignatureAlgorithm().Hash() - if !hasher.Available() { - return "", &signature.SignatureAlgoNotSupportedError{Alg: hasher.String()} - } - - // calculate hash - h := hasher.New() - h.Write([]byte(signingString)) - hash := h.Sum(nil) - // sign by external signer - sig, err := s.signer.Sign(hash) + sig, certs, err := s.signer.Sign([]byte(signingString)) if err != nil { return "", err } + s.certs = certs return base64.RawURLEncoding.EncodeToString(sig), nil } -// Alg doesn't need to be implemented. +// Alg return the signing algorithm func (s *remoteSigningMethod) Alg() string { - alg, err := extractJwtAlgorithm(s.signer) - if err != nil { - panic(err) - } - return alg + return s.algorithm } -// verifyJWT verifies the JWT token against the specified verification key -func verifyJWT(tokenString string, cert *x509.Certificate) error { - keySpec, err := signature.ExtractKeySpec(cert) - if err != nil { - return err +// CertificateChain returns the certificate chain +// +// should be called after Sign() +func (s *remoteSigningMethod) CertificateChain() ([]*x509.Certificate, error) { + if s.certs == nil { + return nil, &signature.RemoteSigningError{Msg: "certificate chain is not set"} } - jwsAlg, err := convertAlgorithm(keySpec.SignatureAlgorithm()) + return s.certs, nil +} + +// PrivateKey returns nil for remote signer +func (s *remoteSigningMethod) PrivateKey() crypto.PrivateKey { + return nil +} + +// localSigningMethod wraps the local signer to be a SigningMethod +type localSigningMethod struct { + jwt.SigningMethod + signer signature.LocalSigner + certs []*x509.Certificate +} + +func newLocalSigningMethod(signer signature.LocalSigner) (SigningMethod, error) { + alg, err := extractJwtAlgorithm(signer) if err != nil { - return err + return nil, err } - signingMethod := jwt.GetSigningMethod(jwsAlg) + return &localSigningMethod{ + SigningMethod: jwt.GetSigningMethod(alg), + signer: signer, + }, nil +} + +// CertificateChain returns the certificate chain +func (s *localSigningMethod) CertificateChain() ([]*x509.Certificate, error) { + return s.signer.CertificateChain() +} + +// PrivateKey returns the private key +func (s *localSigningMethod) PrivateKey() crypto.PrivateKey { + return s.signer.PrivateKey() +} + +// verifyJWT verifies the JWT token against the specified verification key +func verifyJWT(tokenString string, publicKey interface{}) error { parser := jwt.NewParser( jwt.WithValidMethods(validMethods), jwt.WithJSONNumber(), @@ -77,14 +116,7 @@ func verifyJWT(tokenString string, cert *x509.Certificate) error { ) if _, err := parser.ParseWithClaims(tokenString, &jwtPayload{}, func(t *jwt.Token) (interface{}, error) { - if t.Method.Alg() != signingMethod.Alg() { - return nil, &signature.MalformedSignatureError{ - Msg: fmt.Sprintf("unexpected signing method: %v: require %v", t.Method.Alg(), signingMethod.Alg())} - } - - // override default signing method with key-specific method - t.Method = signingMethod - return cert.PublicKey, nil + return publicKey, nil }); err != nil { return &signature.SignatureIntegrityError{Err: err} } @@ -92,16 +124,15 @@ func verifyJWT(tokenString string, cert *x509.Certificate) error { } func extractJwtAlgorithm(signer signature.Signer) (string, error) { + // extract algorithm from signer keySpec, err := signer.KeySpec() if err != nil { return "", err } - return convertAlgorithm(keySpec.SignatureAlgorithm()) -} + alg := keySpec.SignatureAlgorithm() -// convertAlgorithm converts the signature.Algorithm to be jwt package defined -// algorithm name. -func convertAlgorithm(alg signature.Algorithm) (string, error) { + // converts the signature.Algorithm to be jwt package defined + // algorithm name. jwsAlg, ok := signatureAlgJWSAlgMap[alg] if !ok { return "", &signature.SignatureAlgoNotSupportedError{ diff --git a/signature/jws/jwt_test.go b/signature/jws/jwt_test.go new file mode 100644 index 00000000..1a10fbad --- /dev/null +++ b/signature/jws/jwt_test.go @@ -0,0 +1,126 @@ +package jws + +import ( + "crypto" + "crypto/x509" + "errors" + "testing" + + "github.com/notaryproject/notation-core-go/signature" + "github.com/notaryproject/notation-core-go/testhelper" +) + +type errorLocalSigner struct { + algType signature.KeyType + size int + keySpecError error +} + +// Sign returns error +func (s *errorLocalSigner) Sign(payload []byte) ([]byte, []*x509.Certificate, error) { + return nil, nil, errors.New("sign error") +} + +// KeySpec returns the key specification. +func (s *errorLocalSigner) KeySpec() (signature.KeySpec, error) { + return signature.KeySpec{ + Type: s.algType, + Size: s.size, + }, s.keySpecError +} + +// PrivateKey returns nil. +func (s *errorLocalSigner) PrivateKey() crypto.PrivateKey { + return nil +} + +// CertificateChain returns nil. +func (s *errorLocalSigner) CertificateChain() ([]*x509.Certificate, error) { + return nil, nil +} + +func Test_remoteSigningMethod_Verify(t *testing.T) { + defer func() { + if d := recover(); d == nil { + t.Fatal("should panic") + } + }() + s := &remoteSigningMethod{} // Sign signs the payload and returns the raw signature and certificates. + s.Verify("", "", nil) +} + +func Test_extractJwtAlgorithm(t *testing.T) { + _, err := extractJwtAlgorithm(&errorLocalSigner{}) + cmpError(t, err.Error(), `signature algorithm "#0" is not supported`) + + _, err = extractJwtAlgorithm(&errorLocalSigner{ + keySpecError: errors.New("get key spec error"), + }) + cmpError(t, err.Error(), `get key spec error`) +} + +func Test_verifyJWT(t *testing.T) { + type args struct { + tokenString string + publicKey interface{} + } + tests := []struct { + name string + args args + wantErr bool + }{ + { + name: "invalid signature", + args: args{ + tokenString: "eyJhbGciOiJQUzM4NCIsImNyaXQiOlsiaW8uY25jZi5ub3Rhcnkuc2lnbmluZ1NjaGVtZSIsInRlc3RLZXkiLCJpby5jbmNmLm5vdGFyeS5leHBpcnkiXSwiY3R5IjoiYXBwbGljYXRpb24vdm5kLmNuY2Yubm90YXJ5LnBheWxvYWQudjEranNvbiIsImlvLmNuY2Yubm90YXJ5LmV4cGlyeSI6IjIwMjItMDgtMjRUMTc6MTg6MTUuNDkxNzQ1ODQ1KzA4OjAwIiwiaW8uY25jZi5ub3Rhcnkuc2lnbmluZ1NjaGVtZSI6Im5vdGFyeS54NTA5IiwiaW8uY25jZi5ub3Rhcnkuc2lnbmluZ1RpbWUiOiIyMDIyLTA4LTI0VDE2OjE4OjE1LjQ5MTc0NTgwNCswODowMCIsInRlc3RLZXkiOiJ0ZXN0VmFsdWUiLCJ0ZXN0S2V5MiI6InRlc3RWYWx1ZTIifQ.ImV3b2dJQ0p6ZFdKcVpXTjBJam9nZXdvZ0lDQWdJbTFsWkdsaFZIbHdaU0k2SUNKaGNIQnNhV05oZEdsdmJpOTJibVF1YjJOcExtbHRZV2RsTG0xaGJtbG1aWE4wTG5ZeEsycHpiMjRpTEFvZ0lDQWdJbVJwWjJWemRDSTZJQ0p6YUdFeU5UWTZOek5qT0RBek9UTXdaV0V6WW1FeFpUVTBZbU15TldNeVltUmpOVE5sWkdRd01qZzBZell5WldRMk5URm1aVGRpTURBek5qbGtZVFV4T1dFell6TXpNeUlzQ2lBZ0lDQWljMmw2WlNJNklERTJOekkwTEFvZ0lDQWdJbUZ1Ym05MFlYUnBiMjV6SWpvZ2V3b2dJQ0FnSUNBZ0lDSnBieTUzWVdKaWFYUXRibVYwZDI5eWEzTXVZblZwYkdSSlpDSTZJQ0l4TWpNaUNpQWdJQ0I5Q2lBZ2ZRcDlDZ2s9Ig.YmF1_5dMW4YWK2fzct1dp25lTy8p0qdSmR-O2fZsf29ohiLYGUVXfvRjEgERzZvDd49aOYQvrEgGvoU9FfK2KIqHrJ8kliI00wd4kuK57aE83pszBMOOrZqAjqkdyoj7dswmwJSyjMC9fhwh_AwrrOnrBjw4U0vGTrImMQEwHfVq0MWLCuw9YpFkytLPeCl8n825EtqMzwYYTUzdQfQJO_ZZrS34n8tK0IRZrX2LjrYz9HqR_UFgVqf_G9qwJpekYyd9Aacl9y4x7zzI-R-bADFgztyAYeWRmE75qI26OgG-ss4wfG-ZbchEm6FYU8py64bsLmJtK9muPd9ZU7SXQOEVzxtXoQFnUhT9AgaNNoxnSnU25mMjAeuGDj0Xn_Gv7f24PyDk9ZEE3WjrguJyzaP6P4jYugXr6Afq10HXRpI_cE8B-6USGpiRH9iJLE04xumWpjWup9p5fv3Fnt3Au1dhbgaDvrSGMHmmCSW4dk7_87Q4LGkGcbn0zNINydcg", + publicKey: testhelper.GetRSALeafCertificate().Cert.PublicKey, + }, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := verifyJWT(tt.args.tokenString, tt.args.publicKey); (err != nil) != tt.wantErr { + t.Errorf("verifyJWT() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func Test_newLocalSigningMethod(t *testing.T) { + signer := errorLocalSigner{} + _, err := newLocalSigningMethod(&signer) + cmpError(t, err.Error(), `signature algorithm "#0" is not supported`) +} + +func Test_newRemoteSigningMethod(t *testing.T) { + _, err := newRemoteSigningMethod(&errorLocalSigner{}) + cmpError(t, err.Error(), `signature algorithm "#0" is not supported`) +} + +func Test_remoteSigningMethod_CertificateChain(t *testing.T) { + certs := []*x509.Certificate{ + testhelper.GetRSALeafCertificate().Cert, + } + signer, err := getSigner(false, certs, testhelper.GetRSALeafCertificate().PrivateKey) + signingScheme, err := newRemoteSigningMethod(signer) + if err != nil { + t.Fatal(err) + } + _, err = signingScheme.CertificateChain() + cmpError(t, err.Error(), "remote signing error. Error: certificate chain is not set") +} + +func Test_remoteSigningMethod_Sign(t *testing.T) { + signer := errorLocalSigner{ + algType: signature.KeyTypeRSA, + size: 2048, + keySpecError: nil, + } + signingScheme, err := newRemoteSigningMethod(&signer) + if err != nil { + t.Fatal(err) + } + _, err = signingScheme.Sign("", nil) + cmpError(t, err.Error(), "sign error") +}