Skip to content
This repository has been archived by the owner on Aug 19, 2022. It is now read-only.

handle TCP simultaneous open (option 4) #38

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 119 additions & 0 deletions conn.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
package libp2ptls

import (
"bytes"
"crypto/tls"
"errors"
"fmt"
"io"
"net"

ci "github.com/libp2p/go-libp2p-core/crypto"
"github.com/libp2p/go-libp2p-core/peer"
Expand Down Expand Up @@ -35,3 +40,117 @@ func (c *conn) RemotePeer() peer.ID {
func (c *conn) RemotePublicKey() ci.PubKey {
return c.remotePubKey
}

const (
recordTypeHandshake byte = 22
versionTLS13 = 0x0304
maxCiphertextTLS13 = 16384 + 256 // maximum ciphertext length in TLS 1.3
maxHandshake = 65536 // maximum handshake we support (protocol max is 16 MB)
)

var errSimultaneousConnect = errors.New("detected TCP simultaneous connect")

type teeConn struct {
vyzo marked this conversation as resolved.
Show resolved Hide resolved
net.Conn
buf *bytes.Buffer
}

func newTeeConn(c net.Conn, buf *bytes.Buffer) net.Conn {
return &teeConn{Conn: c, buf: buf}
}

func (c *teeConn) Read(b []byte) (int, error) {
n, err := c.Conn.Read(b)
c.buf.Write(b[:n])
return n, err
}

type wrappedConn struct {
// Before reading the first handshake message, this is a *teeConn.
// After reading the first handshake message, we switch it to the rawConn.
net.Conn

rawConn net.Conn
hasReadFirstMessage bool
raw *bytes.Buffer // contains a copy of every byte of the first handshake message we read from the wire

hand bytes.Buffer // used to store the first handshake message until we've completely read it
}

func newWrappedConn(c net.Conn) net.Conn {
wc := &wrappedConn{
raw: &bytes.Buffer{},
rawConn: c,
}
wc.Conn = newTeeConn(c, wc.raw)
return wc
}

func (c *wrappedConn) Read(b []byte) (int, error) {
if c.hasReadFirstMessage {
return c.Conn.Read(b)
vyzo marked this conversation as resolved.
Show resolved Hide resolved
}

// We read the first handshake message, and it was not a ClientHello.
// We now need to feed all the bytes we read from the wire into the TLS stack,
// so it can proceed with the handshake.
if c.raw.Len() > 0 {
n, err := c.raw.Read(b)
if err == io.EOF || c.raw.Len() == 0 {
c.raw = nil
c.Conn = c.rawConn
c.hasReadFirstMessage = true
err = nil
}
return n, err
}

mes, err := c.readFirstHandshakeMessage()
if err != nil {
return 0, err
}

switch mes[0] {
case 1: // ClientHello
return 0, errSimultaneousConnect
case 2: // ServerHello
return c.Read(b)
default:
return 0, fmt.Errorf("unexpected message type: %d", mes[0])
}
}

func (c *wrappedConn) readFirstHandshakeMessage() ([]byte, error) {
for c.hand.Len() < 4 {
if err := c.readRecord(); err != nil {
return nil, err
}
}
data := c.hand.Bytes()
n := int(data[1])<<16 | int(data[2])<<8 | int(data[3])
if n > maxHandshake {
return nil, fmt.Errorf("tls: handshake message of length %d bytes exceeds maximum of %d bytes", n, maxHandshake)
}
for c.hand.Len() < 4+n {
if err := c.readRecord(); err != nil {
return nil, err
}
}
return c.hand.Next(4 + n), nil
}

func (c *wrappedConn) readRecord() error {
hdr := make([]byte, 5)
if _, err := io.ReadFull(c.Conn, hdr); err != nil {
return err
}
if hdr[0] != recordTypeHandshake {
return errors.New("expected a handshake record")
}
n := int(hdr[3])<<8 | int(hdr[4])
if n > maxCiphertextTLS13 {
return fmt.Errorf("oversized record received with length %d", n)
}
_, err := io.CopyN(&c.hand, c.Conn, int64(n))
return err
}
10 changes: 10 additions & 0 deletions crypto.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
package libp2ptls

import (
"bytes"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/sha256"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
Expand Down Expand Up @@ -220,3 +222,11 @@ func preferServerCipherSuites() bool {
)
return !hasGCMAsm
}

// Compare two peer IDs by their SHA256 hash.
// The result will be 0 if H(a) == H(b), -1 if H(a) < H(b), and +1 if H(a) > H(b).
func comparePeerIDs(p1, p2 peer.ID) int {
p1Hash := sha256.Sum256([]byte(p1))
p2Hash := sha256.Sum256([]byte(p2))
return bytes.Compare(p1Hash[:], p2Hash[:])
}
23 changes: 20 additions & 3 deletions transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,28 @@ func (t *Transport) SecureInbound(ctx context.Context, insecure net.Conn) (sec.S
// notice this after 1 RTT when calling Read.
func (t *Transport) SecureOutbound(ctx context.Context, insecure net.Conn, p peer.ID) (sec.SecureConn, error) {
config, keyCh := t.identity.ConfigForPeer(p)
cs, err := t.handshake(ctx, tls.Client(insecure, config), keyCh)
if err != nil {
conn, err := t.handshake(ctx, tls.Client(newWrappedConn(insecure), config), keyCh)
if err == errSimultaneousConnect {
switch comparePeerIDs(t.localPeer, p) {
case 0:
return nil, errors.New("tried to simultaneous connect to oneself")
case -1:
// SHA256(our peer ID) is smaller than SHA256(their peer ID).
// We're the client in the next connection attempt.
config, keyCh := t.identity.ConfigForPeer(p)
return t.handshake(ctx, tls.Client(insecure, config), keyCh)
case 1:
// SHA256(our peer ID) is larger than SHA256(their peer ID).
// We're the server in the next connection attempt.
config, keyCh := t.identity.ConfigForPeer(p)
return t.handshake(ctx, tls.Server(insecure, config), keyCh)
default:
panic("unexpected peer ID comparison result")
}
} else if err != nil {
insecure.Close()
}
return cs, err
return conn, err
}

func (t *Transport) handshake(
Expand Down
57 changes: 57 additions & 0 deletions transport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"math/big"
mrand "math/rand"
"net"
"reflect"
"time"

"github.com/onsi/gomega/gbytes"
Expand Down Expand Up @@ -188,6 +189,62 @@ var _ = Describe("Transport", func() {
Eventually(done).Should(BeClosed())
})

It("handles simultaneous open", func() {
// Avoid confusion regarding the naming.
p1, p1Key := serverID, serverKey
p2, p2Key := clientID, clientKey

// We use a normal dial / listen to establish the TCP connection,
// but we then start two clients.
c1raw, c2raw := connect()

c1Transport, err := New(p1Key)
Expect(err).ToNot(HaveOccurred())
c2Transport, err := New(p2Key)
Expect(err).ToNot(HaveOccurred())

c1ConnChan := make(chan sec.SecureConn, 1)
go func() {
defer GinkgoRecover()
conn, err := c1Transport.SecureOutbound(context.Background(), c1raw, p2)
Expect(err).ToNot(HaveOccurred())
c1ConnChan <- conn
}()

c2, err := c2Transport.SecureOutbound(context.Background(), c2raw, p1)
Expect(err).ToNot(HaveOccurred())
defer c2.Close()
var c1 sec.SecureConn
Eventually(c1ConnChan).Should(Receive(&c1))
defer c1.Close()

// check that the peers are in the correct roles
isClient := func(c sec.SecureConn) bool {
// the isClient field of the tls.Conn will tell us who is client and server
return reflect.ValueOf(c.(*conn).Conn).Elem().FieldByName("isClient").Bool()
}
switch comparePeerIDs(p1, p2) {
case -1:
// H(p1) < H(p2) => p1 acts as a client, p2 as a server
Expect(isClient(c1)).To(BeTrue())
Expect(isClient(c2)).To(BeFalse())
case 1:
// H(p1) > H(p2) => p1 acts as a server, p2 as a client
Expect(isClient(c1)).To(BeFalse())
Expect(isClient(c2)).To(BeTrue())
default:
Fail("unexpected peer comparison result")
}

// exchange some data
_, err = c1.Write([]byte("foobar"))
Expect(err).ToNot(HaveOccurred())
b := make([]byte, 6)
_, err = c2.Read(b)
Expect(err).ToNot(HaveOccurred())
Expect(string(b)).To(Equal("foobar"))
})

Context("invalid certificates", func() {
invalidateCertChain := func(identity *Identity) {
switch identity.config.Certificates[0].PrivateKey.(type) {
Expand Down