Skip to content

Commit

Permalink
centralize key validations and check more invalid cases
Browse files Browse the repository at this point in the history
Signed-off-by: qmuntal <qmuntaldiaz@microsoft.com>
  • Loading branch information
qmuntal committed Jul 5, 2023
1 parent 2f778da commit b913048
Show file tree
Hide file tree
Showing 3 changed files with 122 additions and 85 deletions.
8 changes: 4 additions & 4 deletions errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@ var (
ErrNoSignatures = errors.New("no signatures attached")
ErrUnavailableHashFunc = errors.New("hash function is not available")
ErrVerification = errors.New("verification error")
ErrInvalidKey = errors.New("invalid key")
ErrInvalidPubKey = errors.New("invalid public key")
ErrInvalidPrivKey = errors.New("invalid private key")
ErrNotPrivKey = errors.New("not a private key")
ErrSignOpNotSupported = errors.New("sign key_op not supported by key")
ErrVerifyOpNotSupported = errors.New("verify key_op not supported by key")
ErrEC2NoPub = errors.New("cannot create PrivateKey from EC2 key: missing X or Y")
ErrOKPNoPub = errors.New("cannot create PrivateKey from OKP key: missing X")
ErrOpNotSupported = errors.New("key_op not supported by key")
ErrEC2NoPub = errors.New("cannot create PrivateKey from EC2 key: missing x or y")
ErrOKPNoPub = errors.New("cannot create PrivateKey from OKP key: missing x")
)
120 changes: 58 additions & 62 deletions key.go
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ func NewOKPKey(alg Algorithm, x, d []byte) (*Key, error) {
X: x,
D: d,
}
return key, key.Validate()
return key, key.validate(KeyOpInvalid)
}

// NewEC2Key returns a Key created using the provided elliptic curve key
Expand All @@ -303,17 +303,17 @@ func NewEC2Key(alg Algorithm, x, y, d []byte) (*Key, error) {
Y: y,
D: d,
}
return key, key.Validate()
return key, key.validate(KeyOpInvalid)
}

