Skip to content

Commit

Permalink
noise: make it possible for the server to send early data
Browse files Browse the repository at this point in the history
  • Loading branch information
marten-seemann committed Sep 13, 2022
1 parent ddfb6f9 commit 65270f1
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 63 deletions.
61 changes: 39 additions & 22 deletions p2p/security/noise/handshake.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,6 @@ func (s *secureSession) runHandshake(ctx context.Context) (err error) {
return fmt.Errorf("error initializing handshake state: %w", err)
}

payload, err := s.generateHandshakePayload(kp)
if err != nil {
return err
}

// set a deadline to complete the handshake, if one has been supplied.
// clear it after we're done.
if deadline, ok := ctx.Deadline(); ok {
Expand All @@ -82,7 +77,7 @@ func (s *secureSession) runHandshake(ctx context.Context) (err error) {
// will be the size of the maximum handshake message for the Noise XX pattern.
// Also, since we prefix every noise handshake message with its length, we need to account for
// it when we fetch the buffer from the pool
maxMsgSize := 2*noise.DH25519.DHLen() + len(payload) + 2*chacha20poly1305.Overhead
maxMsgSize := 2*noise.DH25519.DHLen() + 2*chacha20poly1305.Overhead + 1024 /* payload */
hbuf := pool.Get(maxMsgSize + LengthPrefixLength)
defer pool.Put(hbuf)

Expand All @@ -102,12 +97,22 @@ func (s *secureSession) runHandshake(ctx context.Context) (err error) {
if err != nil {
return fmt.Errorf("error reading handshake message: %w", err)
}
if err := s.handleRemoteHandshakePayload(plaintext, hs.PeerStatic()); err != nil {
rcvdEd, err := s.handleRemoteHandshakePayload(plaintext, hs.PeerStatic())
if err != nil {
return err
}
if s.earlyDataHandler != nil {
if err := s.earlyDataHandler.Received(ctx, s.insecureConn, rcvdEd); err != nil {
return err
}
}

// stage 2 //
// Handshake Msg Len = len(DHT static key) + MAC(static key is encrypted) + len(Payload) + MAC(payload is encrypted)
payload, err := s.generateHandshakePayload(kp, nil)
if err != nil {
return err
}
if err := s.sendHandshakeMessage(hs, payload, hbuf); err != nil {
return fmt.Errorf("error sending handshake message: %w", err)
}
Expand All @@ -127,6 +132,14 @@ func (s *secureSession) runHandshake(ctx context.Context) (err error) {
// stage 1 //
// Handshake Msg Len = len(DH ephemeral key) + len(DHT static key) + MAC(static key is encrypted) + len(Payload) +
// MAC(payload is encrypted)
var ed []byte
if s.earlyDataHandler != nil {
ed = s.earlyDataHandler.Send(ctx, s.insecureConn, s.remoteID)
}
payload, err := s.generateHandshakePayload(kp, ed)
if err != nil {
return err
}
if err := s.sendHandshakeMessage(hs, payload, hbuf); err != nil {
return fmt.Errorf("error sending handshake message: %w", err)
}
Expand All @@ -136,7 +149,9 @@ func (s *secureSession) runHandshake(ctx context.Context) (err error) {
if err != nil {
return fmt.Errorf("error reading handshake message: %w", err)
}
return s.handleRemoteHandshakePayload(plaintext, hs.PeerStatic())
// we don't expect any early data on this message
_, err = s.handleRemoteHandshakePayload(plaintext, hs.PeerStatic())
return err
}
}

Expand Down Expand Up @@ -214,8 +229,8 @@ func (s *secureSession) readHandshakeMessage(hs *noise.HandshakeState) ([]byte,

// generateHandshakePayload creates a libp2p handshake payload with a
// signature of our static noise key.
func (s *secureSession) generateHandshakePayload(localStatic noise.DHKey) ([]byte, error) {
// obtain the public key from the handshake session so we can sign it with
func (s *secureSession) generateHandshakePayload(localStatic noise.DHKey, data []byte) ([]byte, error) {
// obtain the public key from the handshake session, so we can sign it with
// our libp2p secret key.
localKeyRaw, err := crypto.MarshalPublicKey(s.LocalPublicKey())
if err != nil {
Expand All @@ -230,10 +245,11 @@ func (s *secureSession) generateHandshakePayload(localStatic noise.DHKey) ([]byt
}

// create payload
payload := new(pb.NoiseHandshakePayload)
payload.IdentityKey = localKeyRaw
payload.IdentitySig = signedPayload
payloadEnc, err := proto.Marshal(payload)
payloadEnc, err := proto.Marshal(&pb.NoiseHandshakePayload{
IdentityKey: localKeyRaw,
IdentitySig: signedPayload,
Data: data,
})
if err != nil {
return nil, fmt.Errorf("error marshaling handshake payload: %w", err)
}
Expand All @@ -242,44 +258,45 @@ func (s *secureSession) generateHandshakePayload(localStatic noise.DHKey) ([]byt

// handleRemoteHandshakePayload unmarshals the handshake payload object sent
// by the remote peer and validates the signature against the peer's static Noise key.
func (s *secureSession) handleRemoteHandshakePayload(payload []byte, remoteStatic []byte) error {
// It returns the data attached to the payload.
func (s *secureSession) handleRemoteHandshakePayload(payload []byte, remoteStatic []byte) ([]byte, error) {
// unmarshal payload
nhp := new(pb.NoiseHandshakePayload)
err := proto.Unmarshal(payload, nhp)
if err != nil {
return fmt.Errorf("error unmarshaling remote handshake payload: %w", err)
return nil, fmt.Errorf("error unmarshaling remote handshake payload: %w", err)
}

// unpack remote peer's public libp2p key
remotePubKey, err := crypto.UnmarshalPublicKey(nhp.GetIdentityKey())
if err != nil {
return err
return nil, err
}
id, err := peer.IDFromPublicKey(remotePubKey)
if err != nil {
return err
return nil, err
}

// check the peer ID for:
// * all outbound connection
// * inbound connections, if we know which peer we want to connect to (SecureInbound called with a peer ID)
if (s.initiator && s.remoteID != id) || (!s.initiator && s.remoteID != "" && s.remoteID != id) {
// use Pretty() as it produces the full b58-encoded string, rather than abbreviated forms.
return fmt.Errorf("peer id mismatch: expected %s, but remote key matches %s", s.remoteID.Pretty(), id.Pretty())
return nil, fmt.Errorf("peer id mismatch: expected %s, but remote key matches %s", s.remoteID.Pretty(), id.Pretty())
}

// verify payload is signed by asserted remote libp2p key.
sig := nhp.GetIdentitySig()
msg := append([]byte(payloadSigPrefix), remoteStatic...)
ok, err := remotePubKey.Verify(msg, sig)
if err != nil {
return fmt.Errorf("error verifying signature: %w", err)
return nil, fmt.Errorf("error verifying signature: %w", err)
} else if !ok {
return fmt.Errorf("handshake signature invalid")
return nil, fmt.Errorf("handshake signature invalid")
}

// set remote peer key and id
s.remoteID = id
s.remoteKey = remotePubKey
return nil
return nhp.Data, nil
}
113 changes: 72 additions & 41 deletions p2p/security/noise/transport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -447,67 +447,98 @@ func (e *earlyDataHandler) Received(ctx context.Context, conn net.Conn, data []b
}

func TestEarlyDataAccepted(t *testing.T) {
handshake := func(t *testing.T, client, server EarlyDataHandler) {
t.Helper()
initTransport, err := newTestTransport(t, crypto.Ed25519, 2048).WithSessionOptions(EarlyData(client))
require.NoError(t, err)
tpt := newTestTransport(t, crypto.Ed25519, 2048)
respTransport, err := tpt.WithSessionOptions(EarlyData(server))
require.NoError(t, err)

initConn, respConn := newConnPair(t)

errChan := make(chan error)
go func() {
_, err := respTransport.SecureInbound(context.Background(), initConn, "")
errChan <- err
}()

conn, err := initTransport.SecureOutbound(context.Background(), respConn, tpt.localID)
require.NoError(t, err)
defer conn.Close()
}

var receivedEarlyData []byte
serverEDH := &earlyDataHandler{
receivingEDH := &earlyDataHandler{
received: func(_ context.Context, _ net.Conn, data []byte) error {
receivedEarlyData = data
return nil
},
}
clientEDH := &earlyDataHandler{
sendingEDH := &earlyDataHandler{
send: func(ctx context.Context, conn net.Conn, id peer.ID) []byte { return []byte("foobar") },
}
initTransport, err := newTestTransport(t, crypto.Ed25519, 2048).WithSessionOptions(EarlyData(clientEDH))
require.NoError(t, err)
tpt := newTestTransport(t, crypto.Ed25519, 2048)
respTransport, err := tpt.WithSessionOptions(EarlyData(serverEDH))
require.NoError(t, err)

initConn, respConn := newConnPair(t)

errChan := make(chan error)
go func() {
_, err := respTransport.SecureInbound(context.Background(), initConn, "")
errChan <- err
}()

conn, err := initTransport.SecureOutbound(context.Background(), respConn, tpt.localID)
require.NoError(t, err)
defer conn.Close()
t.Run("client sending", func(t *testing.T) {
handshake(t, sendingEDH, receivingEDH)
require.Equal(t, []byte("foobar"), receivedEarlyData)
receivedEarlyData = nil
})

require.Equal(t, []byte("foobar"), receivedEarlyData)
t.Run("server sending", func(t *testing.T) {
handshake(t, receivingEDH, sendingEDH)
require.Equal(t, []byte("foobar"), receivedEarlyData)
receivedEarlyData = nil
})
}

func TestEarlyDataRejected(t *testing.T) {
serverEDH := &earlyDataHandler{
handshake := func(t *testing.T, client, server EarlyDataHandler) (clientErr, serverErr error) {
t.Helper()
initTransport, err := newTestTransport(t, crypto.Ed25519, 2048).WithSessionOptions(EarlyData(client))
require.NoError(t, err)
tpt := newTestTransport(t, crypto.Ed25519, 2048)
respTransport, err := tpt.WithSessionOptions(EarlyData(server))
require.NoError(t, err)

initConn, respConn := newConnPair(t)

errChan := make(chan error)
go func() {
_, err := respTransport.SecureInbound(context.Background(), initConn, "")
errChan <- err
}()

_, clientErr = initTransport.SecureOutbound(context.Background(), respConn, tpt.localID)

select {
case <-time.After(500 * time.Millisecond):
t.Fatal("timeout")
case err := <-errChan:
serverErr = err
}
return
}

receivingEDH := &earlyDataHandler{
received: func(_ context.Context, _ net.Conn, data []byte) error { return errors.New("nope") },
}
clientEDH := &earlyDataHandler{
sendingEDH := &earlyDataHandler{
send: func(ctx context.Context, conn net.Conn, id peer.ID) []byte { return []byte("foobar") },
}
initTransport, err := newTestTransport(t, crypto.Ed25519, 2048).WithSessionOptions(EarlyData(clientEDH))
require.NoError(t, err)
tpt := newTestTransport(t, crypto.Ed25519, 2048)
respTransport, err := tpt.WithSessionOptions(EarlyData(serverEDH))
require.NoError(t, err)

initConn, respConn := newConnPair(t)

errChan := make(chan error)
go func() {
_, err := respTransport.SecureInbound(context.Background(), initConn, "")
errChan <- err
}()
t.Run("client sending", func(t *testing.T) {
clientErr, serverErr := handshake(t, sendingEDH, receivingEDH)
require.Error(t, clientErr)
require.EqualError(t, serverErr, "nope")

_, err = initTransport.SecureOutbound(context.Background(), respConn, tpt.localID)
require.Error(t, err)
})

select {
case <-time.After(500 * time.Millisecond):
t.Fatal("timeout")
case err := <-errChan:
require.EqualError(t, err, "nope")
}
t.Run("server sending", func(t *testing.T) {
clientErr, serverErr := handshake(t, receivingEDH, sendingEDH)
require.Error(t, serverErr)
require.EqualError(t, clientErr, "nope")
})
}

func TestEarlyDataAcceptedWithNoHandler(t *testing.T) {
Expand Down

0 comments on commit 65270f1

Please sign in to comment.