diff --git a/p2p/security/noise/crypto_test.go b/p2p/security/noise/crypto_test.go index 9b7d390829..1ca11476f4 100644 --- a/p2p/security/noise/crypto_test.go +++ b/p2p/security/noise/crypto_test.go @@ -93,7 +93,7 @@ func TestCryptoFailsIfHandshakeIncomplete(t *testing.T) { init, resp := net.Pipe() _ = resp.Close() - session, _ := newSecureSession(initTransport, context.TODO(), init, "remote-peer", nil, nil, true) + session, _ := newSecureSession(initTransport, context.TODO(), init, "remote-peer", nil, nil, nil, true) _, err := session.encrypt(nil, []byte("hi")) if err == nil { t.Error("expected encryption error when handshake incomplete") diff --git a/p2p/security/noise/handshake.go b/p2p/security/noise/handshake.go index 8cf306b0f5..db782aace6 100644 --- a/p2p/security/noise/handshake.go +++ b/p2p/security/noise/handshake.go @@ -85,8 +85,8 @@ func (s *secureSession) runHandshake(ctx context.Context) (err error) { // stage 0 // // Handshake Msg Len = len(DH ephemeral key) var ed []byte - if s.earlyDataHandler != nil { - ed = s.earlyDataHandler.Send(ctx, s.insecureConn, s.remoteID) + if s.initiatorEarlyDataHandler != nil { + ed = s.initiatorEarlyDataHandler.Send(ctx, s.insecureConn, s.remoteID) } if err := s.sendHandshakeMessage(hs, ed, hbuf); err != nil { return fmt.Errorf("error sending handshake message: %w", err) @@ -101,8 +101,8 @@ func (s *secureSession) runHandshake(ctx context.Context) (err error) { if err != nil { return err } - if s.earlyDataHandler != nil { - if err := s.earlyDataHandler.Received(ctx, s.insecureConn, rcvdEd); err != nil { + if s.initiatorEarlyDataHandler != nil { + if err := s.initiatorEarlyDataHandler.Received(ctx, s.insecureConn, rcvdEd); err != nil { return err } } @@ -123,8 +123,8 @@ func (s *secureSession) runHandshake(ctx context.Context) (err error) { if err != nil { return fmt.Errorf("error reading handshake message: %w", err) } - if s.earlyDataHandler != nil { - if err := s.earlyDataHandler.Received(ctx, s.insecureConn, initialPayload); err != nil { + if s.responderEarlyDataHandler != nil { + if err := s.responderEarlyDataHandler.Received(ctx, s.insecureConn, initialPayload); err != nil { return err } } @@ -133,8 +133,8 @@ func (s *secureSession) runHandshake(ctx context.Context) (err error) { // Handshake Msg Len = len(DH ephemeral key) + len(DHT static key) + MAC(static key is encrypted) + len(Payload) + // MAC(payload is encrypted) var ed []byte - if s.earlyDataHandler != nil { - ed = s.earlyDataHandler.Send(ctx, s.insecureConn, s.remoteID) + if s.responderEarlyDataHandler != nil { + ed = s.responderEarlyDataHandler.Send(ctx, s.insecureConn, s.remoteID) } payload, err := s.generateHandshakePayload(kp, ed) if err != nil { diff --git a/p2p/security/noise/session.go b/p2p/security/noise/session.go index 5e3d0956cf..692c9174ad 100644 --- a/p2p/security/noise/session.go +++ b/p2p/security/noise/session.go @@ -36,22 +36,24 @@ type secureSession struct { dec *noise.CipherState // noise prologue - prologue []byte - earlyDataHandler EarlyDataHandler + prologue []byte + + initiatorEarlyDataHandler, responderEarlyDataHandler EarlyDataHandler } // newSecureSession creates a Noise session over the given insecureConn Conn, using // the libp2p identity keypair from the given Transport. -func newSecureSession(tpt *Transport, ctx context.Context, insecure net.Conn, remote peer.ID, prologue []byte, edh EarlyDataHandler, initiator bool) (*secureSession, error) { +func newSecureSession(tpt *Transport, ctx context.Context, insecure net.Conn, remote peer.ID, prologue []byte, initiatorEDH, responderEDH EarlyDataHandler, initiator bool) (*secureSession, error) { s := &secureSession{ - insecureConn: insecure, - insecureReader: bufio.NewReader(insecure), - initiator: initiator, - localID: tpt.localID, - localKey: tpt.privateKey, - remoteID: remote, - prologue: prologue, - earlyDataHandler: edh, + insecureConn: insecure, + insecureReader: bufio.NewReader(insecure), + initiator: initiator, + localID: tpt.localID, + localKey: tpt.privateKey, + remoteID: remote, + prologue: prologue, + initiatorEarlyDataHandler: initiatorEDH, + responderEarlyDataHandler: responderEDH, } // the go-routine we create to run the handshake will diff --git a/p2p/security/noise/session_transport.go b/p2p/security/noise/session_transport.go index 973f6facf2..12313c1513 100644 --- a/p2p/security/noise/session_transport.go +++ b/p2p/security/noise/session_transport.go @@ -31,9 +31,10 @@ type EarlyDataHandler interface { Received(context.Context, net.Conn, []byte) error } -func EarlyData(h EarlyDataHandler) SessionOption { +func EarlyData(initiator, responder EarlyDataHandler) SessionOption { return func(s *SessionTransport) error { - s.earlyDataHandler = h + s.initiatorEarlyDataHandler = initiator + s.responderEarlyDataHandler = responder return nil } } @@ -45,14 +46,15 @@ var _ sec.SecureTransport = &SessionTransport{} type SessionTransport struct { t *Transport // options - prologue []byte - earlyDataHandler EarlyDataHandler + prologue []byte + + initiatorEarlyDataHandler, responderEarlyDataHandler EarlyDataHandler } // SecureInbound runs the Noise handshake as the responder. // If p is empty, connections from any peer are accepted. func (i *SessionTransport) SecureInbound(ctx context.Context, insecure net.Conn, p peer.ID) (sec.SecureConn, error) { - c, err := newSecureSession(i.t, ctx, insecure, p, i.prologue, i.earlyDataHandler, false) + c, err := newSecureSession(i.t, ctx, insecure, p, i.prologue, i.initiatorEarlyDataHandler, i.responderEarlyDataHandler, false) if err != nil { addr, maErr := manet.FromNetAddr(insecure.RemoteAddr()) if maErr == nil { @@ -64,5 +66,5 @@ func (i *SessionTransport) SecureInbound(ctx context.Context, insecure net.Conn, // SecureOutbound runs the Noise handshake as the initiator. func (i *SessionTransport) SecureOutbound(ctx context.Context, insecure net.Conn, p peer.ID) (sec.SecureConn, error) { - return newSecureSession(i.t, ctx, insecure, p, i.prologue, i.earlyDataHandler, true) + return newSecureSession(i.t, ctx, insecure, p, i.prologue, i.initiatorEarlyDataHandler, i.responderEarlyDataHandler, true) } diff --git a/p2p/security/noise/transport.go b/p2p/security/noise/transport.go index bd66d0fdd1..c6923698cc 100644 --- a/p2p/security/noise/transport.go +++ b/p2p/security/noise/transport.go @@ -41,7 +41,7 @@ func New(privkey crypto.PrivKey) (*Transport, error) { // SecureInbound runs the Noise handshake as the responder. // If p is empty, connections from any peer are accepted. func (t *Transport) SecureInbound(ctx context.Context, insecure net.Conn, p peer.ID) (sec.SecureConn, error) { - c, err := newSecureSession(t, ctx, insecure, p, nil, nil, false) + c, err := newSecureSession(t, ctx, insecure, p, nil, nil, nil, false) if err != nil { addr, maErr := manet.FromNetAddr(insecure.RemoteAddr()) if maErr == nil { @@ -53,7 +53,7 @@ func (t *Transport) SecureInbound(ctx context.Context, insecure net.Conn, p peer // SecureOutbound runs the Noise handshake as the initiator. func (t *Transport) SecureOutbound(ctx context.Context, insecure net.Conn, p peer.ID) (sec.SecureConn, error) { - return newSecureSession(t, ctx, insecure, p, nil, nil, true) + return newSecureSession(t, ctx, insecure, p, nil, nil, nil, true) } func (t *Transport) WithSessionOptions(opts ...SessionOption) (sec.SecureTransport, error) { diff --git a/p2p/security/noise/transport_test.go b/p2p/security/noise/transport_test.go index 3a791668b7..6ebf592e74 100644 --- a/p2p/security/noise/transport_test.go +++ b/p2p/security/noise/transport_test.go @@ -449,10 +449,10 @@ func (e *earlyDataHandler) Received(ctx context.Context, conn net.Conn, data []b func TestEarlyDataAccepted(t *testing.T) { handshake := func(t *testing.T, client, server EarlyDataHandler) { t.Helper() - initTransport, err := newTestTransport(t, crypto.Ed25519, 2048).WithSessionOptions(EarlyData(client)) + initTransport, err := newTestTransport(t, crypto.Ed25519, 2048).WithSessionOptions(EarlyData(client, nil)) require.NoError(t, err) tpt := newTestTransport(t, crypto.Ed25519, 2048) - respTransport, err := tpt.WithSessionOptions(EarlyData(server)) + respTransport, err := tpt.WithSessionOptions(EarlyData(nil, server)) require.NoError(t, err) initConn, respConn := newConnPair(t) @@ -495,10 +495,10 @@ func TestEarlyDataAccepted(t *testing.T) { func TestEarlyDataRejected(t *testing.T) { handshake := func(t *testing.T, client, server EarlyDataHandler) (clientErr, serverErr error) { t.Helper() - initTransport, err := newTestTransport(t, crypto.Ed25519, 2048).WithSessionOptions(EarlyData(client)) + initTransport, err := newTestTransport(t, crypto.Ed25519, 2048).WithSessionOptions(EarlyData(client, nil)) require.NoError(t, err) tpt := newTestTransport(t, crypto.Ed25519, 2048) - respTransport, err := tpt.WithSessionOptions(EarlyData(server)) + respTransport, err := tpt.WithSessionOptions(EarlyData(nil, server)) require.NoError(t, err) initConn, respConn := newConnPair(t) @@ -545,7 +545,7 @@ func TestEarlyDataAcceptedWithNoHandler(t *testing.T) { clientEDH := &earlyDataHandler{ send: func(ctx context.Context, conn net.Conn, id peer.ID) []byte { return []byte("foobar") }, } - initTransport, err := newTestTransport(t, crypto.Ed25519, 2048).WithSessionOptions(EarlyData(clientEDH)) + initTransport, err := newTestTransport(t, crypto.Ed25519, 2048).WithSessionOptions(EarlyData(clientEDH, nil)) require.NoError(t, err) respTransport := newTestTransport(t, crypto.Ed25519, 2048) diff --git a/p2p/transport/webtransport/listener.go b/p2p/transport/webtransport/listener.go index ca8fe1cf35..5c51b89e2e 100644 --- a/p2p/transport/webtransport/listener.go +++ b/p2p/transport/webtransport/listener.go @@ -190,7 +190,7 @@ func (l *listener) handshake(ctx context.Context, sess *webtransport.Session) (* if err != nil { return nil, err } - n, err := l.noise.WithSessionOptions(noise.EarlyData(newEarlyDataReceiver(l.checkEarlyData))) + n, err := l.noise.WithSessionOptions(noise.EarlyData(nil, newEarlyDataReceiver(l.checkEarlyData))) if err != nil { return nil, fmt.Errorf("failed to initialize Noise session: %w", err) } diff --git a/p2p/transport/webtransport/transport.go b/p2p/transport/webtransport/transport.go index 03cacb133c..67c2c33840 100644 --- a/p2p/transport/webtransport/transport.go +++ b/p2p/transport/webtransport/transport.go @@ -208,7 +208,7 @@ func (t *transport) upgrade(ctx context.Context, sess *webtransport.Session, p p if err != nil { return nil, fmt.Errorf("failed to marshal WebTransport protobuf: %w", err) } - n, err := t.noise.WithSessionOptions(noise.EarlyData(newEarlyDataSender(msgBytes))) + n, err := t.noise.WithSessionOptions(noise.EarlyData(newEarlyDataSender(msgBytes), nil)) if err != nil { return nil, fmt.Errorf("failed to create Noise transport: %w", err) }