diff --git a/internal/quic/conn_close.go b/internal/quic/conn_close.go index daf425b76..a9ef0db5e 100644 --- a/internal/quic/conn_close.go +++ b/internal/quic/conn_close.go @@ -62,7 +62,7 @@ func (c *Conn) lifetimeAdvance(now time.Time) (done bool) { c.lifetime.drainEndTime = time.Time{} if c.lifetime.finalErr == nil { // The peer never responded to our CONNECTION_CLOSE. - c.enterDraining(errNoPeerResponse) + c.enterDraining(now, errNoPeerResponse) } return true } @@ -152,11 +152,17 @@ func (c *Conn) sendOK(now time.Time) bool { } // enterDraining enters the draining state. -func (c *Conn) enterDraining(err error) { +func (c *Conn) enterDraining(now time.Time, err error) { if c.isDraining() { return } - if e, ok := c.lifetime.localErr.(localTransportError); ok && e.code != errNo { + if err == errStatelessReset { + // If we've received a stateless reset, then we must not send a CONNECTION_CLOSE. + // Setting connCloseSentTime here prevents us from doing so. + c.lifetime.finalErr = errStatelessReset + c.lifetime.localErr = errStatelessReset + c.lifetime.connCloseSentTime = now + } else if e, ok := c.lifetime.localErr.(localTransportError); ok && e.code != errNo { // If we've terminated the connection due to a peer protocol violation, // record the final error on the connection as our reason for termination. c.lifetime.finalErr = c.lifetime.localErr @@ -239,14 +245,14 @@ func (c *Conn) abort(now time.Time, err error) { // The connection does not send a CONNECTION_CLOSE, and skips the draining period. func (c *Conn) abortImmediately(now time.Time, err error) { c.abort(now, err) - c.enterDraining(err) + c.enterDraining(now, err) c.exited = true } // exit fully terminates a connection immediately. func (c *Conn) exit() { c.sendMsg(func(now time.Time, c *Conn) { - c.enterDraining(errors.New("connection closed")) + c.enterDraining(now, errors.New("connection closed")) c.exited = true }) } diff --git a/internal/quic/conn_recv.go b/internal/quic/conn_recv.go index 8fa3a3906..896c6d74e 100644 --- a/internal/quic/conn_recv.go +++ b/internal/quic/conn_recv.go @@ -56,7 +56,7 @@ func (c *Conn) handleDatagram(now time.Time, dgram *datagram) { if len(buf) == len(dgram.b) && len(buf) > statelessResetTokenLen { var token statelessResetToken copy(token[:], buf[len(buf)-len(token):]) - c.handleStatelessReset(token) + c.handleStatelessReset(now, token) } // Invalid data at the end of a datagram is ignored. break @@ -525,7 +525,7 @@ func (c *Conn) handleConnectionCloseTransportFrame(now time.Time, payload []byte if n < 0 { return -1 } - c.enterDraining(peerTransportError{code: code, reason: reason}) + c.enterDraining(now, peerTransportError{code: code, reason: reason}) return n } @@ -534,7 +534,7 @@ func (c *Conn) handleConnectionCloseApplicationFrame(now time.Time, payload []by if n < 0 { return -1 } - c.enterDraining(&ApplicationError{Code: code, Reason: reason}) + c.enterDraining(now, &ApplicationError{Code: code, Reason: reason}) return n } @@ -556,9 +556,9 @@ func (c *Conn) handleHandshakeDoneFrame(now time.Time, space numberSpace, payloa var errStatelessReset = errors.New("received stateless reset") -func (c *Conn) handleStatelessReset(resetToken statelessResetToken) { +func (c *Conn) handleStatelessReset(now time.Time, resetToken statelessResetToken) { if !c.connIDState.isValidStatelessResetToken(resetToken) { return } - c.enterDraining(errStatelessReset) + c.enterDraining(now, errStatelessReset) } diff --git a/internal/quic/listener.go b/internal/quic/listener.go index 8b31dcbe8..ca8f9b25a 100644 --- a/internal/quic/listener.go +++ b/internal/quic/listener.go @@ -253,12 +253,18 @@ func (l *Listener) handleUnknownDestinationDatagram(m *datagram) { if len(m.b) < minimumValidPacketSize { return } + var now time.Time + if l.testHooks != nil { + now = l.testHooks.timeNow() + } else { + now = time.Now() + } // Check to see if this is a stateless reset. var token statelessResetToken copy(token[:], m.b[len(m.b)-len(token):]) if c := l.connsMap.byResetToken[token]; c != nil { c.sendMsg(func(now time.Time, c *Conn) { - c.handleStatelessReset(token) + c.handleStatelessReset(now, token) }) return } @@ -290,12 +296,6 @@ func (l *Listener) handleUnknownDestinationDatagram(m *datagram) { // https://www.rfc-editor.org/rfc/rfc9000#section-10.3-16 return } - var now time.Time - if l.testHooks != nil { - now = l.testHooks.timeNow() - } else { - now = time.Now() - } cids := newServerConnIDs{ srcConnID: p.srcConnID, dstConnID: p.dstConnID, diff --git a/internal/quic/stateless_reset_test.go b/internal/quic/stateless_reset_test.go index b12e97560..8a16597c4 100644 --- a/internal/quic/stateless_reset_test.go +++ b/internal/quic/stateless_reset_test.go @@ -14,6 +14,7 @@ import ( "errors" "net/netip" "testing" + "time" ) func TestStatelessResetClientSendsStatelessResetTokenTransportParameter(t *testing.T) { @@ -154,7 +155,9 @@ func TestStatelessResetSuccessfulNewConnectionID(t *testing.T) { if err := tc.conn.Wait(canceledContext()); !errors.Is(err, errStatelessReset) { t.Errorf("conn.Wait() = %v, want errStatelessReset", err) } - tc.wantIdle("closed connection is idle") + tc.wantIdle("closed connection is idle in draining") + tc.advance(1 * time.Second) // long enough to exit the draining state + tc.wantIdle("closed connection is idle after draining") } func TestStatelessResetSuccessfulTransportParameter(t *testing.T) {