// NewSymmetricKey returns a Key created using the provided Symmetric key
// bytes.
func NewSymmetricKey(k []byte) (*Key, error) {
func NewSymmetricKey(k []byte) *Key {
key := &Key{
KeyType: KeyTypeSymmetric,
K: k,
}
return key, key.Validate()
return key
}

// NewKeyFromPublic returns a Key created using the provided crypto.PublicKey.
Expand Down Expand Up @@ -355,10 +355,24 @@ func NewKeyFromPrivate(priv crypto.PrivateKey) (*Key, error) {
}

// Validate ensures that the parameters set inside the Key are internally
// consistent (e.g., that the key type is appropriate to the curve.)
func (k Key) Validate() error {
// consistent (e.g., that the key type is appropriate to the curve).
// It also checks that the key is valid for the requested operation.
func (k Key) validate(op KeyOp) error {
switch k.KeyType {
case KeyTypeEC2:
switch op {
case KeyOpVerify:
if len(k.X) == 0 || len(k.Y) == 0 {
return ErrEC2NoPub
}
case KeyOpSign:
if len(k.D) == 0 {
return ErrNotPrivKey
}
}
if k.Curve == CurveInvalid || (len(k.X) == 0 && len(k.Y) == 0 && len(k.D) == 0) {
return ErrInvalidKey
}
switch k.Curve {
case CurveX25519, CurveX448, CurveEd25519, CurveEd448:
return fmt.Errorf(
Expand All @@ -370,6 +384,19 @@ func (k Key) Validate() error {
// see https://www.rfc-editor.org/rfc/rfc8152#section-13.1.1
}
case KeyTypeOKP:
switch op {
case KeyOpVerify:
if len(k.X) == 0 {
return ErrOKPNoPub
}
case KeyOpSign:
if len(k.D) == 0 {
return ErrNotPrivKey
}
}
if k.Curve == CurveInvalid || (len(k.X) == 0 && len(k.D) == 0) {
return ErrInvalidKey
}
switch k.Curve {
case CurveP256, CurveP384, CurveP521:
return fmt.Errorf(
Expand All @@ -381,8 +408,22 @@ func (k Key) Validate() error {
// see https://www.rfc-editor.org/rfc/rfc8152#section-13.2
}
case KeyTypeSymmetric:
// Nothing to validate
default:
return errors.New(k.KeyType.String())
// Unknown key type, we can't validate custom parameters.
}

if op != KeyOpInvalid && k.KeyOps != nil {
found := false
for _, kop := range k.KeyOps {
if kop == op {
found = true
break
}
}
if !found {
return ErrOpNotSupported
}
}

// If Algorithm is set, it must match the specified key parameters.
Expand Down Expand Up @@ -483,11 +524,14 @@ func (k *Key) UnmarshalCBOR(data []byte) error {
return fmt.Errorf("unexpected key type %q", k.KeyType.String())
}

return k.Validate()
return k.validate(KeyOpInvalid)
}

// PublicKey returns a crypto.PublicKey generated using Key's parameters.
func (k *Key) PublicKey() (crypto.PublicKey, error) {
if err := k.validate(KeyOpVerify); err != nil {
return nil, err
}
alg, err := k.deriveAlgorithm()
if err != nil {
return nil, err
Expand Down Expand Up @@ -520,6 +564,9 @@ func (k *Key) PublicKey() (crypto.PublicKey, error) {

// PrivateKey returns a crypto.PrivateKey generated using Key's parameters.
func (k *Key) PrivateKey() (crypto.PrivateKey, error) {
if err := k.validate(KeyOpSign); err != nil {
return nil, err
}
alg, err := k.deriveAlgorithm()
if err != nil {
return nil, err
Expand Down Expand Up @@ -591,25 +638,6 @@ func (k *Key) AlgorithmOrDefault() (Algorithm, error) {

// Signer returns a Signer created using Key.
func (k *Key) Signer() (Signer, error) {
if err := k.Validate(); err != nil {
return nil, err
}

if k.KeyOps != nil {
signFound := false

for _, kop := range k.KeyOps {
if kop == KeyOpSign {
signFound = true
break
}
}

if !signFound {
return nil, ErrSignOpNotSupported
}
}

priv, err := k.PrivateKey()
if err != nil {
return nil, err
Expand All @@ -620,48 +648,16 @@ func (k *Key) Signer() (Signer, error) {
return nil, err
}

var signer crypto.Signer
var ok bool

switch alg {
case AlgorithmES256, AlgorithmES384, AlgorithmES512:
signer, ok = priv.(*ecdsa.PrivateKey)
if !ok {
return nil, ErrInvalidPrivKey
}
case AlgorithmEd25519:
signer, ok = priv.(ed25519.PrivateKey)
if !ok {
return nil, ErrInvalidPrivKey
}
default:
return nil, ErrAlgorithmNotSupported
signer, ok := priv.(crypto.Signer)
if !ok {
return nil, ErrInvalidPrivKey
}

return NewSigner(alg, signer)
}

// Verifier returns a Verifier created using Key.
func (k *Key) Verifier() (Verifier, error) {
if err := k.Validate(); err != nil {
return nil, err
}

if k.KeyOps != nil {
verifyFound := false

for _, kop := range k.KeyOps {
if kop == KeyOpVerify {
verifyFound = true
break
}
}

if !verifyFound {
return nil, ErrVerifyOpNotSupported
}
}

pub, err := k.PublicKey()
if err != nil {
return nil, err
Expand Down
79 changes: 60 additions & 19 deletions key_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -171,19 +171,34 @@ func Test_Key_UnmarshalCBOR(t *testing.T) {
{
Name: "invalid curve OKP",
Value: []byte{
0xa2, // map (2)
0xa3, // map (3)
0x01, 0x01, // kty: OKP
0x20, 0x01, // curve: CurveP256
0x21, 0x58, 0x20, // x-coordinate: bytes(32)
0x15, 0x52, 0x2e, 0xf1, 0x57, 0x29, 0xcc, 0xf3, // 32-byte value
0x95, 0x09, 0xea, 0x5c, 0x15, 0xa2, 0x6b, 0xe9,
0x49, 0xe3, 0x88, 0x07, 0xa5, 0xc2, 0x6e, 0xf9,
0x28, 0x14, 0x87, 0xef, 0x4a, 0xe6, 0x7b, 0x46,
},
WantErr: `Key type mismatch for curve "P-256" (must be EC2, found OKP)`,
Validate: nil,
},
{
Name: "invalid curve EC2",
Value: []byte{
0xa2, // map (2)
0xa4, // map (4)
0x01, 0x02, // kty: EC2
0x20, 0x06, // curve: CurveEd25519
0x21, 0x58, 0x20, // x-coordinate: bytes(32)
0x15, 0x52, 0x2e, 0xf1, 0x57, 0x29, 0xcc, 0xf3, // 32-byte value
0x95, 0x09, 0xea, 0x5c, 0x15, 0xa2, 0x6b, 0xe9,
0x49, 0xe3, 0x88, 0x07, 0xa5, 0xc2, 0x6e, 0xf9,
0x28, 0x14, 0x87, 0xef, 0x4a, 0xe6, 0x7b, 0x46,
0x22, 0x58, 0x20, // y-coordinate: bytes(32)
0x15, 0x52, 0x2e, 0xf1, 0x57, 0x29, 0xcc, 0xf3, // 32-byte value
0x95, 0x09, 0xea, 0x5c, 0x15, 0xa2, 0x6b, 0xe9,
0x49, 0xe3, 0x88, 0x07, 0xa5, 0xc2, 0x6e, 0xf9,
0x28, 0x14, 0x87, 0xef, 0x4a, 0xe6, 0x7b, 0x46,
},
WantErr: `Key type mismatch for curve "Ed25519" (must be OKP, found EC2)`,
Validate: nil,
Expand Down Expand Up @@ -367,14 +382,9 @@ func Test_Key_Create_and_Validate(t *testing.T) {
assertEqual(t, x, key.X)
assertEqual(t, y, key.Y)

key, err = NewSymmetricKey(x)
requireNoError(t, err)
key = NewSymmetricKey(x)
assertEqual(t, x, key.K)

key.KeyType = KeyType(7)
err = key.Validate()
assertEqualError(t, err, "unknown key type value 7")

_, err = NewKeyFromPublic(crypto.PublicKey([]byte{0xde, 0xad, 0xbe, 0xef}))
assertEqualError(t, err, "invalid public key")

Expand Down Expand Up @@ -536,7 +546,7 @@ func Test_Key_signer_validation(t *testing.T) {
key.Curve = CurveEd25519
key.KeyOps = []KeyOp{}
_, err = key.Signer()
assertEqualError(t, err, ErrSignOpNotSupported.Error())
assertEqualError(t, err, ErrOpNotSupported.Error())

key.KeyOps = []KeyOp{KeyOpSign}
_, err = key.Signer()
Expand All @@ -551,7 +561,7 @@ func Test_Key_signer_validation(t *testing.T) {
assertEqualError(t, err, `unsupported curve "X448" for key type OKP`)
}

func Test_Key_verifier_validation(t *testing.T) {
func TestKey_Verifier(t *testing.T) {
pub, _, err := ed25519.GenerateKey(rand.Reader)
requireNoError(t, err)

Expand All @@ -563,42 +573,49 @@ func Test_Key_verifier_validation(t *testing.T) {

key.KeyType = KeyTypeEC2
_, err = key.Verifier()
assertEqualError(t, err, `Key type mismatch for curve "Ed25519" (must be OKP, found EC2)`)
assertEqualError(t, err, ErrEC2NoPub.Error())

key.KeyType = KeyTypeOKP
key.KeyOps = []KeyOp{}
_, err = key.Verifier()
assertEqualError(t, err, ErrVerifyOpNotSupported.Error())
assertEqualError(t, err, ErrOpNotSupported.Error())

key.KeyOps = []KeyOp{KeyOpVerify}
_, err = key.Verifier()
requireNoError(t, err)
}

func Test_Key_crypto_keys(t *testing.T) {
func TestKey_PrivateKey(t *testing.T) {
k := Key{
KeyType: KeyType(7),
}

_, err := k.PublicKey()
assertEqualError(t, err, `unexpected key type "unknown key type value 7"`)
_, err = k.PrivateKey()
_, err := k.PrivateKey()
assertEqualError(t, err, `unexpected key type "unknown key type value 7"`)

k = Key{
KeyType: KeyTypeOKP,
Curve: CurveX448,
X: make([]byte, 1),
D: make([]byte, 1),
}

_, err = k.PublicKey()
assertEqualError(t, err, `unsupported curve "X448" for key type OKP`)
_, err = k.PrivateKey()
assertEqualError(t, err, `unsupported curve "X448" for key type OKP`)

k = Key{
KeyType: KeyTypeOKP,
Curve: CurveEd25519,
D: []byte{0xde, 0xad, 0xbe, 0xef},
X: make([]byte, 1),
}

_, err = k.PrivateKey()
assertEqualError(t, err, ErrNotPrivKey.Error())

k = Key{
KeyType: KeyTypeOKP,
Curve: CurveEd25519,
D: make([]byte, 1),
}

_, err = k.PrivateKey()
Expand All @@ -611,6 +628,30 @@ func Test_Key_crypto_keys(t *testing.T) {
assertEqualError(t, err, ErrEC2NoPub.Error())
}

func TestKey_PublicKey(t *testing.T) {
k := Key{
KeyType: KeyType(7),
}

_, err := k.PublicKey()
assertEqualError(t, err, `unexpected key type "unknown key type value 7"`)

k = Key{
KeyType: KeyTypeOKP,
Curve: CurveEd25519,
D: []byte{0xde, 0xad, 0xbe, 0xef},
}

_, err = k.PublicKey()
assertEqualError(t, err, ErrOKPNoPub.Error())

k.KeyType = KeyTypeEC2
k.Curve = CurveP256

_, err = k.PublicKey()
assertEqualError(t, err, ErrEC2NoPub.Error())
}

func Test_String(t *testing.T) {
// test string conversions not exercised by other test cases
assertEqual(t, "OKP", KeyTypeOKP.String())
Expand Down

0 comments on commit b913048

Please sign in to comment.