diff --git a/errors.go b/errors.go index de601e5..770dc9e 100644 --- a/errors.go +++ b/errors.go @@ -13,11 +13,11 @@ var ( ErrNoSignatures = errors.New("no signatures attached") ErrUnavailableHashFunc = errors.New("hash function is not available") ErrVerification = errors.New("verification error") + ErrInvalidKey = errors.New("invalid key") ErrInvalidPubKey = errors.New("invalid public key") ErrInvalidPrivKey = errors.New("invalid private key") ErrNotPrivKey = errors.New("not a private key") - ErrSignOpNotSupported = errors.New("sign key_op not supported by key") - ErrVerifyOpNotSupported = errors.New("verify key_op not supported by key") - ErrEC2NoPub = errors.New("cannot create PrivateKey from EC2 key: missing X or Y") - ErrOKPNoPub = errors.New("cannot create PrivateKey from OKP key: missing X") + ErrOpNotSupported = errors.New("key_op not supported by key") + ErrEC2NoPub = errors.New("cannot create PrivateKey from EC2 key: missing x or y") + ErrOKPNoPub = errors.New("cannot create PrivateKey from OKP key: missing x") ) diff --git a/key.go b/key.go index fd1a074..37afc24 100644 --- a/key.go +++ b/key.go @@ -276,7 +276,7 @@ func NewOKPKey(alg Algorithm, x, d []byte) (*Key, error) { X: x, D: d, } - return key, key.Validate() + return key, key.validate(KeyOpInvalid) } // NewEC2Key returns a Key created using the provided elliptic curve key @@ -303,17 +303,17 @@ func NewEC2Key(alg Algorithm, x, y, d []byte) (*Key, error) { Y: y, D: d, } - return key, key.Validate() + return key, key.validate(KeyOpInvalid) } // NewSymmetricKey returns a Key created using the provided Symmetric key // bytes. -func NewSymmetricKey(k []byte) (*Key, error) { +func NewSymmetricKey(k []byte) *Key { key := &Key{ KeyType: KeyTypeSymmetric, K: k, } - return key, key.Validate() + return key } // NewKeyFromPublic returns a Key created using the provided crypto.PublicKey. @@ -355,10 +355,24 @@ func NewKeyFromPrivate(priv crypto.PrivateKey) (*Key, error) { } // Validate ensures that the parameters set inside the Key are internally -// consistent (e.g., that the key type is appropriate to the curve.) -func (k Key) Validate() error { +// consistent (e.g., that the key type is appropriate to the curve). +// It also checks that the key is valid for the requested operation. +func (k Key) validate(op KeyOp) error { switch k.KeyType { case KeyTypeEC2: + switch op { + case KeyOpVerify: + if len(k.X) == 0 || len(k.Y) == 0 { + return ErrEC2NoPub + } + case KeyOpSign: + if len(k.D) == 0 { + return ErrNotPrivKey + } + } + if k.Curve == CurveInvalid || (len(k.X) == 0 && len(k.Y) == 0 && len(k.D) == 0) { + return ErrInvalidKey + } switch k.Curve { case CurveX25519, CurveX448, CurveEd25519, CurveEd448: return fmt.Errorf( @@ -370,6 +384,19 @@ func (k Key) Validate() error { // see https://www.rfc-editor.org/rfc/rfc8152#section-13.1.1 } case KeyTypeOKP: + switch op { + case KeyOpVerify: + if len(k.X) == 0 { + return ErrOKPNoPub + } + case KeyOpSign: + if len(k.D) == 0 { + return ErrNotPrivKey + } + } + if k.Curve == CurveInvalid || (len(k.X) == 0 && len(k.D) == 0) { + return ErrInvalidKey + } switch k.Curve { case CurveP256, CurveP384, CurveP521: return fmt.Errorf( @@ -381,8 +408,22 @@ func (k Key) Validate() error { // see https://www.rfc-editor.org/rfc/rfc8152#section-13.2 } case KeyTypeSymmetric: + // Nothing to validate default: - return errors.New(k.KeyType.String()) + // Unknown key type, we can't validate custom parameters. + } + + if op != KeyOpInvalid && k.KeyOps != nil { + found := false + for _, kop := range k.KeyOps { + if kop == op { + found = true + break + } + } + if !found { + return ErrOpNotSupported + } } // If Algorithm is set, it must match the specified key parameters. @@ -483,11 +524,14 @@ func (k *Key) UnmarshalCBOR(data []byte) error { return fmt.Errorf("unexpected key type %q", k.KeyType.String()) } - return k.Validate() + return k.validate(KeyOpInvalid) } // PublicKey returns a crypto.PublicKey generated using Key's parameters. func (k *Key) PublicKey() (crypto.PublicKey, error) { + if err := k.validate(KeyOpVerify); err != nil { + return nil, err + } alg, err := k.deriveAlgorithm() if err != nil { return nil, err @@ -520,6 +564,9 @@ func (k *Key) PublicKey() (crypto.PublicKey, error) { // PrivateKey returns a crypto.PrivateKey generated using Key's parameters. func (k *Key) PrivateKey() (crypto.PrivateKey, error) { + if err := k.validate(KeyOpSign); err != nil { + return nil, err + } alg, err := k.deriveAlgorithm() if err != nil { return nil, err @@ -591,25 +638,6 @@ func (k *Key) AlgorithmOrDefault() (Algorithm, error) { // Signer returns a Signer created using Key. func (k *Key) Signer() (Signer, error) { - if err := k.Validate(); err != nil { - return nil, err - } - - if k.KeyOps != nil { - signFound := false - - for _, kop := range k.KeyOps { - if kop == KeyOpSign { - signFound = true - break - } - } - - if !signFound { - return nil, ErrSignOpNotSupported - } - } - priv, err := k.PrivateKey() if err != nil { return nil, err @@ -620,22 +648,9 @@ func (k *Key) Signer() (Signer, error) { return nil, err } - var signer crypto.Signer - var ok bool - - switch alg { - case AlgorithmES256, AlgorithmES384, AlgorithmES512: - signer, ok = priv.(*ecdsa.PrivateKey) - if !ok { - return nil, ErrInvalidPrivKey - } - case AlgorithmEd25519: - signer, ok = priv.(ed25519.PrivateKey) - if !ok { - return nil, ErrInvalidPrivKey - } - default: - return nil, ErrAlgorithmNotSupported + signer, ok := priv.(crypto.Signer) + if !ok { + return nil, ErrInvalidPrivKey } return NewSigner(alg, signer) @@ -643,25 +658,6 @@ func (k *Key) Signer() (Signer, error) { // Verifier returns a Verifier created using Key. func (k *Key) Verifier() (Verifier, error) { - if err := k.Validate(); err != nil { - return nil, err - } - - if k.KeyOps != nil { - verifyFound := false - - for _, kop := range k.KeyOps { - if kop == KeyOpVerify { - verifyFound = true - break - } - } - - if !verifyFound { - return nil, ErrVerifyOpNotSupported - } - } - pub, err := k.PublicKey() if err != nil { return nil, err diff --git a/key_test.go b/key_test.go index 0c7b236..c2d2979 100644 --- a/key_test.go +++ b/key_test.go @@ -171,9 +171,14 @@ func Test_Key_UnmarshalCBOR(t *testing.T) { { Name: "invalid curve OKP", Value: []byte{ - 0xa2, // map (2) + 0xa3, // map (3) 0x01, 0x01, // kty: OKP 0x20, 0x01, // curve: CurveP256 + 0x21, 0x58, 0x20, // x-coordinate: bytes(32) + 0x15, 0x52, 0x2e, 0xf1, 0x57, 0x29, 0xcc, 0xf3, // 32-byte value + 0x95, 0x09, 0xea, 0x5c, 0x15, 0xa2, 0x6b, 0xe9, + 0x49, 0xe3, 0x88, 0x07, 0xa5, 0xc2, 0x6e, 0xf9, + 0x28, 0x14, 0x87, 0xef, 0x4a, 0xe6, 0x7b, 0x46, }, WantErr: `Key type mismatch for curve "P-256" (must be EC2, found OKP)`, Validate: nil, @@ -181,9 +186,19 @@ func Test_Key_UnmarshalCBOR(t *testing.T) { { Name: "invalid curve EC2", Value: []byte{ - 0xa2, // map (2) + 0xa4, // map (4) 0x01, 0x02, // kty: EC2 0x20, 0x06, // curve: CurveEd25519 + 0x21, 0x58, 0x20, // x-coordinate: bytes(32) + 0x15, 0x52, 0x2e, 0xf1, 0x57, 0x29, 0xcc, 0xf3, // 32-byte value + 0x95, 0x09, 0xea, 0x5c, 0x15, 0xa2, 0x6b, 0xe9, + 0x49, 0xe3, 0x88, 0x07, 0xa5, 0xc2, 0x6e, 0xf9, + 0x28, 0x14, 0x87, 0xef, 0x4a, 0xe6, 0x7b, 0x46, + 0x22, 0x58, 0x20, // y-coordinate: bytes(32) + 0x15, 0x52, 0x2e, 0xf1, 0x57, 0x29, 0xcc, 0xf3, // 32-byte value + 0x95, 0x09, 0xea, 0x5c, 0x15, 0xa2, 0x6b, 0xe9, + 0x49, 0xe3, 0x88, 0x07, 0xa5, 0xc2, 0x6e, 0xf9, + 0x28, 0x14, 0x87, 0xef, 0x4a, 0xe6, 0x7b, 0x46, }, WantErr: `Key type mismatch for curve "Ed25519" (must be OKP, found EC2)`, Validate: nil, @@ -367,14 +382,9 @@ func Test_Key_Create_and_Validate(t *testing.T) { assertEqual(t, x, key.X) assertEqual(t, y, key.Y) - key, err = NewSymmetricKey(x) - requireNoError(t, err) + key = NewSymmetricKey(x) assertEqual(t, x, key.K) - key.KeyType = KeyType(7) - err = key.Validate() - assertEqualError(t, err, "unknown key type value 7") - _, err = NewKeyFromPublic(crypto.PublicKey([]byte{0xde, 0xad, 0xbe, 0xef})) assertEqualError(t, err, "invalid public key") @@ -536,7 +546,7 @@ func Test_Key_signer_validation(t *testing.T) { key.Curve = CurveEd25519 key.KeyOps = []KeyOp{} _, err = key.Signer() - assertEqualError(t, err, ErrSignOpNotSupported.Error()) + assertEqualError(t, err, ErrOpNotSupported.Error()) key.KeyOps = []KeyOp{KeyOpSign} _, err = key.Signer() @@ -551,7 +561,7 @@ func Test_Key_signer_validation(t *testing.T) { assertEqualError(t, err, `unsupported curve "X448" for key type OKP`) } -func Test_Key_verifier_validation(t *testing.T) { +func TestKey_Verifier(t *testing.T) { pub, _, err := ed25519.GenerateKey(rand.Reader) requireNoError(t, err) @@ -563,42 +573,49 @@ func Test_Key_verifier_validation(t *testing.T) { key.KeyType = KeyTypeEC2 _, err = key.Verifier() - assertEqualError(t, err, `Key type mismatch for curve "Ed25519" (must be OKP, found EC2)`) + assertEqualError(t, err, ErrEC2NoPub.Error()) key.KeyType = KeyTypeOKP key.KeyOps = []KeyOp{} _, err = key.Verifier() - assertEqualError(t, err, ErrVerifyOpNotSupported.Error()) + assertEqualError(t, err, ErrOpNotSupported.Error()) key.KeyOps = []KeyOp{KeyOpVerify} _, err = key.Verifier() requireNoError(t, err) } -func Test_Key_crypto_keys(t *testing.T) { +func TestKey_PrivateKey(t *testing.T) { k := Key{ KeyType: KeyType(7), } - _, err := k.PublicKey() - assertEqualError(t, err, `unexpected key type "unknown key type value 7"`) - _, err = k.PrivateKey() + _, err := k.PrivateKey() assertEqualError(t, err, `unexpected key type "unknown key type value 7"`) k = Key{ KeyType: KeyTypeOKP, Curve: CurveX448, + X: make([]byte, 1), + D: make([]byte, 1), } - _, err = k.PublicKey() - assertEqualError(t, err, `unsupported curve "X448" for key type OKP`) _, err = k.PrivateKey() assertEqualError(t, err, `unsupported curve "X448" for key type OKP`) k = Key{ KeyType: KeyTypeOKP, Curve: CurveEd25519, - D: []byte{0xde, 0xad, 0xbe, 0xef}, + X: make([]byte, 1), + } + + _, err = k.PrivateKey() + assertEqualError(t, err, ErrNotPrivKey.Error()) + + k = Key{ + KeyType: KeyTypeOKP, + Curve: CurveEd25519, + D: make([]byte, 1), } _, err = k.PrivateKey() @@ -611,6 +628,30 @@ func Test_Key_crypto_keys(t *testing.T) { assertEqualError(t, err, ErrEC2NoPub.Error()) } +func TestKey_PublicKey(t *testing.T) { + k := Key{ + KeyType: KeyType(7), + } + + _, err := k.PublicKey() + assertEqualError(t, err, `unexpected key type "unknown key type value 7"`) + + k = Key{ + KeyType: KeyTypeOKP, + Curve: CurveEd25519, + D: []byte{0xde, 0xad, 0xbe, 0xef}, + } + + _, err = k.PublicKey() + assertEqualError(t, err, ErrOKPNoPub.Error()) + + k.KeyType = KeyTypeEC2 + k.Curve = CurveP256 + + _, err = k.PublicKey() + assertEqualError(t, err, ErrEC2NoPub.Error()) +} + func Test_String(t *testing.T) { // test string conversions not exercised by other test cases assertEqual(t, "OKP", KeyTypeOKP.String())