diff --git a/go.mod b/go.mod index cfb84546..f1873dc6 100644 --- a/go.mod +++ b/go.mod @@ -28,7 +28,7 @@ require ( github.com/prometheus/client_model v0.6.1 // indirect github.com/prometheus/common v0.61.0 // indirect github.com/prometheus/procfs v0.15.1 // indirect - github.com/quic-go/quic-go v0.48.2 // indirect + github.com/quic-go/quic-go v0.49.1 // indirect github.com/vishvananda/netns v0.0.5 // indirect go.uber.org/mock v0.5.0 // indirect golang.org/x/crypto v0.32.0 // indirect diff --git a/go.sum b/go.sum index 4c017b7f..f812e87f 100644 --- a/go.sum +++ b/go.sum @@ -56,8 +56,8 @@ github.com/prometheus/common v0.61.0 h1:3gv/GThfX0cV2lpO7gkTUwZru38mxevy90Bj8YFS github.com/prometheus/common v0.61.0/go.mod h1:zr29OCN/2BsJRaFwG8QOBr41D6kkchKbpeNH7pAjb/s= github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc= github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk= -github.com/quic-go/quic-go v0.48.2 h1:wsKXZPeGWpMpCGSWqOcqpW2wZYic/8T3aqiOID0/KWE= -github.com/quic-go/quic-go v0.48.2/go.mod h1:yBgs3rWBOADpga7F+jJsb6Ybg1LSYiQvwWlLX+/6HMs= +github.com/quic-go/quic-go v0.49.1 h1:e5JXpUyF0f2uFjckQzD8jTghZrOUK1xxDqqZhlwixo0= +github.com/quic-go/quic-go v0.49.1/go.mod h1:s2wDnmCdooUQBmQfpUSTCYBl1/D4FcqbULMMkASvR6s= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= diff --git a/vendor/github.com/quic-go/quic-go/.golangci.yml b/vendor/github.com/quic-go/quic-go/.golangci.yml index 63b40cc3..1174b125 100644 --- a/vendor/github.com/quic-go/quic-go/.golangci.yml +++ b/vendor/github.com/quic-go/quic-go/.golangci.yml @@ -16,9 +16,9 @@ linters: disable-all: true enable: - asciicheck + - copyloopvar - depguard - exhaustive - - exportloopref - goimports - gofmt # redundant, since gofmt *should* be a no-op after gofumpt - gofumpt @@ -44,3 +44,8 @@ issues: linters: - exhaustive - prealloc + - unparam + - path: _test\.go + text: "SA1029:" + linters: + - staticcheck diff --git a/vendor/github.com/quic-go/quic-go/README.md b/vendor/github.com/quic-go/quic-go/README.md index 94823d99..ccc9e213 100644 --- a/vendor/github.com/quic-go/quic-go/README.md +++ b/vendor/github.com/quic-go/quic-go/README.md @@ -9,7 +9,8 @@ quic-go is an implementation of the QUIC protocol ([RFC 9000](https://datatracker.ietf.org/doc/html/rfc9000), [RFC 9001](https://datatracker.ietf.org/doc/html/rfc9001), [RFC 9002](https://datatracker.ietf.org/doc/html/rfc9002)) in Go. It has support for HTTP/3 ([RFC 9114](https://datatracker.ietf.org/doc/html/rfc9114)), including QPACK ([RFC 9204](https://datatracker.ietf.org/doc/html/rfc9204)) and HTTP Datagrams ([RFC 9297](https://datatracker.ietf.org/doc/html/rfc9297)). -In addition to these base RFCs, it also implements the following RFCs: +In addition to these base RFCs, it also implements the following RFCs: + * Unreliable Datagram Extension ([RFC 9221](https://datatracker.ietf.org/doc/html/rfc9221)) * Datagram Packetization Layer Path MTU Discovery (DPLPMTUD, [RFC 8899](https://datatracker.ietf.org/doc/html/rfc8899)) * QUIC Version 2 ([RFC 9369](https://datatracker.ietf.org/doc/html/rfc9369)) @@ -33,6 +34,7 @@ Detailed documentation can be found on [quic-go.net](https://quic-go.net/docs/). | [Hysteria](https://github.com/apernet/hysteria) | A powerful, lightning fast and censorship resistant proxy | ![GitHub Repo stars](https://img.shields.io/github/stars/apernet/hysteria?style=flat-square) | | [Mercure](https://github.com/dunglas/mercure) | An open, easy, fast, reliable and battery-efficient solution for real-time communications | ![GitHub Repo stars](https://img.shields.io/github/stars/dunglas/mercure?style=flat-square) | | [OONI Probe](https://github.com/ooni/probe-cli) | Next generation OONI Probe. Library and CLI tool. | ![GitHub Repo stars](https://img.shields.io/github/stars/ooni/probe-cli?style=flat-square) | +| [reverst](https://github.com/flipt-io/reverst) | Reverse Tunnels in Go over HTTP/3 and QUIC | ![GitHub Repo stars](https://img.shields.io/github/stars/flipt-io/reverst?style=flat-square) | | [RoadRunner](https://github.com/roadrunner-server/roadrunner) | High-performance PHP application server, process manager written in Go and powered with plugins | ![GitHub Repo stars](https://img.shields.io/github/stars/roadrunner-server/roadrunner?style=flat-square) | | [syncthing](https://github.com/syncthing/syncthing/) | Open Source Continuous File Synchronization | ![GitHub Repo stars](https://img.shields.io/github/stars/syncthing/syncthing?style=flat-square) | | [traefik](https://github.com/traefik/traefik) | The Cloud Native Application Proxy | ![GitHub Repo stars](https://img.shields.io/github/stars/traefik/traefik?style=flat-square) | diff --git a/vendor/github.com/quic-go/quic-go/client.go b/vendor/github.com/quic-go/quic-go/client.go index 1c5654f6..29a715cc 100644 --- a/vendor/github.com/quic-go/quic-go/client.go +++ b/vendor/github.com/quic-go/quic-go/client.go @@ -7,38 +7,8 @@ import ( "net" "github.com/quic-go/quic-go/internal/protocol" - "github.com/quic-go/quic-go/internal/utils" - "github.com/quic-go/quic-go/logging" ) -type client struct { - sendConn sendConn - - use0RTT bool - - packetHandlers packetHandlerManager - onClose func() - - tlsConf *tls.Config - config *Config - - connIDGenerator ConnectionIDGenerator - srcConnID protocol.ConnectionID - destConnID protocol.ConnectionID - - initialPacketNumber protocol.PacketNumber - hasNegotiatedVersion bool - version protocol.Version - - handshakeChan chan struct{} - - conn quicConn - - tracer *logging.ConnectionTracer - tracingID ConnectionTracingID - logger utils.Logger -} - // make it possible to mock connection ID for initial generation in the tests var generateConnectionIDForInitial = protocol.GenerateConnectionIDForInitial @@ -132,120 +102,3 @@ func setupTransport(c net.PacketConn, tlsConf *tls.Config, createdPacketConn boo isSingleUse: true, }, nil } - -func dial( - ctx context.Context, - conn sendConn, - connIDGenerator ConnectionIDGenerator, - packetHandlers packetHandlerManager, - tlsConf *tls.Config, - config *Config, - onClose func(), - use0RTT bool, -) (quicConn, error) { - c, err := newClient(conn, connIDGenerator, config, tlsConf, onClose, use0RTT) - if err != nil { - return nil, err - } - c.packetHandlers = packetHandlers - - c.tracingID = nextConnTracingID() - if c.config.Tracer != nil { - c.tracer = c.config.Tracer(context.WithValue(ctx, ConnectionTracingKey, c.tracingID), protocol.PerspectiveClient, c.destConnID) - } - if c.tracer != nil && c.tracer.StartedConnection != nil { - c.tracer.StartedConnection(c.sendConn.LocalAddr(), c.sendConn.RemoteAddr(), c.srcConnID, c.destConnID) - } - if err := c.dial(ctx); err != nil { - return nil, err - } - return c.conn, nil -} - -func newClient(sendConn sendConn, connIDGenerator ConnectionIDGenerator, config *Config, tlsConf *tls.Config, onClose func(), use0RTT bool) (*client, error) { - srcConnID, err := connIDGenerator.GenerateConnectionID() - if err != nil { - return nil, err - } - destConnID, err := generateConnectionIDForInitial() - if err != nil { - return nil, err - } - c := &client{ - connIDGenerator: connIDGenerator, - srcConnID: srcConnID, - destConnID: destConnID, - sendConn: sendConn, - use0RTT: use0RTT, - onClose: onClose, - tlsConf: tlsConf, - config: config, - version: config.Versions[0], - handshakeChan: make(chan struct{}), - logger: utils.DefaultLogger.WithPrefix("client"), - } - return c, nil -} - -func (c *client) dial(ctx context.Context) error { - c.logger.Infof("Starting new connection to %s (%s -> %s), source connection ID %s, destination connection ID %s, version %s", c.tlsConf.ServerName, c.sendConn.LocalAddr(), c.sendConn.RemoteAddr(), c.srcConnID, c.destConnID, c.version) - - c.conn = newClientConnection( - context.WithValue(context.WithoutCancel(ctx), ConnectionTracingKey, c.tracingID), - c.sendConn, - c.packetHandlers, - c.destConnID, - c.srcConnID, - c.connIDGenerator, - c.config, - c.tlsConf, - c.initialPacketNumber, - c.use0RTT, - c.hasNegotiatedVersion, - c.tracer, - c.logger, - c.version, - ) - c.packetHandlers.Add(c.srcConnID, c.conn) - - errorChan := make(chan error, 1) - recreateChan := make(chan errCloseForRecreating) - go func() { - err := c.conn.run() - var recreateErr *errCloseForRecreating - if errors.As(err, &recreateErr) { - recreateChan <- *recreateErr - return - } - if c.onClose != nil { - c.onClose() - } - errorChan <- err // returns as soon as the connection is closed - }() - - // only set when we're using 0-RTT - // Otherwise, earlyConnChan will be nil. Receiving from a nil chan blocks forever. - var earlyConnChan <-chan struct{} - if c.use0RTT { - earlyConnChan = c.conn.earlyConnReady() - } - - select { - case <-ctx.Done(): - c.conn.destroy(nil) - return context.Cause(ctx) - case err := <-errorChan: - return err - case recreateErr := <-recreateChan: - c.initialPacketNumber = recreateErr.nextPacketNumber - c.version = recreateErr.nextVersion - c.hasNegotiatedVersion = true - return c.dial(ctx) - case <-earlyConnChan: - // ready to send 0-RTT data - return nil - case <-c.conn.HandshakeComplete(): - // handshake successfully completed - return nil - } -} diff --git a/vendor/github.com/quic-go/quic-go/codecov.yml b/vendor/github.com/quic-go/quic-go/codecov.yml index 59e4b58f..77e47fbe 100644 --- a/vendor/github.com/quic-go/quic-go/codecov.yml +++ b/vendor/github.com/quic-go/quic-go/codecov.yml @@ -6,6 +6,8 @@ coverage: - internal/handshake/cipher_suite.go - internal/utils/linkedlist/linkedlist.go - internal/testdata + - logging/connection_tracer_multiplexer.go + - logging/tracer_multiplexer.go - testutils/ - fuzzing/ - metrics/ diff --git a/vendor/github.com/quic-go/quic-go/conn_id_generator.go b/vendor/github.com/quic-go/quic-go/conn_id_generator.go index d7be6540..c309c2cd 100644 --- a/vendor/github.com/quic-go/quic-go/conn_id_generator.go +++ b/vendor/github.com/quic-go/quic-go/conn_id_generator.go @@ -15,19 +15,19 @@ type connIDGenerator struct { activeSrcConnIDs map[uint64]protocol.ConnectionID initialClientDestConnID *protocol.ConnectionID // nil for the client - addConnectionID func(protocol.ConnectionID) - getStatelessResetToken func(protocol.ConnectionID) protocol.StatelessResetToken - removeConnectionID func(protocol.ConnectionID) - retireConnectionID func(protocol.ConnectionID) - replaceWithClosed func([]protocol.ConnectionID, []byte) - queueControlFrame func(wire.Frame) + addConnectionID func(protocol.ConnectionID) + statelessResetter *statelessResetter + removeConnectionID func(protocol.ConnectionID) + retireConnectionID func(protocol.ConnectionID) + replaceWithClosed func([]protocol.ConnectionID, []byte) + queueControlFrame func(wire.Frame) } func newConnIDGenerator( initialConnectionID protocol.ConnectionID, initialClientDestConnID *protocol.ConnectionID, // nil for the client addConnectionID func(protocol.ConnectionID), - getStatelessResetToken func(protocol.ConnectionID) protocol.StatelessResetToken, + statelessResetter *statelessResetter, removeConnectionID func(protocol.ConnectionID), retireConnectionID func(protocol.ConnectionID), replaceWithClosed func([]protocol.ConnectionID, []byte), @@ -35,14 +35,14 @@ func newConnIDGenerator( generator ConnectionIDGenerator, ) *connIDGenerator { m := &connIDGenerator{ - generator: generator, - activeSrcConnIDs: make(map[uint64]protocol.ConnectionID), - addConnectionID: addConnectionID, - getStatelessResetToken: getStatelessResetToken, - removeConnectionID: removeConnectionID, - retireConnectionID: retireConnectionID, - replaceWithClosed: replaceWithClosed, - queueControlFrame: queueControlFrame, + generator: generator, + activeSrcConnIDs: make(map[uint64]protocol.ConnectionID), + addConnectionID: addConnectionID, + statelessResetter: statelessResetter, + removeConnectionID: removeConnectionID, + retireConnectionID: retireConnectionID, + replaceWithClosed: replaceWithClosed, + queueControlFrame: queueControlFrame, } m.activeSrcConnIDs[0] = initialConnectionID m.initialClientDestConnID = initialClientDestConnID @@ -104,7 +104,7 @@ func (m *connIDGenerator) issueNewConnID() error { m.queueControlFrame(&wire.NewConnectionIDFrame{ SequenceNumber: m.highestSeq + 1, ConnectionID: connID, - StatelessResetToken: m.getStatelessResetToken(connID), + StatelessResetToken: m.statelessResetter.GetStatelessResetToken(connID), }) m.highestSeq++ return nil diff --git a/vendor/github.com/quic-go/quic-go/conn_id_manager.go b/vendor/github.com/quic-go/quic-go/conn_id_manager.go index 4aa3f749..4030913d 100644 --- a/vendor/github.com/quic-go/quic-go/conn_id_manager.go +++ b/vendor/github.com/quic-go/quic-go/conn_id_manager.go @@ -35,6 +35,8 @@ type connIDManager struct { addStatelessResetToken func(protocol.StatelessResetToken) removeStatelessResetToken func(protocol.StatelessResetToken) queueControlFrame func(wire.Frame) + + closed bool } func newConnIDManager( @@ -66,6 +68,12 @@ func (h *connIDManager) Add(f *wire.NewConnectionIDFrame) error { } func (h *connIDManager) add(f *wire.NewConnectionIDFrame) error { + if h.activeConnectionID.Len() == 0 { + return &qerr.TransportError{ + ErrorCode: qerr.ProtocolViolation, + ErrorMessage: "received NEW_CONNECTION_ID frame but zero-length connection IDs are in use", + } + } // If the NEW_CONNECTION_ID frame is reordered, such that its sequence number is smaller than the currently active // connection ID or if it was already retired, send the RETIRE_CONNECTION_ID frame immediately. if f.SequenceNumber < h.activeSequenceNumber || f.SequenceNumber < h.highestRetired { @@ -142,6 +150,7 @@ func (h *connIDManager) addConnectionID(seq uint64, connID protocol.ConnectionID } func (h *connIDManager) updateConnectionID() { + h.assertNotClosed() h.queueControlFrame(&wire.RetireConnectionIDFrame{ SequenceNumber: h.activeSequenceNumber, }) @@ -160,6 +169,7 @@ func (h *connIDManager) updateConnectionID() { } func (h *connIDManager) Close() { + h.closed = true if h.activeStatelessResetToken != nil { h.removeStatelessResetToken(*h.activeStatelessResetToken) } @@ -176,6 +186,7 @@ func (h *connIDManager) ChangeInitialConnID(newConnID protocol.ConnectionID) { // is called when the server provides a stateless reset token in the transport parameters func (h *connIDManager) SetStatelessResetToken(token protocol.StatelessResetToken) { + h.assertNotClosed() if h.activeSequenceNumber != 0 { panic("expected first connection ID to have sequence number 0") } @@ -203,6 +214,7 @@ func (h *connIDManager) shouldUpdateConnID() bool { } func (h *connIDManager) Get() protocol.ConnectionID { + h.assertNotClosed() if h.shouldUpdateConnID() { h.updateConnectionID() } @@ -212,3 +224,13 @@ func (h *connIDManager) Get() protocol.ConnectionID { func (h *connIDManager) SetHandshakeComplete() { h.handshakeComplete = true } + +// Using the connIDManager after it has been closed can have disastrous effects: +// If the connection ID is rotated, a new entry would be inserted into the packet handler map, +// leading to a memory leak of the connection struct. +// See https://github.com/quic-go/quic-go/pull/4852 for more details. +func (h *connIDManager) assertNotClosed() { + if h.closed { + panic("connection ID manager is closed") + } +} diff --git a/vendor/github.com/quic-go/quic-go/connection.go b/vendor/github.com/quic-go/quic-go/connection.go index 4390f5ca..c9ce930c 100644 --- a/vendor/github.com/quic-go/quic-go/connection.go +++ b/vendor/github.com/quic-go/quic-go/connection.go @@ -85,7 +85,6 @@ func (p *receivedPacket) Clone() *receivedPacket { type connRunner interface { Add(protocol.ConnectionID, packetHandler) bool - GetStatelessResetToken(protocol.ConnectionID) protocol.StatelessResetToken Retire(protocol.ConnectionID) Remove(protocol.ConnectionID) ReplaceWithClosed([]protocol.ConnectionID, []byte) @@ -225,7 +224,7 @@ var newConnection = func( destConnID protocol.ConnectionID, srcConnID protocol.ConnectionID, connIDGenerator ConnectionIDGenerator, - statelessResetToken protocol.StatelessResetToken, + statelessResetter *statelessResetter, conf *Config, tlsConf *tls.Config, tokenGenerator *handshake.TokenGenerator, @@ -263,7 +262,7 @@ var newConnection = func( srcConnID, &clientDestConnID, func(connID protocol.ConnectionID) { runner.Add(connID, s) }, - runner.GetStatelessResetToken, + statelessResetter, runner.Remove, runner.Retire, runner.ReplaceWithClosed, @@ -282,6 +281,7 @@ var newConnection = func( s.logger, ) s.maxPayloadSizeEstimate.Store(uint32(estimateMaxPayloadSize(protocol.ByteCount(s.config.InitialPacketSize)))) + statelessResetToken := statelessResetter.GetStatelessResetToken(srcConnID) params := &wire.TransportParameters{ InitialMaxStreamDataBidiLocal: protocol.ByteCount(s.config.InitialStreamReceiveWindow), InitialMaxStreamDataBidiRemote: protocol.ByteCount(s.config.InitialStreamReceiveWindow), @@ -340,6 +340,7 @@ var newClientConnection = func( destConnID protocol.ConnectionID, srcConnID protocol.ConnectionID, connIDGenerator ConnectionIDGenerator, + statelessResetter *statelessResetter, conf *Config, tlsConf *tls.Config, initialPacketNumber protocol.PacketNumber, @@ -372,7 +373,7 @@ var newClientConnection = func( srcConnID, nil, func(connID protocol.ConnectionID) { runner.Add(connID, s) }, - runner.GetStatelessResetToken, + statelessResetter, runner.Remove, runner.Retire, runner.ReplaceWithClosed, @@ -477,7 +478,7 @@ func (s *connection) preSetup() { uint64(s.config.MaxIncomingUniStreams), s.perspective, ) - s.framer = newFramer() + s.framer = newFramer(s.connFlowController) s.receivedPackets = make(chan receivedPacket, protocol.MaxConnUnprocessedPackets) s.closeChan = make(chan closeError, 1) s.sendingScheduled = make(chan struct{}, 1) @@ -496,12 +497,28 @@ func (s *connection) run() error { var closeErr closeError defer func() { s.ctxCancel(closeErr.err) }() + defer func() { + // Drain queued packets that will never be processed. + for { + select { + case p, ok := <-s.receivedPackets: + if !ok { + return + } + p.buffer.Decrement() + p.buffer.MaybeRelease() + default: + return + } + } + }() + s.timer = *newTimer() if err := s.cryptoStreamHandler.StartHandshake(s.ctx); err != nil { return err } - if err := s.handleHandshakeEvents(); err != nil { + if err := s.handleHandshakeEvents(time.Now()); err != nil { return err } go func() { @@ -602,7 +619,7 @@ runLoop: if timeout := s.sentPacketHandler.GetLossDetectionTimeout(); !timeout.IsZero() && timeout.Before(now) { // This could cause packets to be retransmitted. // Check it before trying to send packets. - if err := s.sentPacketHandler.OnLossDetectionTimeout(); err != nil { + if err := s.sentPacketHandler.OnLossDetectionTimeout(now); err != nil { s.closeLocal(err) } } @@ -727,7 +744,7 @@ func (s *connection) idleTimeoutStartTime() time.Time { return startTime } -func (s *connection) handleHandshakeComplete() error { +func (s *connection) handleHandshakeComplete(now time.Time) error { defer close(s.handshakeCompleteChan) // Once the handshake completes, we have derived 1-RTT keys. // There's no point in queueing undecryptable packets for later decryption anymore. @@ -748,7 +765,7 @@ func (s *connection) handleHandshakeComplete() error { } // All these only apply to the server side. - if err := s.handleHandshakeConfirmed(); err != nil { + if err := s.handleHandshakeConfirmed(now); err != nil { return err } @@ -771,23 +788,25 @@ func (s *connection) handleHandshakeComplete() error { return nil } -func (s *connection) handleHandshakeConfirmed() error { - if err := s.dropEncryptionLevel(protocol.EncryptionHandshake); err != nil { +func (s *connection) handleHandshakeConfirmed(now time.Time) error { + if err := s.dropEncryptionLevel(protocol.EncryptionInitial, now); err != nil { + return err + } + if err := s.dropEncryptionLevel(protocol.EncryptionHandshake, now); err != nil { return err } s.handshakeConfirmed = true - s.sentPacketHandler.SetHandshakeConfirmed() s.cryptoStreamHandler.SetHandshakeConfirmed() if !s.config.DisablePathMTUDiscovery && s.conn.capabilities().DF { - s.mtuDiscoverer.Start() + s.mtuDiscoverer.Start(now) } return nil } func (s *connection) handlePacketImpl(rp receivedPacket) bool { - s.sentPacketHandler.ReceivedBytes(rp.Size()) + s.sentPacketHandler.ReceivedBytes(rp.Size(), rp.rcvTime) if wire.IsVersionNegotiationPacket(rp.data) { s.handleVersionNegotiationPacket(rp) @@ -958,7 +977,7 @@ func (s *connection) handleLongHeaderPacket(p receivedPacket, hdr *wire.Header) // drop 0-RTT packets, if we are a client if s.perspective == protocol.PerspectiveClient && hdr.Type == protocol.PacketType0RTT { if s.tracer != nil && s.tracer.DroppedPacket != nil { - s.tracer.DroppedPacket(logging.PacketType0RTT, protocol.InvalidPacketNumber, p.Size(), logging.PacketDropKeyUnavailable) + s.tracer.DroppedPacket(logging.PacketType0RTT, protocol.InvalidPacketNumber, p.Size(), logging.PacketDropUnexpectedPacket) } return false } @@ -1068,6 +1087,15 @@ func (s *connection) handleRetryPacket(hdr *wire.Header, data []byte, rcvTime ti return false } + newDestConnID := hdr.SrcConnectionID + s.receivedRetry = true + s.sentPacketHandler.ResetForRetry(rcvTime) + s.handshakeDestConnID = newDestConnID + s.retrySrcConnID = &newDestConnID + s.cryptoStreamHandler.ChangeConnectionID(newDestConnID) + s.packer.SetToken(hdr.Token) + s.connIDManager.ChangeInitialConnID(newDestConnID) + if s.logger.Debug() { s.logger.Debugf("<- Received Retry:") (&wire.ExtendedHeader{Header: *hdr}).Log(s.logger) @@ -1076,17 +1104,7 @@ func (s *connection) handleRetryPacket(hdr *wire.Header, data []byte, rcvTime ti if s.tracer != nil && s.tracer.ReceivedRetry != nil { s.tracer.ReceivedRetry(hdr) } - newDestConnID := hdr.SrcConnectionID - s.receivedRetry = true - if err := s.sentPacketHandler.ResetForRetry(rcvTime); err != nil { - s.closeLocal(err) - return false - } - s.handshakeDestConnID = newDestConnID - s.retrySrcConnID = &newDestConnID - s.cryptoStreamHandler.ChangeConnectionID(newDestConnID) - s.packer.SetToken(hdr.Token) - s.connIDManager.ChangeInitialConnID(newDestConnID) + s.scheduleSending() return true } @@ -1195,7 +1213,7 @@ func (s *connection) handleUnpackedLongHeaderPacket( !s.droppedInitialKeys { // On the server side, Initial keys are dropped as soon as the first Handshake packet is received. // See Section 4.9.1 of RFC 9001. - if err := s.dropEncryptionLevel(protocol.EncryptionInitial); err != nil { + if err := s.dropEncryptionLevel(protocol.EncryptionInitial, rcvTime); err != nil { return err } } @@ -1210,7 +1228,7 @@ func (s *connection) handleUnpackedLongHeaderPacket( s.tracer.ReceivedLongHeaderPacket(packet.hdr, packetSize, ecn, frames) } } - isAckEliciting, err := s.handleFrames(packet.data, packet.hdr.DestConnectionID, packet.encryptionLevel, log) + isAckEliciting, err := s.handleFrames(packet.data, packet.hdr.DestConnectionID, packet.encryptionLevel, log, rcvTime) if err != nil { return err } @@ -1229,7 +1247,7 @@ func (s *connection) handleUnpackedShortHeaderPacket( s.firstAckElicitingPacketAfterIdleSentTime = time.Time{} s.keepAlivePingSent = false - isAckEliciting, err := s.handleFrames(data, destConnID, protocol.Encryption1RTT, log) + isAckEliciting, err := s.handleFrames(data, destConnID, protocol.Encryption1RTT, log, rcvTime) if err != nil { return err } @@ -1241,6 +1259,7 @@ func (s *connection) handleFrames( destConnID protocol.ConnectionID, encLevel protocol.EncryptionLevel, log func([]logging.Frame), + rcvTime time.Time, ) (isAckEliciting bool, _ error) { // Only used for tracing. // If we're not tracing, this slice will always remain empty. @@ -1270,7 +1289,7 @@ func (s *connection) handleFrames( if handleErr != nil { continue } - if err := s.handleFrame(frame, encLevel, destConnID); err != nil { + if err := s.handleFrame(frame, encLevel, destConnID, rcvTime); err != nil { if log == nil { return false, err } @@ -1291,7 +1310,7 @@ func (s *connection) handleFrames( // We receive a Handshake packet that contains the CRYPTO frame that allows us to complete the handshake, // and an ACK serialized after that CRYPTO frame. In this case, we still want to process the ACK frame. if !handshakeWasComplete && s.handshakeComplete { - if err := s.handleHandshakeComplete(); err != nil { + if err := s.handleHandshakeComplete(rcvTime); err != nil { return false, err } } @@ -1299,20 +1318,25 @@ func (s *connection) handleFrames( return } -func (s *connection) handleFrame(f wire.Frame, encLevel protocol.EncryptionLevel, destConnID protocol.ConnectionID) error { +func (s *connection) handleFrame( + f wire.Frame, + encLevel protocol.EncryptionLevel, + destConnID protocol.ConnectionID, + rcvTime time.Time, +) error { var err error wire.LogFrame(s.logger, f, false) switch frame := f.(type) { case *wire.CryptoFrame: - err = s.handleCryptoFrame(frame, encLevel) + err = s.handleCryptoFrame(frame, encLevel, rcvTime) case *wire.StreamFrame: - err = s.handleStreamFrame(frame) + err = s.handleStreamFrame(frame, rcvTime) case *wire.AckFrame: - err = s.handleAckFrame(frame, encLevel) + err = s.handleAckFrame(frame, encLevel, rcvTime) case *wire.ConnectionCloseFrame: s.handleConnectionCloseFrame(frame) case *wire.ResetStreamFrame: - err = s.handleResetStreamFrame(frame) + err = s.handleResetStreamFrame(frame, rcvTime) case *wire.MaxDataFrame: s.handleMaxDataFrame(frame) case *wire.MaxStreamDataFrame: @@ -1321,6 +1345,7 @@ func (s *connection) handleFrame(f wire.Frame, encLevel protocol.EncryptionLevel s.handleMaxStreamsFrame(frame) case *wire.DataBlockedFrame: case *wire.StreamDataBlockedFrame: + err = s.handleStreamDataBlockedFrame(frame) case *wire.StreamsBlockedFrame: case *wire.StopSendingFrame: err = s.handleStopSendingFrame(frame) @@ -1329,7 +1354,10 @@ func (s *connection) handleFrame(f wire.Frame, encLevel protocol.EncryptionLevel s.handlePathChallengeFrame(frame) case *wire.PathResponseFrame: // since we don't send PATH_CHALLENGEs, we don't expect PATH_RESPONSEs - err = errors.New("unexpected PATH_RESPONSE frame") + err = &qerr.TransportError{ + ErrorCode: qerr.ProtocolViolation, + ErrorMessage: "unexpected PATH_RESPONSE frame", + } case *wire.NewTokenFrame: err = s.handleNewTokenFrame(frame) case *wire.NewConnectionIDFrame: @@ -1337,7 +1365,7 @@ func (s *connection) handleFrame(f wire.Frame, encLevel protocol.EncryptionLevel case *wire.RetireConnectionIDFrame: err = s.handleRetireConnectionIDFrame(frame, destConnID) case *wire.HandshakeDoneFrame: - err = s.handleHandshakeDoneFrame() + err = s.handleHandshakeDoneFrame(rcvTime) case *wire.DatagramFrame: err = s.handleDatagramFrame(frame) default: @@ -1376,7 +1404,7 @@ func (s *connection) handleConnectionCloseFrame(frame *wire.ConnectionCloseFrame }) } -func (s *connection) handleCryptoFrame(frame *wire.CryptoFrame, encLevel protocol.EncryptionLevel) error { +func (s *connection) handleCryptoFrame(frame *wire.CryptoFrame, encLevel protocol.EncryptionLevel, rcvTime time.Time) error { if err := s.cryptoStreamManager.HandleCryptoFrame(frame, encLevel); err != nil { return err } @@ -1389,10 +1417,10 @@ func (s *connection) handleCryptoFrame(frame *wire.CryptoFrame, encLevel protoco return err } } - return s.handleHandshakeEvents() + return s.handleHandshakeEvents(rcvTime) } -func (s *connection) handleHandshakeEvents() error { +func (s *connection) handleHandshakeEvents(now time.Time) error { for { ev := s.cryptoStreamHandler.NextEvent() var err error @@ -1413,7 +1441,7 @@ func (s *connection) handleHandshakeEvents() error { s.undecryptablePacketsToProcess = s.undecryptablePackets s.undecryptablePackets = nil case handshake.EventDiscard0RTTKeys: - err = s.dropEncryptionLevel(protocol.Encryption0RTT) + err = s.dropEncryptionLevel(protocol.Encryption0RTT, now) case handshake.EventWriteInitialData: _, err = s.initialStream.Write(ev.Data) case handshake.EventWriteHandshakeData: @@ -1425,17 +1453,15 @@ func (s *connection) handleHandshakeEvents() error { } } -func (s *connection) handleStreamFrame(frame *wire.StreamFrame) error { +func (s *connection) handleStreamFrame(frame *wire.StreamFrame, rcvTime time.Time) error { str, err := s.streamsMap.GetOrOpenReceiveStream(frame.StreamID) if err != nil { return err } - if str == nil { - // Stream is closed and already garbage collected - // ignore this StreamFrame + if str == nil { // stream was already closed and garbage collected return nil } - return str.handleStreamFrame(frame) + return str.handleStreamFrame(frame, rcvTime) } func (s *connection) handleMaxDataFrame(frame *wire.MaxDataFrame) { @@ -1455,11 +1481,18 @@ func (s *connection) handleMaxStreamDataFrame(frame *wire.MaxStreamDataFrame) er return nil } +func (s *connection) handleStreamDataBlockedFrame(frame *wire.StreamDataBlockedFrame) error { + // We don't need to do anything in response to a STREAM_DATA_BLOCKED frame, + // but we need to make sure that the stream ID is valid. + _, err := s.streamsMap.GetOrOpenReceiveStream(frame.StreamID) + return err +} + func (s *connection) handleMaxStreamsFrame(frame *wire.MaxStreamsFrame) { s.streamsMap.HandleMaxStreamsFrame(frame) } -func (s *connection) handleResetStreamFrame(frame *wire.ResetStreamFrame) error { +func (s *connection) handleResetStreamFrame(frame *wire.ResetStreamFrame, rcvTime time.Time) error { str, err := s.streamsMap.GetOrOpenReceiveStream(frame.StreamID) if err != nil { return err @@ -1468,7 +1501,7 @@ func (s *connection) handleResetStreamFrame(frame *wire.ResetStreamFrame) error // stream is closed and already garbage collected return nil } - return str.handleResetStreamFrame(frame) + return str.handleResetStreamFrame(frame, rcvTime) } func (s *connection) handleStopSendingFrame(frame *wire.StopSendingFrame) error { @@ -1509,7 +1542,7 @@ func (s *connection) handleRetireConnectionIDFrame(f *wire.RetireConnectionIDFra return s.connIDGenerator.Retire(f.SequenceNumber, destConnID) } -func (s *connection) handleHandshakeDoneFrame() error { +func (s *connection) handleHandshakeDoneFrame(rcvTime time.Time) error { if s.perspective == protocol.PerspectiveServer { return &qerr.TransportError{ ErrorCode: qerr.ProtocolViolation, @@ -1517,12 +1550,12 @@ func (s *connection) handleHandshakeDoneFrame() error { } } if !s.handshakeConfirmed { - return s.handleHandshakeConfirmed() + return s.handleHandshakeConfirmed(rcvTime) } return nil } -func (s *connection) handleAckFrame(frame *wire.AckFrame, encLevel protocol.EncryptionLevel) error { +func (s *connection) handleAckFrame(frame *wire.AckFrame, encLevel protocol.EncryptionLevel, rcvTime time.Time) error { acked1RTTPacket, err := s.sentPacketHandler.ReceivedAck(frame, encLevel, s.lastPacketReceivedTime) if err != nil { return err @@ -1534,7 +1567,7 @@ func (s *connection) handleAckFrame(frame *wire.AckFrame, encLevel protocol.Encr // This is only possible if the ACK was sent in a 1-RTT packet. // This is an optimization over simply waiting for a HANDSHAKE_DONE frame, see section 4.1.2 of RFC 9001. if s.perspective == protocol.PerspectiveClient && !s.handshakeConfirmed { - if err := s.handleHandshakeConfirmed(); err != nil { + if err := s.handleHandshakeConfirmed(rcvTime); err != nil { return err } } @@ -1627,6 +1660,8 @@ func (s *connection) handleCloseError(closeErr *closeError) { errors.As(e, &recreateErr), errors.As(e, &applicationErr), errors.As(e, &transportErr): + case closeErr.immediate: + e = closeErr.err default: e = &qerr.TransportError{ ErrorCode: qerr.InternalError, @@ -1635,11 +1670,16 @@ func (s *connection) handleCloseError(closeErr *closeError) { } s.streamsMap.CloseWithError(e) - s.connIDManager.Close() if s.datagramQueue != nil { s.datagramQueue.CloseWithError(e) } + // In rare instances, the connection ID manager might switch to a new connection ID + // when sending the CONNECTION_CLOSE frame. + // The connection ID manager removes the active stateless reset token from the packet + // handler map when it is closed, so we need to make sure that this happens last. + defer s.connIDManager.Close() + if s.tracer != nil && s.tracer.ClosedConnection != nil && !errors.As(e, &recreateErr) { s.tracer.ClosedConnection(e) } @@ -1666,11 +1706,11 @@ func (s *connection) handleCloseError(closeErr *closeError) { s.connIDGenerator.ReplaceWithClosed(connClosePacket) } -func (s *connection) dropEncryptionLevel(encLevel protocol.EncryptionLevel) error { +func (s *connection) dropEncryptionLevel(encLevel protocol.EncryptionLevel, now time.Time) error { if s.tracer != nil && s.tracer.DroppedEncryptionLevel != nil { s.tracer.DroppedEncryptionLevel(encLevel) } - s.sentPacketHandler.DropPackets(encLevel) + s.sentPacketHandler.DropPackets(encLevel, now) s.receivedPacketHandler.DropPackets(encLevel) //nolint:exhaustive // only Initial and 0-RTT need special treatment switch encLevel { @@ -1772,7 +1812,7 @@ func (s *connection) applyTransportParameters() { if params.MaxIdleTimeout > 0 { s.idleTimeout = min(s.idleTimeout, params.MaxIdleTimeout) } - s.keepAliveInterval = min(s.config.KeepAlivePeriod, min(s.idleTimeout/2, protocol.MaxKeepAliveInterval)) + s.keepAliveInterval = min(s.config.KeepAlivePeriod, s.idleTimeout/2) s.streamsMap.UpdateLimits(params) s.frameParser.SetAckDelayExponent(params.AckDelayExponent) s.connFlowController.UpdateSendWindow(params.InitialMaxData) @@ -1822,28 +1862,10 @@ func (s *connection) triggerSending(now time.Time) error { case ackhandler.SendAck: // We can at most send a single ACK only packet. // There will only be a new ACK after receiving new packets. - // SendAck is only returned when we're congestion limited, so we don't need to set the pacinggs timer. + // SendAck is only returned when we're congestion limited, so we don't need to set the pacing timer. return s.maybeSendAckOnlyPacket(now) - case ackhandler.SendPTOInitial: - if err := s.sendProbePacket(protocol.EncryptionInitial, now); err != nil { - return err - } - if s.sendQueue.WouldBlock() { - s.scheduleSending() - return nil - } - return s.triggerSending(now) - case ackhandler.SendPTOHandshake: - if err := s.sendProbePacket(protocol.EncryptionHandshake, now); err != nil { - return err - } - if s.sendQueue.WouldBlock() { - s.scheduleSending() - return nil - } - return s.triggerSending(now) - case ackhandler.SendPTOAppData: - if err := s.sendProbePacket(protocol.Encryption1RTT, now); err != nil { + case ackhandler.SendPTOInitial, ackhandler.SendPTOHandshake, ackhandler.SendPTOAppData: + if err := s.sendProbePacket(sendMode, now); err != nil { return err } if s.sendQueue.WouldBlock() { @@ -1862,7 +1884,7 @@ func (s *connection) sendPackets(now time.Time) error { // Performance-wise, this doesn't matter, since we only send a very small (<10) number of // MTU probe packets per connection. if s.handshakeConfirmed && s.mtuDiscoverer != nil && s.mtuDiscoverer.ShouldSendProbe(now) { - ping, size := s.mtuDiscoverer.GetPing() + ping, size := s.mtuDiscoverer.GetPing(now) p, buf, err := s.packer.PackMTUProbePacket(ping, size, s.version) if err != nil { return err @@ -1871,15 +1893,12 @@ func (s *connection) sendPackets(now time.Time) error { s.logShortHeaderPacket(p.DestConnID, p.Ack, p.Frames, p.StreamFrames, p.PacketNumber, p.PacketNumberLen, p.KeyPhase, ecn, buf.Len(), false) s.registerPackedShortHeaderPacket(p, ecn, now) s.sendQueue.Send(buf, 0, ecn) - // This is kind of a hack. We need to trigger sending again somehow. - s.pacingDeadline = deadlineSendImmediately + // There's (likely) more data to send. Loop around again. + s.scheduleSending() return nil } - if isBlocked, offset := s.connFlowController.IsNewlyBlocked(); isBlocked { - s.framer.QueueControlFrame(&wire.DataBlockedFrame{MaximumData: offset}) - } - if offset := s.connFlowController.GetWindowUpdate(); offset > 0 { + if offset := s.connFlowController.GetWindowUpdate(now); offset > 0 { s.framer.QueueControlFrame(&wire.MaxDataFrame{MaximumData: offset}) } if cf := s.cryptoStreamManager.GetPostHandshakeData(protocol.MaxPostHandshakeCryptoFrameSize); cf != nil { @@ -1887,7 +1906,7 @@ func (s *connection) sendPackets(now time.Time) error { } if !s.handshakeConfirmed { - packet, err := s.packer.PackCoalescedPacket(false, s.maxPacketSize(), s.version) + packet, err := s.packer.PackCoalescedPacket(false, s.maxPacketSize(), now, s.version) if err != nil || packet == nil { return err } @@ -1999,6 +2018,7 @@ func (s *connection) sendPacketsWithGSO(now time.Time) error { return nil } + ecn = nextECN buf = getLargePacketBuffer() } } @@ -2014,7 +2034,7 @@ func (s *connection) resetPacingDeadline() { func (s *connection) maybeSendAckOnlyPacket(now time.Time) error { if !s.handshakeConfirmed { ecn := s.sentPacketHandler.ECNMode(false) - packet, err := s.packer.PackCoalescedPacket(true, s.maxPacketSize(), s.version) + packet, err := s.packer.PackCoalescedPacket(true, s.maxPacketSize(), now, s.version) if err != nil { return err } @@ -2025,7 +2045,7 @@ func (s *connection) maybeSendAckOnlyPacket(now time.Time) error { } ecn := s.sentPacketHandler.ECNMode(true) - p, buf, err := s.packer.PackAckOnlyPacket(s.maxPacketSize(), s.version) + p, buf, err := s.packer.PackAckOnlyPacket(s.maxPacketSize(), now, s.version) if err != nil { if err == errNothingToPack { return nil @@ -2038,7 +2058,19 @@ func (s *connection) maybeSendAckOnlyPacket(now time.Time) error { return nil } -func (s *connection) sendProbePacket(encLevel protocol.EncryptionLevel, now time.Time) error { +func (s *connection) sendProbePacket(sendMode ackhandler.SendMode, now time.Time) error { + var encLevel protocol.EncryptionLevel + //nolint:exhaustive // We only need to handle the PTO send modes here. + switch sendMode { + case ackhandler.SendPTOInitial: + encLevel = protocol.EncryptionInitial + case ackhandler.SendPTOHandshake: + encLevel = protocol.EncryptionHandshake + case ackhandler.SendPTOAppData: + encLevel = protocol.Encryption1RTT + default: + return fmt.Errorf("connection BUG: unexpected send mode: %d", sendMode) + } // Queue probe packets until we actually send out a packet, // or until there are no more packets to queue. var packet *coalescedPacket @@ -2047,7 +2079,7 @@ func (s *connection) sendProbePacket(encLevel protocol.EncryptionLevel, now time break } var err error - packet, err = s.packer.MaybePackProbePacket(encLevel, s.maxPacketSize(), s.version) + packet, err = s.packer.MaybePackProbePacket(encLevel, s.maxPacketSize(), now, s.version) if err != nil { return err } @@ -2058,7 +2090,7 @@ func (s *connection) sendProbePacket(encLevel protocol.EncryptionLevel, now time if packet == nil { s.retransmissionQueue.AddPing(encLevel) var err error - packet, err = s.packer.MaybePackProbePacket(encLevel, s.maxPacketSize(), s.version) + packet, err = s.packer.MaybePackProbePacket(encLevel, s.maxPacketSize(), now, s.version) if err != nil { return err } @@ -2073,7 +2105,7 @@ func (s *connection) sendProbePacket(encLevel protocol.EncryptionLevel, now time // If there was nothing to pack, the returned size is 0. func (s *connection) appendOneShortHeaderPacket(buf *packetBuffer, maxSize protocol.ByteCount, ecn protocol.ECN, now time.Time) (protocol.ByteCount, error) { startLen := buf.Len() - p, err := s.packer.AppendPacket(buf, maxSize, s.version) + p, err := s.packer.AppendPacket(buf, maxSize, now, s.version) if err != nil { return 0, err } @@ -2111,7 +2143,7 @@ func (s *connection) sendPackedCoalescedPacket(packet *coalescedPacket, ecn prot !s.droppedInitialKeys { // On the client side, Initial keys are dropped as soon as the first Handshake packet is sent. // See Section 4.9.1 of RFC 9001. - if err := s.dropEncryptionLevel(protocol.EncryptionInitial); err != nil { + if err := s.dropEncryptionLevel(protocol.EncryptionInitial, now); err != nil { return err } } @@ -2251,6 +2283,8 @@ func (s *connection) queueControlFrame(f wire.Frame) { s.scheduleSending() } +func (s *connection) onHasConnectionData() { s.scheduleSending() } + func (s *connection) onHasStreamData(id protocol.StreamID, str sendStreamI) { s.framer.AddActiveStream(id, str) s.scheduleSending() @@ -2300,17 +2334,8 @@ func (s *connection) ReceiveDatagram(ctx context.Context) ([]byte, error) { return s.datagramQueue.Receive(ctx) } -func (s *connection) LocalAddr() net.Addr { - return s.conn.LocalAddr() -} - -func (s *connection) RemoteAddr() net.Addr { - return s.conn.RemoteAddr() -} - -func (s *connection) GetVersion() protocol.Version { - return s.version -} +func (s *connection) LocalAddr() net.Addr { return s.conn.LocalAddr() } +func (s *connection) RemoteAddr() net.Addr { return s.conn.RemoteAddr() } func (s *connection) NextConnection(ctx context.Context) (Connection, error) { // The handshake might fail after the server rejected 0-RTT. diff --git a/vendor/github.com/quic-go/quic-go/connection_logging.go b/vendor/github.com/quic-go/quic-go/connection_logging.go index f75b39f6..a314a6cd 100644 --- a/vendor/github.com/quic-go/quic-go/connection_logging.go +++ b/vendor/github.com/quic-go/quic-go/connection_logging.go @@ -125,12 +125,7 @@ func (s *connection) logShortHeaderPacket( ack = toLoggingAckFrame(ackFrame) } s.tracer.SentShortHeaderPacket( - &logging.ShortHeader{ - DestConnectionID: destConnID, - PacketNumber: pn, - PacketNumberLen: pnLen, - KeyPhase: kp, - }, + &logging.ShortHeader{DestConnectionID: destConnID, PacketNumber: pn, PacketNumberLen: pnLen, KeyPhase: kp}, size, ecn, ack, diff --git a/vendor/github.com/quic-go/quic-go/errors.go b/vendor/github.com/quic-go/quic-go/errors.go index 3fe1e0a9..4a69a7f1 100644 --- a/vendor/github.com/quic-go/quic-go/errors.go +++ b/vendor/github.com/quic-go/quic-go/errors.go @@ -50,8 +50,8 @@ type StreamError struct { } func (e *StreamError) Is(target error) bool { - _, ok := target.(*StreamError) - return ok + t, ok := target.(*StreamError) + return ok && e.StreamID == t.StreamID && e.ErrorCode == t.ErrorCode && e.Remote == t.Remote } func (e *StreamError) Error() string { @@ -68,8 +68,8 @@ type DatagramTooLargeError struct { } func (e *DatagramTooLargeError) Is(target error) bool { - _, ok := target.(*DatagramTooLargeError) - return ok + t, ok := target.(*DatagramTooLargeError) + return ok && e.MaxDatagramPayloadSize == t.MaxDatagramPayloadSize } func (e *DatagramTooLargeError) Error() string { return "DATAGRAM frame too large" } diff --git a/vendor/github.com/quic-go/quic-go/framer.go b/vendor/github.com/quic-go/quic-go/framer.go index e162f6b8..fee31631 100644 --- a/vendor/github.com/quic-go/quic-go/framer.go +++ b/vendor/github.com/quic-go/quic-go/framer.go @@ -3,8 +3,10 @@ package quic import ( "slices" "sync" + "time" "github.com/quic-go/quic-go/internal/ackhandler" + "github.com/quic-go/quic-go/internal/flowcontrol" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/utils/ringbuffer" "github.com/quic-go/quic-go/internal/wire" @@ -21,7 +23,7 @@ const ( const maxStreamControlFrameSize = 25 type streamControlFrameGetter interface { - getControlFrame() (_ ackhandler.Frame, ok, hasMore bool) + getControlFrame(time.Time) (_ ackhandler.Frame, ok, hasMore bool) } type framer struct { @@ -34,13 +36,15 @@ type framer struct { controlFrameMutex sync.Mutex controlFrames []wire.Frame pathResponses []*wire.PathResponseFrame + connFlowController flowcontrol.ConnectionFlowController queuedTooManyControlFrames bool } -func newFramer() *framer { +func newFramer(connFlowController flowcontrol.ConnectionFlowController) *framer { return &framer{ activeStreams: make(map[protocol.StreamID]sendStreamI), streamsWithControlFrames: make(map[protocol.StreamID]streamControlFrameGetter), + connFlowController: connFlowController, } } @@ -78,10 +82,80 @@ func (f *framer) QueueControlFrame(frame wire.Frame) { f.controlFrames = append(f.controlFrames, frame) } -func (f *framer) AppendControlFrames(frames []ackhandler.Frame, maxLen protocol.ByteCount, v protocol.Version) ([]ackhandler.Frame, protocol.ByteCount) { +func (f *framer) Append( + frames []ackhandler.Frame, + streamFrames []ackhandler.StreamFrame, + maxLen protocol.ByteCount, + now time.Time, + v protocol.Version, +) ([]ackhandler.Frame, []ackhandler.StreamFrame, protocol.ByteCount) { f.controlFrameMutex.Lock() - defer f.controlFrameMutex.Unlock() + frames, controlFrameLen := f.appendControlFrames(frames, maxLen, now, v) + maxLen -= controlFrameLen + + var lastFrame ackhandler.StreamFrame + var streamFrameLen protocol.ByteCount + f.mutex.Lock() + // pop STREAM frames, until less than 128 bytes are left in the packet + numActiveStreams := f.streamQueue.Len() + for i := 0; i < numActiveStreams; i++ { + if protocol.MinStreamFrameSize > maxLen { + break + } + sf, blocked := f.getNextStreamFrame(maxLen, v) + if sf.Frame != nil { + streamFrames = append(streamFrames, sf) + maxLen -= sf.Frame.Length(v) + lastFrame = sf + streamFrameLen += sf.Frame.Length(v) + } + // If the stream just became blocked on stream flow control, attempt to pack the + // STREAM_DATA_BLOCKED into the same packet. + if blocked != nil { + l := blocked.Length(v) + // In case it doesn't fit, queue it for the next packet. + if maxLen < l { + f.controlFrames = append(f.controlFrames, blocked) + break + } + frames = append(frames, ackhandler.Frame{Frame: blocked}) + maxLen -= l + controlFrameLen += l + } + } + + // The only way to become blocked on connection-level flow control is by sending STREAM frames. + if isBlocked, offset := f.connFlowController.IsNewlyBlocked(); isBlocked { + blocked := &wire.DataBlockedFrame{MaximumData: offset} + l := blocked.Length(v) + // In case it doesn't fit, queue it for the next packet. + if maxLen >= l { + frames = append(frames, ackhandler.Frame{Frame: blocked}) + controlFrameLen += l + } else { + f.controlFrames = append(f.controlFrames, blocked) + } + } + + f.mutex.Unlock() + f.controlFrameMutex.Unlock() + + if lastFrame.Frame != nil { + // account for the smaller size of the last STREAM frame + streamFrameLen -= lastFrame.Frame.Length(v) + lastFrame.Frame.DataLenPresent = false + streamFrameLen += lastFrame.Frame.Length(v) + } + + return frames, streamFrames, controlFrameLen + streamFrameLen +} +func (f *framer) appendControlFrames( + frames []ackhandler.Frame, + maxLen protocol.ByteCount, + now time.Time, + v protocol.Version, +) ([]ackhandler.Frame, protocol.ByteCount) { var length protocol.ByteCount // add a PATH_RESPONSE first, but only pack a single PATH_RESPONSE per packet if len(f.pathResponses) > 0 { @@ -101,7 +175,7 @@ func (f *framer) AppendControlFrames(frames []ackhandler.Frame, maxLen protocol. if remainingLen <= maxStreamControlFrameSize { break } - fr, ok, hasMore := str.getControlFrame() + fr, ok, hasMore := str.getControlFrame(now) if !hasMore { delete(f.streamsWithControlFrames, id) } @@ -163,56 +237,33 @@ func (f *framer) RemoveActiveStream(id protocol.StreamID) { delete(f.activeStreams, id) // We don't delete the stream from the streamQueue, // since we'd have to iterate over the ringbuffer. - // Instead, we check if the stream is still in activeStreams in AppendStreamFrames. + // Instead, we check if the stream is still in activeStreams when appending STREAM frames. f.mutex.Unlock() } -func (f *framer) AppendStreamFrames(frames []ackhandler.StreamFrame, maxLen protocol.ByteCount, v protocol.Version) ([]ackhandler.StreamFrame, protocol.ByteCount) { - startLen := len(frames) - var length protocol.ByteCount - f.mutex.Lock() - // pop STREAM frames, until less than 128 bytes are left in the packet - numActiveStreams := f.streamQueue.Len() - for i := 0; i < numActiveStreams; i++ { - if protocol.MinStreamFrameSize+length > maxLen { - break - } - id := f.streamQueue.PopFront() - // This should never return an error. Better check it anyway. - // The stream will only be in the streamQueue, if it enqueued itself there. - str, ok := f.activeStreams[id] - // The stream might have been removed after being enqueued. - if !ok { - continue - } - remainingLen := maxLen - length - // For the last STREAM frame, we'll remove the DataLen field later. - // Therefore, we can pretend to have more bytes available when popping - // the STREAM frame (which will always have the DataLen set). - remainingLen += protocol.ByteCount(quicvarint.Len(uint64(remainingLen))) - frame, ok, hasMoreData := str.popStreamFrame(remainingLen, v) - if hasMoreData { // put the stream back in the queue (at the end) - f.streamQueue.PushBack(id) - } else { // no more data to send. Stream is not active - delete(f.activeStreams, id) - } - // The frame can be "nil" - // * if the stream was canceled after it said it had data - // * the remaining size doesn't allow us to add another STREAM frame - if !ok { - continue - } - frames = append(frames, frame) - length += frame.Frame.Length(v) +func (f *framer) getNextStreamFrame(maxLen protocol.ByteCount, v protocol.Version) (ackhandler.StreamFrame, *wire.StreamDataBlockedFrame) { + id := f.streamQueue.PopFront() + // This should never return an error. Better check it anyway. + // The stream will only be in the streamQueue, if it enqueued itself there. + str, ok := f.activeStreams[id] + // The stream might have been removed after being enqueued. + if !ok { + return ackhandler.StreamFrame{}, nil } - f.mutex.Unlock() - if len(frames) > startLen { - l := frames[len(frames)-1].Frame.Length(v) - // account for the smaller size of the last STREAM frame - frames[len(frames)-1].Frame.DataLenPresent = false - length += frames[len(frames)-1].Frame.Length(v) - l + // For the last STREAM frame, we'll remove the DataLen field later. + // Therefore, we can pretend to have more bytes available when popping + // the STREAM frame (which will always have the DataLen set). + maxLen += protocol.ByteCount(quicvarint.Len(uint64(maxLen))) + frame, blocked, hasMoreData := str.popStreamFrame(maxLen, v) + if hasMoreData { // put the stream back in the queue (at the end) + f.streamQueue.PushBack(id) + } else { // no more data to send. Stream is not active + delete(f.activeStreams, id) } - return frames, length + // Note that the frame.Frame can be nil: + // * if the stream was canceled after it said it had data + // * the remaining size doesn't allow us to add another STREAM frame + return frame, blocked } func (f *framer) Handle0RTTRejection() { diff --git a/vendor/github.com/quic-go/quic-go/interface.go b/vendor/github.com/quic-go/quic-go/interface.go index 2071b596..7f3c40c2 100644 --- a/vendor/github.com/quic-go/quic-go/interface.go +++ b/vendor/github.com/quic-go/quic-go/interface.go @@ -98,7 +98,6 @@ type ReceiveStream interface { // SetReadDeadline sets the deadline for future Read calls and // any currently-blocked Read call. // A zero value for t means Read will not time out. - SetReadDeadline(t time.Time) error } @@ -357,10 +356,10 @@ type ClientHelloInfo struct { type ConnectionState struct { // TLS contains information about the TLS connection state, incl. the tls.ConnectionState. TLS tls.ConnectionState - // SupportsDatagrams says if support for QUIC datagrams (RFC 9221) was negotiated. - // This requires both nodes to support and enable the datagram extensions (via Config.EnableDatagrams). - // If datagram support was negotiated, datagrams can be sent and received using the - // SendDatagram and ReceiveDatagram methods on the Connection. + // SupportsDatagrams indicates whether the peer advertised support for QUIC datagrams (RFC 9221). + // When true, datagrams can be sent using the Connection's SendDatagram method. + // This is a unilateral declaration by the peer - receiving datagrams is only possible if + // datagram support was enabled locally via Config.EnableDatagrams. SupportsDatagrams bool // Used0RTT says if 0-RTT resumption was used. Used0RTT bool diff --git a/vendor/github.com/quic-go/quic-go/internal/ackhandler/interfaces.go b/vendor/github.com/quic-go/quic-go/internal/ackhandler/interfaces.go index ba8cbbda..acf95426 100644 --- a/vendor/github.com/quic-go/quic-go/internal/ackhandler/interfaces.go +++ b/vendor/github.com/quic-go/quic-go/internal/ackhandler/interfaces.go @@ -14,10 +14,9 @@ type SentPacketHandler interface { // ReceivedAck processes an ACK frame. // It does not store a copy of the frame. ReceivedAck(f *wire.AckFrame, encLevel protocol.EncryptionLevel, rcvTime time.Time) (bool /* 1-RTT packet acked */, error) - ReceivedBytes(protocol.ByteCount) - DropPackets(protocol.EncryptionLevel) - ResetForRetry(rcvTime time.Time) error - SetHandshakeConfirmed() + ReceivedBytes(_ protocol.ByteCount, rcvTime time.Time) + DropPackets(_ protocol.EncryptionLevel, rcvTime time.Time) + ResetForRetry(rcvTime time.Time) // The SendMode determines if and what kind of packets can be sent. SendMode(now time.Time) SendMode @@ -34,12 +33,12 @@ type SentPacketHandler interface { PopPacketNumber(protocol.EncryptionLevel) protocol.PacketNumber GetLossDetectionTimeout() time.Time - OnLossDetectionTimeout() error + OnLossDetectionTimeout(now time.Time) error } type sentPacketTracker interface { GetLowestPacketNotConfirmedAcked() protocol.PacketNumber - ReceivedPacket(protocol.EncryptionLevel) + ReceivedPacket(_ protocol.EncryptionLevel, rcvTime time.Time) } // ReceivedPacketHandler handles ACKs needed to send for incoming packets @@ -49,5 +48,5 @@ type ReceivedPacketHandler interface { DropPackets(protocol.EncryptionLevel) GetAlarmTimeout() time.Time - GetAckFrame(encLevel protocol.EncryptionLevel, onlyIfQueued bool) *wire.AckFrame + GetAckFrame(_ protocol.EncryptionLevel, now time.Time, onlyIfQueued bool) *wire.AckFrame } diff --git a/vendor/github.com/quic-go/quic-go/internal/ackhandler/received_packet_handler.go b/vendor/github.com/quic-go/quic-go/internal/ackhandler/received_packet_handler.go index 1175c790..eda0826c 100644 --- a/vendor/github.com/quic-go/quic-go/internal/ackhandler/received_packet_handler.go +++ b/vendor/github.com/quic-go/quic-go/internal/ackhandler/received_packet_handler.go @@ -38,7 +38,7 @@ func (h *receivedPacketHandler) ReceivedPacket( rcvTime time.Time, ackEliciting bool, ) error { - h.sentPackets.ReceivedPacket(encLevel) + h.sentPackets.ReceivedPacket(encLevel, rcvTime) switch encLevel { case protocol.EncryptionInitial: return h.initialPackets.ReceivedPacket(pn, ecn, rcvTime, ackEliciting) @@ -87,7 +87,7 @@ func (h *receivedPacketHandler) GetAlarmTimeout() time.Time { return h.appDataPackets.GetAlarmTimeout() } -func (h *receivedPacketHandler) GetAckFrame(encLevel protocol.EncryptionLevel, onlyIfQueued bool) *wire.AckFrame { +func (h *receivedPacketHandler) GetAckFrame(encLevel protocol.EncryptionLevel, now time.Time, onlyIfQueued bool) *wire.AckFrame { //nolint:exhaustive // 0-RTT packets can't contain ACK frames. switch encLevel { case protocol.EncryptionInitial: @@ -101,7 +101,7 @@ func (h *receivedPacketHandler) GetAckFrame(encLevel protocol.EncryptionLevel, o } return nil case protocol.Encryption1RTT: - return h.appDataPackets.GetAckFrame(onlyIfQueued) + return h.appDataPackets.GetAckFrame(now, onlyIfQueued) default: // 0-RTT packets can't contain ACK frames return nil diff --git a/vendor/github.com/quic-go/quic-go/internal/ackhandler/received_packet_tracker.go b/vendor/github.com/quic-go/quic-go/internal/ackhandler/received_packet_tracker.go index 08af6f1e..d1d26f4a 100644 --- a/vendor/github.com/quic-go/quic-go/internal/ackhandler/received_packet_tracker.go +++ b/vendor/github.com/quic-go/quic-go/internal/ackhandler/received_packet_tracker.go @@ -196,8 +196,7 @@ func (h *appDataReceivedPacketTracker) shouldQueueACK(pn protocol.PacketNumber, return false } -func (h *appDataReceivedPacketTracker) GetAckFrame(onlyIfQueued bool) *wire.AckFrame { - now := time.Now() +func (h *appDataReceivedPacketTracker) GetAckFrame(now time.Time, onlyIfQueued bool) *wire.AckFrame { if onlyIfQueued && !h.ackQueued { if h.ackAlarm.IsZero() || h.ackAlarm.After(now) { return nil diff --git a/vendor/github.com/quic-go/quic-go/internal/ackhandler/sent_packet_handler.go b/vendor/github.com/quic-go/quic-go/internal/ackhandler/sent_packet_handler.go index b84f0dcb..5276fe19 100644 --- a/vendor/github.com/quic-go/quic-go/internal/ackhandler/sent_packet_handler.go +++ b/vendor/github.com/quic-go/quic-go/internal/ackhandler/sent_packet_handler.go @@ -53,6 +53,12 @@ func newPacketNumberSpace(initialPN protocol.PacketNumber, isAppData bool) *pack } } +type alarmTimer struct { + Time time.Time + TimerType logging.TimerType + EncryptionLevel protocol.EncryptionLevel +} + type sentPacketHandler struct { initialPackets *packetNumberSpace handshakePackets *packetNumberSpace @@ -90,7 +96,7 @@ type sentPacketHandler struct { numProbesToSend int // The alarm timeout - alarm time.Time + alarm alarmTimer enableECN bool ecnTracker ecnHandler @@ -155,7 +161,7 @@ func (h *sentPacketHandler) removeFromBytesInFlight(p *packet) { } } -func (h *sentPacketHandler) DropPackets(encLevel protocol.EncryptionLevel) { +func (h *sentPacketHandler) DropPackets(encLevel protocol.EncryptionLevel, now time.Time) { // The server won't await address validation after the handshake is confirmed. // This applies even if we didn't receive an ACK for a Handshake packet. if h.perspective == protocol.PerspectiveClient && encLevel == protocol.EncryptionHandshake { @@ -179,6 +185,9 @@ func (h *sentPacketHandler) DropPackets(encLevel protocol.EncryptionLevel) { case protocol.EncryptionInitial: h.initialPackets = nil case protocol.EncryptionHandshake: + // Dropping the handshake packet number space means that the handshake is confirmed, + // see section 4.9.2 of RFC 9001. + h.handshakeConfirmed = true h.handshakePackets = nil case protocol.Encryption0RTT: // This function is only called when 0-RTT is rejected, @@ -202,21 +211,21 @@ func (h *sentPacketHandler) DropPackets(encLevel protocol.EncryptionLevel) { h.ptoCount = 0 h.numProbesToSend = 0 h.ptoMode = SendNone - h.setLossDetectionTimer() + h.setLossDetectionTimer(now) } -func (h *sentPacketHandler) ReceivedBytes(n protocol.ByteCount) { +func (h *sentPacketHandler) ReceivedBytes(n protocol.ByteCount, t time.Time) { wasAmplificationLimit := h.isAmplificationLimited() h.bytesReceived += n if wasAmplificationLimit && !h.isAmplificationLimited() { - h.setLossDetectionTimer() + h.setLossDetectionTimer(t) } } -func (h *sentPacketHandler) ReceivedPacket(l protocol.EncryptionLevel) { +func (h *sentPacketHandler) ReceivedPacket(l protocol.EncryptionLevel, t time.Time) { if h.perspective == protocol.PerspectiveServer && l == protocol.EncryptionHandshake && !h.peerAddressValidated { h.peerAddressValidated = true - h.setLossDetectionTimer() + h.setLossDetectionTimer(t) } } @@ -269,7 +278,7 @@ func (h *sentPacketHandler) SentPacket( if !isAckEliciting { pnSpace.history.SentNonAckElicitingPacket(pn) if !h.peerCompletedAddressValidation { - h.setLossDetectionTimer() + h.setLossDetectionTimer(t) } return } @@ -289,7 +298,7 @@ func (h *sentPacketHandler) SentPacket( if h.tracer != nil && h.tracer.UpdatedMetrics != nil { h.tracer.UpdatedMetrics(h.rttStats, h.congestion.GetCongestionWindow(), h.bytesInFlight, h.packetsInFlight()) } - h.setLossDetectionTimer() + h.setLossDetectionTimer(t) } func (h *sentPacketHandler) getPacketNumberSpace(encLevel protocol.EncryptionLevel) *packetNumberSpace { @@ -322,7 +331,7 @@ func (h *sentPacketHandler) ReceivedAck(ack *wire.AckFrame, encLevel protocol.En h.peerCompletedAddressValidation = true h.logger.Debugf("Peer doesn't await address validation any longer.") // Make sure that the timer is reset, even if this ACK doesn't acknowledge any (ack-eliciting) packets. - h.setLossDetectionTimer() + h.setLossDetectionTimer(rcvTime) } priorInFlight := h.bytesInFlight @@ -338,7 +347,7 @@ func (h *sentPacketHandler) ReceivedAck(ack *wire.AckFrame, encLevel protocol.En if encLevel == protocol.Encryption1RTT { ackDelay = min(ack.DelayTime, h.rttStats.MaxAckDelay()) } - h.rttStats.UpdateRTT(rcvTime.Sub(p.SendTime), ackDelay, rcvTime) + h.rttStats.UpdateRTT(rcvTime.Sub(p.SendTime), ackDelay) if h.logger.Debug() { h.logger.Debugf("\tupdated RTT: %s (σ: %s)", h.rttStats.SmoothedRTT(), h.rttStats.MeanDeviation()) } @@ -387,7 +396,7 @@ func (h *sentPacketHandler) ReceivedAck(ack *wire.AckFrame, encLevel protocol.En h.tracer.UpdatedMetrics(h.rttStats, h.congestion.GetCongestionWindow(), h.bytesInFlight, h.packetsInFlight()) } - h.setLossDetectionTimer() + h.setLossDetectionTimer(rcvTime) return acked1RTTPacket, nil } @@ -498,14 +507,14 @@ func (h *sentPacketHandler) getScaledPTO(includeMaxAckDelay bool) time.Duration } // same logic as getLossTimeAndSpace, but for lastAckElicitingPacketTime instead of lossTime -func (h *sentPacketHandler) getPTOTimeAndSpace() (pto time.Time, encLevel protocol.EncryptionLevel, ok bool) { +func (h *sentPacketHandler) getPTOTimeAndSpace(now time.Time) (pto time.Time, encLevel protocol.EncryptionLevel, ok bool) { // We only send application data probe packets once the handshake is confirmed, // because before that, we don't have the keys to decrypt ACKs sent in 1-RTT packets. if !h.handshakeConfirmed && !h.hasOutstandingCryptoPackets() { if h.peerCompletedAddressValidation { return } - t := time.Now().Add(h.getScaledPTO(false)) + t := now.Add(h.getScaledPTO(false)) if h.initialPackets != nil { return t, protocol.EncryptionInitial, true } @@ -545,61 +554,53 @@ func (h *sentPacketHandler) hasOutstandingCryptoPackets() bool { return false } -func (h *sentPacketHandler) hasOutstandingPackets() bool { - return h.appDataPackets.history.HasOutstandingPackets() || h.hasOutstandingCryptoPackets() -} - -func (h *sentPacketHandler) setLossDetectionTimer() { +func (h *sentPacketHandler) setLossDetectionTimer(now time.Time) { oldAlarm := h.alarm // only needed in case tracing is enabled - lossTime, encLevel := h.getLossTimeAndSpace() - if !lossTime.IsZero() { - // Early retransmit timer or time loss detection. - h.alarm = lossTime - if h.tracer != nil && h.tracer.SetLossTimer != nil && h.alarm != oldAlarm { - h.tracer.SetLossTimer(logging.TimerTypeACK, encLevel, h.alarm) + newAlarm := h.lossDetectionTime(now) + h.alarm = newAlarm + + if newAlarm.Time.IsZero() && !oldAlarm.Time.IsZero() { + h.logger.Debugf("Canceling loss detection timer.") + if h.tracer != nil && h.tracer.LossTimerCanceled != nil { + h.tracer.LossTimerCanceled() } - return } - // Cancel the alarm if amplification limited. + if h.tracer != nil && h.tracer.SetLossTimer != nil && newAlarm != oldAlarm { + h.tracer.SetLossTimer(newAlarm.TimerType, newAlarm.EncryptionLevel, newAlarm.Time) + } +} + +func (h *sentPacketHandler) lossDetectionTime(now time.Time) alarmTimer { + // cancel the alarm if no packets are outstanding + if h.peerCompletedAddressValidation && + !h.hasOutstandingCryptoPackets() && !h.appDataPackets.history.HasOutstandingPackets() { + return alarmTimer{} + } + + // cancel the alarm if amplification limited if h.isAmplificationLimited() { - h.alarm = time.Time{} - if !oldAlarm.IsZero() { - h.logger.Debugf("Canceling loss detection timer. Amplification limited.") - if h.tracer != nil && h.tracer.LossTimerCanceled != nil { - h.tracer.LossTimerCanceled() - } - } - return + return alarmTimer{} } - // Cancel the alarm if no packets are outstanding - if !h.hasOutstandingPackets() && h.peerCompletedAddressValidation { - h.alarm = time.Time{} - if !oldAlarm.IsZero() { - h.logger.Debugf("Canceling loss detection timer. No packets in flight.") - if h.tracer != nil && h.tracer.LossTimerCanceled != nil { - h.tracer.LossTimerCanceled() - } + // early retransmit timer or time loss detection + lossTime, encLevel := h.getLossTimeAndSpace() + if !lossTime.IsZero() { + return alarmTimer{ + Time: lossTime, + TimerType: logging.TimerTypeACK, + EncryptionLevel: encLevel, } - return } - // PTO alarm - ptoTime, encLevel, ok := h.getPTOTimeAndSpace() + ptoTime, encLevel, ok := h.getPTOTimeAndSpace(now) if !ok { - if !oldAlarm.IsZero() { - h.alarm = time.Time{} - h.logger.Debugf("Canceling loss detection timer. No PTO needed..") - if h.tracer != nil && h.tracer.LossTimerCanceled != nil { - h.tracer.LossTimerCanceled() - } - } - return + return alarmTimer{} } - h.alarm = ptoTime - if h.tracer != nil && h.tracer.SetLossTimer != nil && h.alarm != oldAlarm { - h.tracer.SetLossTimer(logging.TimerTypePTO, encLevel, h.alarm) + return alarmTimer{ + Time: ptoTime, + TimerType: logging.TimerTypePTO, + EncryptionLevel: encLevel, } } @@ -623,7 +624,7 @@ func (h *sentPacketHandler) detectLostPackets(now time.Time, encLevel protocol.E } var packetLost bool - if p.SendTime.Before(lostSendTime) { + if !p.SendTime.After(lostSendTime) { packetLost = true if !p.skippedPacket { if h.logger.Debug() { @@ -669,8 +670,8 @@ func (h *sentPacketHandler) detectLostPackets(now time.Time, encLevel protocol.E }) } -func (h *sentPacketHandler) OnLossDetectionTimeout() error { - defer h.setLossDetectionTimer() +func (h *sentPacketHandler) OnLossDetectionTimeout(now time.Time) error { + defer h.setLossDetectionTimer(now) earliestLossTime, encLevel := h.getLossTimeAndSpace() if !earliestLossTime.IsZero() { if h.logger.Debug() { @@ -680,13 +681,13 @@ func (h *sentPacketHandler) OnLossDetectionTimeout() error { h.tracer.LossTimerExpired(logging.TimerTypeACK, encLevel) } // Early retransmit or time loss detection - return h.detectLostPackets(time.Now(), encLevel) + return h.detectLostPackets(now, encLevel) } // PTO - // When all outstanding are acknowledged, the alarm is canceled in - // setLossDetectionTimer. This doesn't reset the timer in the session though. - // When OnAlarm is called, we therefore need to make sure that there are + // When all outstanding are acknowledged, the alarm is canceled in setLossDetectionTimer. + // However, there's no way to reset the timer in the connection. + // When OnLossDetectionTimeout is called, we therefore need to make sure that there are // actually packets outstanding. if h.bytesInFlight == 0 && !h.peerCompletedAddressValidation { h.ptoCount++ @@ -701,7 +702,7 @@ func (h *sentPacketHandler) OnLossDetectionTimeout() error { return nil } - _, encLevel, ok := h.getPTOTimeAndSpace() + _, encLevel, ok := h.getPTOTimeAndSpace(now) if !ok { return nil } @@ -739,7 +740,7 @@ func (h *sentPacketHandler) OnLossDetectionTimeout() error { } func (h *sentPacketHandler) GetLossDetectionTimeout() time.Time { - return h.alarm + return h.alarm.Time } func (h *sentPacketHandler) ECNMode(isShortHeaderPacket bool) protocol.ECN { @@ -864,7 +865,7 @@ func (h *sentPacketHandler) queueFramesForRetransmission(p *packet) { p.Frames = nil } -func (h *sentPacketHandler) ResetForRetry(now time.Time) error { +func (h *sentPacketHandler) ResetForRetry(now time.Time) { h.bytesInFlight = 0 var firstPacketSendTime time.Time h.initialPackets.history.Iterate(func(p *packet) (bool, error) { @@ -890,7 +891,7 @@ func (h *sentPacketHandler) ResetForRetry(now time.Time) error { // Otherwise, we don't know which Initial the Retry was sent in response to. if h.ptoCount == 0 { // Don't set the RTT to a value lower than 5ms here. - h.rttStats.UpdateRTT(max(minRTTAfterRetry, now.Sub(firstPacketSendTime)), 0, now) + h.rttStats.UpdateRTT(max(minRTTAfterRetry, now.Sub(firstPacketSendTime)), 0) if h.logger.Debug() { h.logger.Debugf("\tupdated RTT: %s (σ: %s)", h.rttStats.SmoothedRTT(), h.rttStats.MeanDeviation()) } @@ -901,28 +902,14 @@ func (h *sentPacketHandler) ResetForRetry(now time.Time) error { h.initialPackets = newPacketNumberSpace(h.initialPackets.pns.Peek(), false) h.appDataPackets = newPacketNumberSpace(h.appDataPackets.pns.Peek(), true) oldAlarm := h.alarm - h.alarm = time.Time{} + h.alarm = alarmTimer{} if h.tracer != nil { if h.tracer.UpdatedPTOCount != nil { h.tracer.UpdatedPTOCount(0) } - if !oldAlarm.IsZero() && h.tracer.LossTimerCanceled != nil { + if !oldAlarm.Time.IsZero() && h.tracer.LossTimerCanceled != nil { h.tracer.LossTimerCanceled() } } h.ptoCount = 0 - return nil -} - -func (h *sentPacketHandler) SetHandshakeConfirmed() { - if h.initialPackets != nil { - panic("didn't drop initial correctly") - } - if h.handshakePackets != nil { - panic("didn't drop handshake correctly") - } - h.handshakeConfirmed = true - // We don't send PTOs for application data packets before the handshake completes. - // Make sure the timer is armed now, if necessary. - h.setLossDetectionTimer() } diff --git a/vendor/github.com/quic-go/quic-go/internal/flowcontrol/base_flow_controller.go b/vendor/github.com/quic-go/quic-go/internal/flowcontrol/base_flow_controller.go index 3d88d577..950e5f72 100644 --- a/vendor/github.com/quic-go/quic-go/internal/flowcontrol/base_flow_controller.go +++ b/vendor/github.com/quic-go/quic-go/internal/flowcontrol/base_flow_controller.go @@ -36,7 +36,7 @@ type baseFlowController struct { // For every offset, it only returns true once. // If it is blocked, the offset is returned. func (c *baseFlowController) IsNewlyBlocked() (bool, protocol.ByteCount) { - if c.sendWindowSize() != 0 || c.sendWindow == c.lastBlockedAt { + if c.SendWindowSize() != 0 || c.sendWindow == c.lastBlockedAt { return false, 0 } c.lastBlockedAt = c.sendWindow @@ -56,7 +56,7 @@ func (c *baseFlowController) UpdateSendWindow(offset protocol.ByteCount) (update return false } -func (c *baseFlowController) sendWindowSize() protocol.ByteCount { +func (c *baseFlowController) SendWindowSize() protocol.ByteCount { // this only happens during connection establishment, when data is sent before we receive the peer's transport parameters if c.bytesSent > c.sendWindow { return 0 @@ -66,11 +66,6 @@ func (c *baseFlowController) sendWindowSize() protocol.ByteCount { // needs to be called with locked mutex func (c *baseFlowController) addBytesRead(n protocol.ByteCount) { - // pretend we sent a WindowUpdate when reading the first byte - // this way auto-tuning of the window size already works for the first WindowUpdate - if c.bytesRead == 0 { - c.startNewAutoTuningEpoch(time.Now()) - } c.bytesRead += n } @@ -82,19 +77,19 @@ func (c *baseFlowController) hasWindowUpdate() bool { // getWindowUpdate updates the receive window, if necessary // it returns the new offset -func (c *baseFlowController) getWindowUpdate() protocol.ByteCount { +func (c *baseFlowController) getWindowUpdate(now time.Time) protocol.ByteCount { if !c.hasWindowUpdate() { return 0 } - c.maybeAdjustWindowSize() + c.maybeAdjustWindowSize(now) c.receiveWindow = c.bytesRead + c.receiveWindowSize return c.receiveWindow } // maybeAdjustWindowSize increases the receiveWindowSize if we're sending updates too often. // For details about auto-tuning, see https://docs.google.com/document/d/1SExkMmGiz8VYzV3s9E35JQlJ73vhzCekKkDi85F1qCE/edit?usp=sharing. -func (c *baseFlowController) maybeAdjustWindowSize() { +func (c *baseFlowController) maybeAdjustWindowSize(now time.Time) { bytesReadInEpoch := c.bytesRead - c.epochStartOffset // don't do anything if less than half the window has been consumed if bytesReadInEpoch <= c.receiveWindowSize/2 { @@ -106,7 +101,6 @@ func (c *baseFlowController) maybeAdjustWindowSize() { } fraction := float64(bytesReadInEpoch) / float64(c.receiveWindowSize) - now := time.Now() if now.Sub(c.epochStartTime) < time.Duration(4*fraction*float64(rtt)) { // window is consumed too fast, try to increase the window size newSize := min(2*c.receiveWindowSize, c.maxReceiveWindowSize) diff --git a/vendor/github.com/quic-go/quic-go/internal/flowcontrol/connection_flow_controller.go b/vendor/github.com/quic-go/quic-go/internal/flowcontrol/connection_flow_controller.go index 2efcad74..bbeb7889 100644 --- a/vendor/github.com/quic-go/quic-go/internal/flowcontrol/connection_flow_controller.go +++ b/vendor/github.com/quic-go/quic-go/internal/flowcontrol/connection_flow_controller.go @@ -24,7 +24,7 @@ func NewConnectionFlowController( allowWindowIncrease func(size protocol.ByteCount) bool, rttStats *utils.RTTStats, logger utils.Logger, -) ConnectionFlowController { +) *connectionFlowController { return &connectionFlowController{ baseFlowController: baseFlowController{ rttStats: rttStats, @@ -37,16 +37,17 @@ func NewConnectionFlowController( } } -func (c *connectionFlowController) SendWindowSize() protocol.ByteCount { - return c.baseFlowController.sendWindowSize() -} - // IncrementHighestReceived adds an increment to the highestReceived value -func (c *connectionFlowController) IncrementHighestReceived(increment protocol.ByteCount) error { +func (c *connectionFlowController) IncrementHighestReceived(increment protocol.ByteCount, now time.Time) error { c.mutex.Lock() defer c.mutex.Unlock() + // If this is the first frame received on this connection, start flow-control auto-tuning. + if c.highestReceived == 0 { + c.startNewAutoTuningEpoch(now) + } c.highestReceived += increment + if c.checkFlowControlViolation() { return &qerr.TransportError{ ErrorCode: qerr.FlowControlError, @@ -56,40 +57,47 @@ func (c *connectionFlowController) IncrementHighestReceived(increment protocol.B return nil } -func (c *connectionFlowController) AddBytesRead(n protocol.ByteCount) { +func (c *connectionFlowController) AddBytesRead(n protocol.ByteCount) (hasWindowUpdate bool) { c.mutex.Lock() + defer c.mutex.Unlock() + c.baseFlowController.addBytesRead(n) - c.mutex.Unlock() + return c.baseFlowController.hasWindowUpdate() } -func (c *connectionFlowController) GetWindowUpdate() protocol.ByteCount { +func (c *connectionFlowController) GetWindowUpdate(now time.Time) protocol.ByteCount { c.mutex.Lock() + defer c.mutex.Unlock() + oldWindowSize := c.receiveWindowSize - offset := c.baseFlowController.getWindowUpdate() + offset := c.baseFlowController.getWindowUpdate(now) if c.logger.Debug() && oldWindowSize < c.receiveWindowSize { c.logger.Debugf("Increasing receive flow control window for the connection to %d kB", c.receiveWindowSize/(1<<10)) } - c.mutex.Unlock() return offset } // EnsureMinimumWindowSize sets a minimum window size // it should make sure that the connection-level window is increased when a stream-level window grows -func (c *connectionFlowController) EnsureMinimumWindowSize(inc protocol.ByteCount) { +func (c *connectionFlowController) EnsureMinimumWindowSize(inc protocol.ByteCount, now time.Time) { c.mutex.Lock() - if inc > c.receiveWindowSize { - c.logger.Debugf("Increasing receive flow control window for the connection to %d kB, in response to stream flow control window increase", c.receiveWindowSize/(1<<10)) - newSize := min(inc, c.maxReceiveWindowSize) - if delta := newSize - c.receiveWindowSize; delta > 0 && c.allowWindowIncrease(delta) { - c.receiveWindowSize = newSize + defer c.mutex.Unlock() + + if inc <= c.receiveWindowSize { + return + } + newSize := min(inc, c.maxReceiveWindowSize) + if delta := newSize - c.receiveWindowSize; delta > 0 && c.allowWindowIncrease(delta) { + c.receiveWindowSize = newSize + if c.logger.Debug() { + c.logger.Debugf("Increasing receive flow control window for the connection to %d, in response to stream flow control window increase", newSize) } - c.startNewAutoTuningEpoch(time.Now()) } - c.mutex.Unlock() + c.startNewAutoTuningEpoch(now) } // Reset rests the flow controller. This happens when 0-RTT is rejected. -// All stream data is invalidated, it's if we had never opened a stream and never sent any data. +// All stream data is invalidated, it's as if we had never opened a stream and never sent any data. // At that point, we only have sent stream data, but we didn't have the keys to open 1-RTT keys yet. func (c *connectionFlowController) Reset() error { c.mutex.Lock() @@ -100,5 +108,6 @@ func (c *connectionFlowController) Reset() error { } c.bytesSent = 0 c.lastBlockedAt = 0 + c.sendWindow = 0 return nil } diff --git a/vendor/github.com/quic-go/quic-go/internal/flowcontrol/interface.go b/vendor/github.com/quic-go/quic-go/internal/flowcontrol/interface.go index 57d12a95..23cf30c5 100644 --- a/vendor/github.com/quic-go/quic-go/internal/flowcontrol/interface.go +++ b/vendor/github.com/quic-go/quic-go/internal/flowcontrol/interface.go @@ -1,6 +1,10 @@ package flowcontrol -import "github.com/quic-go/quic-go/internal/protocol" +import ( + "time" + + "github.com/quic-go/quic-go/internal/protocol" +) type flowController interface { // for sending @@ -8,17 +12,17 @@ type flowController interface { UpdateSendWindow(protocol.ByteCount) (updated bool) AddBytesSent(protocol.ByteCount) // for receiving - GetWindowUpdate() protocol.ByteCount // returns 0 if no update is necessary + GetWindowUpdate(time.Time) protocol.ByteCount // returns 0 if no update is necessary } // A StreamFlowController is a flow controller for a QUIC stream. type StreamFlowController interface { flowController - AddBytesRead(protocol.ByteCount) (shouldQueueWindowUpdate bool) + AddBytesRead(protocol.ByteCount) (hasStreamWindowUpdate, hasConnWindowUpdate bool) // UpdateHighestReceived is called when a new highest offset is received // final has to be to true if this is the final offset of the stream, // as contained in a STREAM frame with FIN bit, and the RESET_STREAM frame - UpdateHighestReceived(offset protocol.ByteCount, final bool) error + UpdateHighestReceived(offset protocol.ByteCount, final bool, now time.Time) error // Abandon is called when reading from the stream is aborted early, // and there won't be any further calls to AddBytesRead. Abandon() @@ -28,7 +32,7 @@ type StreamFlowController interface { // The ConnectionFlowController is the flow controller for the connection. type ConnectionFlowController interface { flowController - AddBytesRead(protocol.ByteCount) + AddBytesRead(protocol.ByteCount) (hasWindowUpdate bool) Reset() error IsNewlyBlocked() (bool, protocol.ByteCount) } @@ -37,7 +41,7 @@ type connectionFlowControllerI interface { ConnectionFlowController // The following two methods are not supposed to be called from outside this packet, but are needed internally // for sending - EnsureMinimumWindowSize(protocol.ByteCount) + EnsureMinimumWindowSize(protocol.ByteCount, time.Time) // for receiving - IncrementHighestReceived(protocol.ByteCount) error + IncrementHighestReceived(protocol.ByteCount, time.Time) error } diff --git a/vendor/github.com/quic-go/quic-go/internal/flowcontrol/stream_flow_controller.go b/vendor/github.com/quic-go/quic-go/internal/flowcontrol/stream_flow_controller.go index 2d58351c..ba005122 100644 --- a/vendor/github.com/quic-go/quic-go/internal/flowcontrol/stream_flow_controller.go +++ b/vendor/github.com/quic-go/quic-go/internal/flowcontrol/stream_flow_controller.go @@ -2,6 +2,7 @@ package flowcontrol import ( "fmt" + "time" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/qerr" @@ -45,7 +46,7 @@ func NewStreamFlowController( } // UpdateHighestReceived updates the highestReceived value, if the offset is higher. -func (c *streamFlowController) UpdateHighestReceived(offset protocol.ByteCount, final bool) error { +func (c *streamFlowController) UpdateHighestReceived(offset protocol.ByteCount, final bool, now time.Time) error { // If the final offset for this stream is already known, check for consistency. if c.receivedFinalOffset { // If we receive another final offset, check that it's the same. @@ -70,9 +71,8 @@ func (c *streamFlowController) UpdateHighestReceived(offset protocol.ByteCount, if offset == c.highestReceived { return nil } - // A higher offset was received before. - // This can happen due to reordering. - if offset <= c.highestReceived { + // A higher offset was received before. This can happen due to reordering. + if offset < c.highestReceived { if final { return &qerr.TransportError{ ErrorCode: qerr.FinalSizeError, @@ -82,23 +82,28 @@ func (c *streamFlowController) UpdateHighestReceived(offset protocol.ByteCount, return nil } + // If this is the first frame received for this stream, start flow-control auto-tuning. + if c.highestReceived == 0 { + c.startNewAutoTuningEpoch(now) + } increment := offset - c.highestReceived c.highestReceived = offset + if c.checkFlowControlViolation() { return &qerr.TransportError{ ErrorCode: qerr.FlowControlError, ErrorMessage: fmt.Sprintf("received %d bytes on stream %d, allowed %d bytes", offset, c.streamID, c.receiveWindow), } } - return c.connection.IncrementHighestReceived(increment) + return c.connection.IncrementHighestReceived(increment, now) } -func (c *streamFlowController) AddBytesRead(n protocol.ByteCount) (shouldQueueWindowUpdate bool) { +func (c *streamFlowController) AddBytesRead(n protocol.ByteCount) (hasStreamWindowUpdate, hasConnWindowUpdate bool) { c.mutex.Lock() c.baseFlowController.addBytesRead(n) - shouldQueueWindowUpdate = c.shouldQueueWindowUpdate() + hasStreamWindowUpdate = c.shouldQueueWindowUpdate() c.mutex.Unlock() - c.connection.AddBytesRead(n) + hasConnWindowUpdate = c.connection.AddBytesRead(n) return } @@ -118,7 +123,7 @@ func (c *streamFlowController) AddBytesSent(n protocol.ByteCount) { } func (c *streamFlowController) SendWindowSize() protocol.ByteCount { - return min(c.baseFlowController.sendWindowSize(), c.connection.SendWindowSize()) + return min(c.baseFlowController.SendWindowSize(), c.connection.SendWindowSize()) } func (c *streamFlowController) IsNewlyBlocked() bool { @@ -130,20 +135,20 @@ func (c *streamFlowController) shouldQueueWindowUpdate() bool { return !c.receivedFinalOffset && c.hasWindowUpdate() } -func (c *streamFlowController) GetWindowUpdate() protocol.ByteCount { +func (c *streamFlowController) GetWindowUpdate(now time.Time) protocol.ByteCount { // If we already received the final offset for this stream, the peer won't need any additional flow control credit. if c.receivedFinalOffset { return 0 } - // Don't use defer for unlocking the mutex here, GetWindowUpdate() is called frequently and defer shows up in the profiler c.mutex.Lock() + defer c.mutex.Unlock() + oldWindowSize := c.receiveWindowSize - offset := c.baseFlowController.getWindowUpdate() + offset := c.baseFlowController.getWindowUpdate(now) if c.receiveWindowSize > oldWindowSize { // auto-tuning enlarged the window size - c.logger.Debugf("Increasing receive flow control window for stream %d to %d kB", c.streamID, c.receiveWindowSize/(1<<10)) - c.connection.EnsureMinimumWindowSize(protocol.ByteCount(float64(c.receiveWindowSize) * protocol.ConnectionFlowControlMultiplier)) + c.logger.Debugf("Increasing receive flow control window for stream %d to %d", c.streamID, c.receiveWindowSize) + c.connection.EnsureMinimumWindowSize(protocol.ByteCount(float64(c.receiveWindowSize)*protocol.ConnectionFlowControlMultiplier), now) } - c.mutex.Unlock() return offset } diff --git a/vendor/github.com/quic-go/quic-go/internal/handshake/retry.go b/vendor/github.com/quic-go/quic-go/internal/handshake/retry.go index 30643cdf..27a09e22 100644 --- a/vendor/github.com/quic-go/quic-go/internal/handshake/retry.go +++ b/vendor/github.com/quic-go/quic-go/internal/handshake/retry.go @@ -10,16 +10,13 @@ import ( "github.com/quic-go/quic-go/internal/protocol" ) +// Instead of using an init function, the AEADs are created lazily. +// For more details see https://github.com/quic-go/quic-go/issues/4894. var ( retryAEADv1 cipher.AEAD // used for QUIC v1 (RFC 9000) retryAEADv2 cipher.AEAD // used for QUIC v2 (RFC 9369) ) -func init() { - retryAEADv1 = initAEAD([16]byte{0xbe, 0x0c, 0x69, 0x0b, 0x9f, 0x66, 0x57, 0x5a, 0x1d, 0x76, 0x6b, 0x54, 0xe3, 0x68, 0xc8, 0x4e}) - retryAEADv2 = initAEAD([16]byte{0x8f, 0xb4, 0xb0, 0x1b, 0x56, 0xac, 0x48, 0xe2, 0x60, 0xfb, 0xcb, 0xce, 0xad, 0x7c, 0xcc, 0x92}) -} - func initAEAD(key [16]byte) cipher.AEAD { aes, err := aes.NewCipher(key[:]) if err != nil { @@ -52,8 +49,14 @@ func GetRetryIntegrityTag(retry []byte, origDestConnID protocol.ConnectionID, ve var tag [16]byte var sealed []byte if version == protocol.Version2 { + if retryAEADv2 == nil { + retryAEADv2 = initAEAD([16]byte{0x8f, 0xb4, 0xb0, 0x1b, 0x56, 0xac, 0x48, 0xe2, 0x60, 0xfb, 0xcb, 0xce, 0xad, 0x7c, 0xcc, 0x92}) + } sealed = retryAEADv2.Seal(tag[:0], retryNonceV2[:], nil, retryBuf.Bytes()) } else { + if retryAEADv1 == nil { + retryAEADv1 = initAEAD([16]byte{0xbe, 0x0c, 0x69, 0x0b, 0x9f, 0x66, 0x57, 0x5a, 0x1d, 0x76, 0x6b, 0x54, 0xe3, 0x68, 0xc8, 0x4e}) + } sealed = retryAEADv1.Seal(tag[:0], retryNonceV1[:], nil, retryBuf.Bytes()) } if len(sealed) != 16 { diff --git a/vendor/github.com/quic-go/quic-go/internal/protocol/params.go b/vendor/github.com/quic-go/quic-go/internal/protocol/params.go index 7c4d8d4d..f0aa3ad9 100644 --- a/vendor/github.com/quic-go/quic-go/internal/protocol/params.go +++ b/vendor/github.com/quic-go/quic-go/internal/protocol/params.go @@ -102,10 +102,6 @@ const DefaultIdleTimeout = 30 * time.Second // DefaultHandshakeIdleTimeout is the default idle timeout used before handshake completion. const DefaultHandshakeIdleTimeout = 5 * time.Second -// MaxKeepAliveInterval is the maximum time until we send a packet to keep a connection alive. -// It should be shorter than the time that NATs clear their mapping. -const MaxKeepAliveInterval = 20 * time.Second - // RetiredConnectionIDDeleteTimeout is the time we keep closed connections around in order to retransmit the CONNECTION_CLOSE. // after this time all information about the old connection will be deleted const RetiredConnectionIDDeleteTimeout = 5 * time.Second diff --git a/vendor/github.com/quic-go/quic-go/internal/qerr/errors.go b/vendor/github.com/quic-go/quic-go/internal/qerr/errors.go index 8f5936df..7fe1c293 100644 --- a/vendor/github.com/quic-go/quic-go/internal/qerr/errors.go +++ b/vendor/github.com/quic-go/quic-go/internal/qerr/errors.go @@ -48,21 +48,16 @@ func (e *TransportError) Error() string { return str + ": " + msg } -func (e *TransportError) Is(target error) bool { - return target == net.ErrClosed -} +func (e *TransportError) Unwrap() []error { return []error{net.ErrClosed, e.error} } -func (e *TransportError) Unwrap() error { - return e.error +func (e *TransportError) Is(target error) bool { + t, ok := target.(*TransportError) + return ok && e.ErrorCode == t.ErrorCode && e.FrameType == t.FrameType && e.Remote == t.Remote } // An ApplicationErrorCode is an application-defined error code. type ApplicationErrorCode uint64 -func (e *ApplicationError) Is(target error) bool { - return target == net.ErrClosed -} - // A StreamErrorCode is an error code used to cancel streams. type StreamErrorCode uint64 @@ -81,23 +76,30 @@ func (e *ApplicationError) Error() string { return fmt.Sprintf("Application error %#x (%s): %s", e.ErrorCode, getRole(e.Remote), e.ErrorMessage) } +func (e *ApplicationError) Unwrap() error { return net.ErrClosed } + +func (e *ApplicationError) Is(target error) bool { + t, ok := target.(*ApplicationError) + return ok && e.ErrorCode == t.ErrorCode && e.Remote == t.Remote +} + type IdleTimeoutError struct{} var _ error = &IdleTimeoutError{} -func (e *IdleTimeoutError) Timeout() bool { return true } -func (e *IdleTimeoutError) Temporary() bool { return false } -func (e *IdleTimeoutError) Error() string { return "timeout: no recent network activity" } -func (e *IdleTimeoutError) Is(target error) bool { return target == net.ErrClosed } +func (e *IdleTimeoutError) Timeout() bool { return true } +func (e *IdleTimeoutError) Temporary() bool { return false } +func (e *IdleTimeoutError) Error() string { return "timeout: no recent network activity" } +func (e *IdleTimeoutError) Unwrap() error { return net.ErrClosed } type HandshakeTimeoutError struct{} var _ error = &HandshakeTimeoutError{} -func (e *HandshakeTimeoutError) Timeout() bool { return true } -func (e *HandshakeTimeoutError) Temporary() bool { return false } -func (e *HandshakeTimeoutError) Error() string { return "timeout: handshake did not complete in time" } -func (e *HandshakeTimeoutError) Is(target error) bool { return target == net.ErrClosed } +func (e *HandshakeTimeoutError) Timeout() bool { return true } +func (e *HandshakeTimeoutError) Temporary() bool { return false } +func (e *HandshakeTimeoutError) Error() string { return "timeout: handshake did not complete in time" } +func (e *HandshakeTimeoutError) Unwrap() error { return net.ErrClosed } // A VersionNegotiationError occurs when the client and the server can't agree on a QUIC version. type VersionNegotiationError struct { @@ -109,25 +111,18 @@ func (e *VersionNegotiationError) Error() string { return fmt.Sprintf("no compatible QUIC version found (we support %s, server offered %s)", e.Ours, e.Theirs) } -func (e *VersionNegotiationError) Is(target error) bool { - return target == net.ErrClosed -} +func (e *VersionNegotiationError) Unwrap() error { return net.ErrClosed } // A StatelessResetError occurs when we receive a stateless reset. -type StatelessResetError struct { - Token protocol.StatelessResetToken -} +type StatelessResetError struct{} var _ net.Error = &StatelessResetError{} func (e *StatelessResetError) Error() string { - return fmt.Sprintf("received a stateless reset with token %x", e.Token) -} - -func (e *StatelessResetError) Is(target error) bool { - return target == net.ErrClosed + return "received a stateless reset" } +func (e *StatelessResetError) Unwrap() error { return net.ErrClosed } func (e *StatelessResetError) Timeout() bool { return false } func (e *StatelessResetError) Temporary() bool { return true } diff --git a/vendor/github.com/quic-go/quic-go/internal/utils/rtt_stats.go b/vendor/github.com/quic-go/quic-go/internal/utils/rtt_stats.go index dcfac67d..92fec2e2 100644 --- a/vendor/github.com/quic-go/quic-go/internal/utils/rtt_stats.go +++ b/vendor/github.com/quic-go/quic-go/internal/utils/rtt_stats.go @@ -58,7 +58,7 @@ func (r *RTTStats) PTO(includeMaxAckDelay bool) time.Duration { } // UpdateRTT updates the RTT based on a new sample. -func (r *RTTStats) UpdateRTT(sendDelta, ackDelay time.Duration, now time.Time) { +func (r *RTTStats) UpdateRTT(sendDelta, ackDelay time.Duration) { if sendDelta <= 0 { return } diff --git a/vendor/github.com/quic-go/quic-go/internal/wire/frame.go b/vendor/github.com/quic-go/quic-go/internal/wire/frame.go new file mode 100644 index 00000000..10d4eebc --- /dev/null +++ b/vendor/github.com/quic-go/quic-go/internal/wire/frame.go @@ -0,0 +1,21 @@ +package wire + +import ( + "github.com/quic-go/quic-go/internal/protocol" +) + +// A Frame in QUIC +type Frame interface { + Append(b []byte, version protocol.Version) ([]byte, error) + Length(version protocol.Version) protocol.ByteCount +} + +// IsProbingFrame returns true if the frame is a probing frame. +// See section 9.1 of RFC 9000. +func IsProbingFrame(f Frame) bool { + switch f.(type) { + case *PathChallengeFrame, *PathResponseFrame, *NewConnectionIDFrame: + return true + } + return false +} diff --git a/vendor/github.com/quic-go/quic-go/internal/wire/interface.go b/vendor/github.com/quic-go/quic-go/internal/wire/interface.go deleted file mode 100644 index bc17883b..00000000 --- a/vendor/github.com/quic-go/quic-go/internal/wire/interface.go +++ /dev/null @@ -1,11 +0,0 @@ -package wire - -import ( - "github.com/quic-go/quic-go/internal/protocol" -) - -// A Frame in QUIC -type Frame interface { - Append(b []byte, version protocol.Version) ([]byte, error) - Length(version protocol.Version) protocol.ByteCount -} diff --git a/vendor/github.com/quic-go/quic-go/internal/wire/stream_frame.go b/vendor/github.com/quic-go/quic-go/internal/wire/stream_frame.go index f9470ecd..cdc32722 100644 --- a/vendor/github.com/quic-go/quic-go/internal/wire/stream_frame.go +++ b/vendor/github.com/quic-go/quic-go/internal/wire/stream_frame.go @@ -58,7 +58,10 @@ func parseStreamFrame(b []byte, typ uint64, _ protocol.Version) (*StreamFrame, i var frame *StreamFrame if dataLen < protocol.MinStreamFrameBufferSize { - frame = &StreamFrame{Data: make([]byte, dataLen)} + frame = &StreamFrame{} + if dataLen > 0 { + frame.Data = make([]byte, dataLen) + } } else { frame = GetStreamFrame() // The STREAM frame can't be larger than the StreamFrame we obtained from the buffer, @@ -74,7 +77,7 @@ func parseStreamFrame(b []byte, typ uint64, _ protocol.Version) (*StreamFrame, i frame.Fin = fin frame.DataLenPresent = hasDataLen - if dataLen != 0 { + if dataLen > 0 { copy(frame.Data, b) } if frame.Offset+frame.DataLen() > protocol.MaxByteCount { diff --git a/vendor/github.com/quic-go/quic-go/logging/connection_tracer.go b/vendor/github.com/quic-go/quic-go/logging/connection_tracer.go index 96bf4617..f218e046 100644 --- a/vendor/github.com/quic-go/quic-go/logging/connection_tracer.go +++ b/vendor/github.com/quic-go/quic-go/logging/connection_tracer.go @@ -5,34 +5,36 @@ import ( "time" ) +//go:generate go run generate_multiplexer.go ConnectionTracer connection_tracer.go multiplexer.tmpl connection_tracer_multiplexer.go + // A ConnectionTracer records events. type ConnectionTracer struct { StartedConnection func(local, remote net.Addr, srcConnID, destConnID ConnectionID) NegotiatedVersion func(chosen Version, clientVersions, serverVersions []Version) - ClosedConnection func(error) - SentTransportParameters func(*TransportParameters) - ReceivedTransportParameters func(*TransportParameters) + ClosedConnection func(err error) + SentTransportParameters func(parameters *TransportParameters) + ReceivedTransportParameters func(parameters *TransportParameters) RestoredTransportParameters func(parameters *TransportParameters) // for 0-RTT - SentLongHeaderPacket func(*ExtendedHeader, ByteCount, ECN, *AckFrame, []Frame) - SentShortHeaderPacket func(*ShortHeader, ByteCount, ECN, *AckFrame, []Frame) - ReceivedVersionNegotiationPacket func(dest, src ArbitraryLenConnectionID, _ []Version) - ReceivedRetry func(*Header) - ReceivedLongHeaderPacket func(*ExtendedHeader, ByteCount, ECN, []Frame) - ReceivedShortHeaderPacket func(*ShortHeader, ByteCount, ECN, []Frame) - BufferedPacket func(PacketType, ByteCount) - DroppedPacket func(PacketType, PacketNumber, ByteCount, PacketDropReason) + SentLongHeaderPacket func(hdr *ExtendedHeader, size ByteCount, ecn ECN, ack *AckFrame, frames []Frame) + SentShortHeaderPacket func(hdr *ShortHeader, size ByteCount, ecn ECN, ack *AckFrame, frames []Frame) + ReceivedVersionNegotiationPacket func(dest, src ArbitraryLenConnectionID, versions []Version) + ReceivedRetry func(hdr *Header) + ReceivedLongHeaderPacket func(hdr *ExtendedHeader, size ByteCount, ecn ECN, frames []Frame) + ReceivedShortHeaderPacket func(hdr *ShortHeader, size ByteCount, ecn ECN, frames []Frame) + BufferedPacket func(packetType PacketType, size ByteCount) + DroppedPacket func(packetType PacketType, pn PacketNumber, size ByteCount, reason PacketDropReason) UpdatedMetrics func(rttStats *RTTStats, cwnd, bytesInFlight ByteCount, packetsInFlight int) - AcknowledgedPacket func(EncryptionLevel, PacketNumber) - LostPacket func(EncryptionLevel, PacketNumber, PacketLossReason) + AcknowledgedPacket func(encLevel EncryptionLevel, pn PacketNumber) + LostPacket func(encLevel EncryptionLevel, pn PacketNumber, reason PacketLossReason) UpdatedMTU func(mtu ByteCount, done bool) - UpdatedCongestionState func(CongestionState) + UpdatedCongestionState func(state CongestionState) UpdatedPTOCount func(value uint32) - UpdatedKeyFromTLS func(EncryptionLevel, Perspective) + UpdatedKeyFromTLS func(encLevel EncryptionLevel, p Perspective) UpdatedKey func(keyPhase KeyPhase, remote bool) - DroppedEncryptionLevel func(EncryptionLevel) + DroppedEncryptionLevel func(encLevel EncryptionLevel) DroppedKey func(keyPhase KeyPhase) - SetLossTimer func(TimerType, EncryptionLevel, time.Time) - LossTimerExpired func(TimerType, EncryptionLevel) + SetLossTimer func(timerType TimerType, encLevel EncryptionLevel, time time.Time) + LossTimerExpired func(timerType TimerType, encLevel EncryptionLevel) LossTimerCanceled func() ECNStateUpdated func(state ECNState, trigger ECNStateTrigger) ChoseALPN func(protocol string) @@ -40,232 +42,3 @@ type ConnectionTracer struct { Close func() Debug func(name, msg string) } - -// NewMultiplexedConnectionTracer creates a new connection tracer that multiplexes events to multiple tracers. -func NewMultiplexedConnectionTracer(tracers ...*ConnectionTracer) *ConnectionTracer { - if len(tracers) == 0 { - return nil - } - if len(tracers) == 1 { - return tracers[0] - } - return &ConnectionTracer{ - StartedConnection: func(local, remote net.Addr, srcConnID, destConnID ConnectionID) { - for _, t := range tracers { - if t.StartedConnection != nil { - t.StartedConnection(local, remote, srcConnID, destConnID) - } - } - }, - NegotiatedVersion: func(chosen Version, clientVersions, serverVersions []Version) { - for _, t := range tracers { - if t.NegotiatedVersion != nil { - t.NegotiatedVersion(chosen, clientVersions, serverVersions) - } - } - }, - ClosedConnection: func(e error) { - for _, t := range tracers { - if t.ClosedConnection != nil { - t.ClosedConnection(e) - } - } - }, - SentTransportParameters: func(tp *TransportParameters) { - for _, t := range tracers { - if t.SentTransportParameters != nil { - t.SentTransportParameters(tp) - } - } - }, - ReceivedTransportParameters: func(tp *TransportParameters) { - for _, t := range tracers { - if t.ReceivedTransportParameters != nil { - t.ReceivedTransportParameters(tp) - } - } - }, - RestoredTransportParameters: func(tp *TransportParameters) { - for _, t := range tracers { - if t.RestoredTransportParameters != nil { - t.RestoredTransportParameters(tp) - } - } - }, - SentLongHeaderPacket: func(hdr *ExtendedHeader, size ByteCount, ecn ECN, ack *AckFrame, frames []Frame) { - for _, t := range tracers { - if t.SentLongHeaderPacket != nil { - t.SentLongHeaderPacket(hdr, size, ecn, ack, frames) - } - } - }, - SentShortHeaderPacket: func(hdr *ShortHeader, size ByteCount, ecn ECN, ack *AckFrame, frames []Frame) { - for _, t := range tracers { - if t.SentShortHeaderPacket != nil { - t.SentShortHeaderPacket(hdr, size, ecn, ack, frames) - } - } - }, - ReceivedVersionNegotiationPacket: func(dest, src ArbitraryLenConnectionID, versions []Version) { - for _, t := range tracers { - if t.ReceivedVersionNegotiationPacket != nil { - t.ReceivedVersionNegotiationPacket(dest, src, versions) - } - } - }, - ReceivedRetry: func(hdr *Header) { - for _, t := range tracers { - if t.ReceivedRetry != nil { - t.ReceivedRetry(hdr) - } - } - }, - ReceivedLongHeaderPacket: func(hdr *ExtendedHeader, size ByteCount, ecn ECN, frames []Frame) { - for _, t := range tracers { - if t.ReceivedLongHeaderPacket != nil { - t.ReceivedLongHeaderPacket(hdr, size, ecn, frames) - } - } - }, - ReceivedShortHeaderPacket: func(hdr *ShortHeader, size ByteCount, ecn ECN, frames []Frame) { - for _, t := range tracers { - if t.ReceivedShortHeaderPacket != nil { - t.ReceivedShortHeaderPacket(hdr, size, ecn, frames) - } - } - }, - BufferedPacket: func(typ PacketType, size ByteCount) { - for _, t := range tracers { - if t.BufferedPacket != nil { - t.BufferedPacket(typ, size) - } - } - }, - DroppedPacket: func(typ PacketType, pn PacketNumber, size ByteCount, reason PacketDropReason) { - for _, t := range tracers { - if t.DroppedPacket != nil { - t.DroppedPacket(typ, pn, size, reason) - } - } - }, - UpdatedMetrics: func(rttStats *RTTStats, cwnd, bytesInFlight ByteCount, packetsInFlight int) { - for _, t := range tracers { - if t.UpdatedMetrics != nil { - t.UpdatedMetrics(rttStats, cwnd, bytesInFlight, packetsInFlight) - } - } - }, - AcknowledgedPacket: func(encLevel EncryptionLevel, pn PacketNumber) { - for _, t := range tracers { - if t.AcknowledgedPacket != nil { - t.AcknowledgedPacket(encLevel, pn) - } - } - }, - LostPacket: func(encLevel EncryptionLevel, pn PacketNumber, reason PacketLossReason) { - for _, t := range tracers { - if t.LostPacket != nil { - t.LostPacket(encLevel, pn, reason) - } - } - }, - UpdatedMTU: func(mtu ByteCount, done bool) { - for _, t := range tracers { - if t.UpdatedMTU != nil { - t.UpdatedMTU(mtu, done) - } - } - }, - UpdatedCongestionState: func(state CongestionState) { - for _, t := range tracers { - if t.UpdatedCongestionState != nil { - t.UpdatedCongestionState(state) - } - } - }, - UpdatedPTOCount: func(value uint32) { - for _, t := range tracers { - if t.UpdatedPTOCount != nil { - t.UpdatedPTOCount(value) - } - } - }, - UpdatedKeyFromTLS: func(encLevel EncryptionLevel, perspective Perspective) { - for _, t := range tracers { - if t.UpdatedKeyFromTLS != nil { - t.UpdatedKeyFromTLS(encLevel, perspective) - } - } - }, - UpdatedKey: func(generation KeyPhase, remote bool) { - for _, t := range tracers { - if t.UpdatedKey != nil { - t.UpdatedKey(generation, remote) - } - } - }, - DroppedEncryptionLevel: func(encLevel EncryptionLevel) { - for _, t := range tracers { - if t.DroppedEncryptionLevel != nil { - t.DroppedEncryptionLevel(encLevel) - } - } - }, - DroppedKey: func(generation KeyPhase) { - for _, t := range tracers { - if t.DroppedKey != nil { - t.DroppedKey(generation) - } - } - }, - SetLossTimer: func(typ TimerType, encLevel EncryptionLevel, exp time.Time) { - for _, t := range tracers { - if t.SetLossTimer != nil { - t.SetLossTimer(typ, encLevel, exp) - } - } - }, - LossTimerExpired: func(typ TimerType, encLevel EncryptionLevel) { - for _, t := range tracers { - if t.LossTimerExpired != nil { - t.LossTimerExpired(typ, encLevel) - } - } - }, - LossTimerCanceled: func() { - for _, t := range tracers { - if t.LossTimerCanceled != nil { - t.LossTimerCanceled() - } - } - }, - ECNStateUpdated: func(state ECNState, trigger ECNStateTrigger) { - for _, t := range tracers { - if t.ECNStateUpdated != nil { - t.ECNStateUpdated(state, trigger) - } - } - }, - ChoseALPN: func(protocol string) { - for _, t := range tracers { - if t.ChoseALPN != nil { - t.ChoseALPN(protocol) - } - } - }, - Close: func() { - for _, t := range tracers { - if t.Close != nil { - t.Close() - } - } - }, - Debug: func(name, msg string) { - for _, t := range tracers { - if t.Debug != nil { - t.Debug(name, msg) - } - } - }, - } -} diff --git a/vendor/github.com/quic-go/quic-go/logging/connection_tracer_multiplexer.go b/vendor/github.com/quic-go/quic-go/logging/connection_tracer_multiplexer.go new file mode 100644 index 00000000..3a87058c --- /dev/null +++ b/vendor/github.com/quic-go/quic-go/logging/connection_tracer_multiplexer.go @@ -0,0 +1,236 @@ +// Code generated by generate_multiplexer.go; DO NOT EDIT. + +package logging + +import ( + "net" + "time" +) + +func NewMultiplexedConnectionTracer(tracers ...*ConnectionTracer) *ConnectionTracer { + if len(tracers) == 0 { + return nil + } + if len(tracers) == 1 { + return tracers[0] + } + return &ConnectionTracer{ + StartedConnection: func(local net.Addr, remote net.Addr, srcConnID ConnectionID, destConnID ConnectionID) { + for _, t := range tracers { + if t.StartedConnection != nil { + t.StartedConnection(local, remote, srcConnID, destConnID) + } + } + }, + NegotiatedVersion: func(chosen Version, clientVersions []Version, serverVersions []Version) { + for _, t := range tracers { + if t.NegotiatedVersion != nil { + t.NegotiatedVersion(chosen, clientVersions, serverVersions) + } + } + }, + ClosedConnection: func(err error) { + for _, t := range tracers { + if t.ClosedConnection != nil { + t.ClosedConnection(err) + } + } + }, + SentTransportParameters: func(parameters *TransportParameters) { + for _, t := range tracers { + if t.SentTransportParameters != nil { + t.SentTransportParameters(parameters) + } + } + }, + ReceivedTransportParameters: func(parameters *TransportParameters) { + for _, t := range tracers { + if t.ReceivedTransportParameters != nil { + t.ReceivedTransportParameters(parameters) + } + } + }, + RestoredTransportParameters: func(parameters *TransportParameters) { + for _, t := range tracers { + if t.RestoredTransportParameters != nil { + t.RestoredTransportParameters(parameters) + } + } + }, + SentLongHeaderPacket: func(hdr *ExtendedHeader, size ByteCount, ecn ECN, ack *AckFrame, frames []Frame) { + for _, t := range tracers { + if t.SentLongHeaderPacket != nil { + t.SentLongHeaderPacket(hdr, size, ecn, ack, frames) + } + } + }, + SentShortHeaderPacket: func(hdr *ShortHeader, size ByteCount, ecn ECN, ack *AckFrame, frames []Frame) { + for _, t := range tracers { + if t.SentShortHeaderPacket != nil { + t.SentShortHeaderPacket(hdr, size, ecn, ack, frames) + } + } + }, + ReceivedVersionNegotiationPacket: func(dest ArbitraryLenConnectionID, src ArbitraryLenConnectionID, versions []Version) { + for _, t := range tracers { + if t.ReceivedVersionNegotiationPacket != nil { + t.ReceivedVersionNegotiationPacket(dest, src, versions) + } + } + }, + ReceivedRetry: func(hdr *Header) { + for _, t := range tracers { + if t.ReceivedRetry != nil { + t.ReceivedRetry(hdr) + } + } + }, + ReceivedLongHeaderPacket: func(hdr *ExtendedHeader, size ByteCount, ecn ECN, frames []Frame) { + for _, t := range tracers { + if t.ReceivedLongHeaderPacket != nil { + t.ReceivedLongHeaderPacket(hdr, size, ecn, frames) + } + } + }, + ReceivedShortHeaderPacket: func(hdr *ShortHeader, size ByteCount, ecn ECN, frames []Frame) { + for _, t := range tracers { + if t.ReceivedShortHeaderPacket != nil { + t.ReceivedShortHeaderPacket(hdr, size, ecn, frames) + } + } + }, + BufferedPacket: func(packetType PacketType, size ByteCount) { + for _, t := range tracers { + if t.BufferedPacket != nil { + t.BufferedPacket(packetType, size) + } + } + }, + DroppedPacket: func(packetType PacketType, pn PacketNumber, size ByteCount, reason PacketDropReason) { + for _, t := range tracers { + if t.DroppedPacket != nil { + t.DroppedPacket(packetType, pn, size, reason) + } + } + }, + UpdatedMetrics: func(rttStats *RTTStats, cwnd ByteCount, bytesInFlight ByteCount, packetsInFlight int) { + for _, t := range tracers { + if t.UpdatedMetrics != nil { + t.UpdatedMetrics(rttStats, cwnd, bytesInFlight, packetsInFlight) + } + } + }, + AcknowledgedPacket: func(encLevel EncryptionLevel, pn PacketNumber) { + for _, t := range tracers { + if t.AcknowledgedPacket != nil { + t.AcknowledgedPacket(encLevel, pn) + } + } + }, + LostPacket: func(encLevel EncryptionLevel, pn PacketNumber, reason PacketLossReason) { + for _, t := range tracers { + if t.LostPacket != nil { + t.LostPacket(encLevel, pn, reason) + } + } + }, + UpdatedMTU: func(mtu ByteCount, done bool) { + for _, t := range tracers { + if t.UpdatedMTU != nil { + t.UpdatedMTU(mtu, done) + } + } + }, + UpdatedCongestionState: func(state CongestionState) { + for _, t := range tracers { + if t.UpdatedCongestionState != nil { + t.UpdatedCongestionState(state) + } + } + }, + UpdatedPTOCount: func(value uint32) { + for _, t := range tracers { + if t.UpdatedPTOCount != nil { + t.UpdatedPTOCount(value) + } + } + }, + UpdatedKeyFromTLS: func(encLevel EncryptionLevel, p Perspective) { + for _, t := range tracers { + if t.UpdatedKeyFromTLS != nil { + t.UpdatedKeyFromTLS(encLevel, p) + } + } + }, + UpdatedKey: func(keyPhase KeyPhase, remote bool) { + for _, t := range tracers { + if t.UpdatedKey != nil { + t.UpdatedKey(keyPhase, remote) + } + } + }, + DroppedEncryptionLevel: func(encLevel EncryptionLevel) { + for _, t := range tracers { + if t.DroppedEncryptionLevel != nil { + t.DroppedEncryptionLevel(encLevel) + } + } + }, + DroppedKey: func(keyPhase KeyPhase) { + for _, t := range tracers { + if t.DroppedKey != nil { + t.DroppedKey(keyPhase) + } + } + }, + SetLossTimer: func(timerType TimerType, encLevel EncryptionLevel, time time.Time) { + for _, t := range tracers { + if t.SetLossTimer != nil { + t.SetLossTimer(timerType, encLevel, time) + } + } + }, + LossTimerExpired: func(timerType TimerType, encLevel EncryptionLevel) { + for _, t := range tracers { + if t.LossTimerExpired != nil { + t.LossTimerExpired(timerType, encLevel) + } + } + }, + LossTimerCanceled: func() { + for _, t := range tracers { + if t.LossTimerCanceled != nil { + t.LossTimerCanceled() + } + } + }, + ECNStateUpdated: func(state ECNState, trigger ECNStateTrigger) { + for _, t := range tracers { + if t.ECNStateUpdated != nil { + t.ECNStateUpdated(state, trigger) + } + } + }, + ChoseALPN: func(protocol string) { + for _, t := range tracers { + if t.ChoseALPN != nil { + t.ChoseALPN(protocol) + } + } + }, + Close: func() { + for _, t := range tracers { + if t.Close != nil { + t.Close() + } + } + }, + Debug: func(name string, msg string) { + for _, t := range tracers { + if t.Debug != nil { + t.Debug(name, msg) + } + } + }, + } +} diff --git a/vendor/github.com/quic-go/quic-go/logging/generate_multiplexer.go b/vendor/github.com/quic-go/quic-go/logging/generate_multiplexer.go new file mode 100644 index 00000000..c152b846 --- /dev/null +++ b/vendor/github.com/quic-go/quic-go/logging/generate_multiplexer.go @@ -0,0 +1,161 @@ +//go:build generate + +package main + +import ( + "bytes" + "fmt" + "go/ast" + "go/parser" + "go/printer" + "go/token" + "log" + "os" + "strings" + "text/template" + + "golang.org/x/tools/imports" +) + +func main() { + if len(os.Args) != 5 { + log.Fatalf("Usage: %s ", os.Args[0]) + } + + structName := os.Args[1] + inputFile := os.Args[2] + templateFile := os.Args[3] + outputFile := os.Args[4] + + fset := token.NewFileSet() + + // Parse the input file containing the struct type + file, err := parser.ParseFile(fset, inputFile, nil, parser.AllErrors) + if err != nil { + log.Fatalf("Failed to parse file: %v", err) + } + + var fields []*ast.Field + + // Find the specified struct type in the AST + for _, decl := range file.Decls { + genDecl, ok := decl.(*ast.GenDecl) + if !ok || genDecl.Tok != token.TYPE { + continue + } + for _, spec := range genDecl.Specs { + typeSpec, ok := spec.(*ast.TypeSpec) + if !ok || typeSpec.Name.Name != structName { + continue + } + structType, ok := typeSpec.Type.(*ast.StructType) + if !ok { + log.Fatalf("%s is not a struct", structName) + } + fields = structType.Fields.List + break + } + } + + if fields == nil { + log.Fatalf("Could not find %s type", structName) + } + + // Prepare data for the template + type FieldData struct { + Name string + Params string + Args string + HasParams bool + ReturnTypes string + HasReturn bool + } + + var fieldDataList []FieldData + + for _, field := range fields { + funcType, ok := field.Type.(*ast.FuncType) + if !ok { + continue + } + for _, name := range field.Names { + fieldData := FieldData{Name: name.Name} + + // extract parameters + var params []string + var args []string + if funcType.Params != nil { + for i, param := range funcType.Params.List { + // We intentionally reject unnamed (and, further down, "_") function parameters. + // We could auto-generate parameter names, + // but having meaningful variable names will be more helpful for the user. + if len(param.Names) == 0 { + log.Fatalf("encountered unnamed parameter at position %d in function %s", i, fieldData.Name) + } + var buf bytes.Buffer + printer.Fprint(&buf, fset, param.Type) + paramType := buf.String() + for _, paramName := range param.Names { + if paramName.Name == "_" { + log.Fatalf("encountered underscore parameter at position %d in function %s", i, fieldData.Name) + } + params = append(params, fmt.Sprintf("%s %s", paramName.Name, paramType)) + args = append(args, paramName.Name) + } + } + } + fieldData.Params = strings.Join(params, ", ") + fieldData.Args = strings.Join(args, ", ") + fieldData.HasParams = len(params) > 0 + + // extract return types + if funcType.Results != nil && len(funcType.Results.List) > 0 { + fieldData.HasReturn = true + var returns []string + for _, result := range funcType.Results.List { + var buf bytes.Buffer + printer.Fprint(&buf, fset, result.Type) + returns = append(returns, buf.String()) + } + if len(returns) == 1 { + fieldData.ReturnTypes = fmt.Sprintf(" %s", returns[0]) + } else { + fieldData.ReturnTypes = fmt.Sprintf(" (%s)", strings.Join(returns, ", ")) + } + } + + fieldDataList = append(fieldDataList, fieldData) + } + } + + // Read the template from file + templateContent, err := os.ReadFile(templateFile) + if err != nil { + log.Fatalf("Failed to read template file: %v", err) + } + + // Generate the code using the template + tmpl, err := template.New("multiplexer").Funcs(template.FuncMap{"join": strings.Join}).Parse(string(templateContent)) + if err != nil { + log.Fatalf("Failed to parse template: %v", err) + } + + var generatedCode bytes.Buffer + generatedCode.WriteString("// Code generated by generate_multiplexer.go; DO NOT EDIT.\n\n") + if err = tmpl.Execute(&generatedCode, map[string]interface{}{ + "Fields": fieldDataList, + "StructName": structName, + }); err != nil { + log.Fatalf("Failed to execute template: %v", err) + } + + // Format the generated code and add imports + formattedCode, err := imports.Process(outputFile, generatedCode.Bytes(), nil) + if err != nil { + log.Fatalf("Failed to process imports: %v", err) + } + + if err := os.WriteFile(outputFile, formattedCode, 0o644); err != nil { + log.Fatalf("Failed to write output file: %v", err) + } +} diff --git a/vendor/github.com/quic-go/quic-go/logging/multiplexer.tmpl b/vendor/github.com/quic-go/quic-go/logging/multiplexer.tmpl new file mode 100644 index 00000000..9ba52e0f --- /dev/null +++ b/vendor/github.com/quic-go/quic-go/logging/multiplexer.tmpl @@ -0,0 +1,21 @@ +package logging + +func NewMultiplexed{{ .StructName }} (tracers ...*{{ .StructName }}) *{{ .StructName }} { + if len(tracers) == 0 { + return nil + } + if len(tracers) == 1 { + return tracers[0] + } + return &{{ .StructName }}{ + {{- range .Fields }} + {{ .Name }}: func({{ .Params }}){{ .ReturnTypes }} { + for _, t := range tracers { + if t.{{ .Name }} != nil { + t.{{ .Name }}({{ .Args }}) + } + } + }, + {{- end }} + } +} diff --git a/vendor/github.com/quic-go/quic-go/logging/tracer.go b/vendor/github.com/quic-go/quic-go/logging/tracer.go index 625a809e..4fe01462 100644 --- a/vendor/github.com/quic-go/quic-go/logging/tracer.go +++ b/vendor/github.com/quic-go/quic-go/logging/tracer.go @@ -2,58 +2,13 @@ package logging import "net" +//go:generate go run generate_multiplexer.go Tracer tracer.go multiplexer.tmpl tracer_multiplexer.go + // A Tracer traces events. type Tracer struct { - SentPacket func(net.Addr, *Header, ByteCount, []Frame) - SentVersionNegotiationPacket func(_ net.Addr, dest, src ArbitraryLenConnectionID, _ []Version) - DroppedPacket func(net.Addr, PacketType, ByteCount, PacketDropReason) + SentPacket func(dest net.Addr, hdr *Header, size ByteCount, frames []Frame) + SentVersionNegotiationPacket func(dest net.Addr, destConnID, srcConnID ArbitraryLenConnectionID, versions []Version) + DroppedPacket func(addr net.Addr, packetType PacketType, size ByteCount, reason PacketDropReason) Debug func(name, msg string) Close func() } - -// NewMultiplexedTracer creates a new tracer that multiplexes events to multiple tracers. -func NewMultiplexedTracer(tracers ...*Tracer) *Tracer { - if len(tracers) == 0 { - return nil - } - if len(tracers) == 1 { - return tracers[0] - } - return &Tracer{ - SentPacket: func(remote net.Addr, hdr *Header, size ByteCount, frames []Frame) { - for _, t := range tracers { - if t.SentPacket != nil { - t.SentPacket(remote, hdr, size, frames) - } - } - }, - SentVersionNegotiationPacket: func(remote net.Addr, dest, src ArbitraryLenConnectionID, versions []Version) { - for _, t := range tracers { - if t.SentVersionNegotiationPacket != nil { - t.SentVersionNegotiationPacket(remote, dest, src, versions) - } - } - }, - DroppedPacket: func(remote net.Addr, typ PacketType, size ByteCount, reason PacketDropReason) { - for _, t := range tracers { - if t.DroppedPacket != nil { - t.DroppedPacket(remote, typ, size, reason) - } - } - }, - Debug: func(name, msg string) { - for _, t := range tracers { - if t.Debug != nil { - t.Debug(name, msg) - } - } - }, - Close: func() { - for _, t := range tracers { - if t.Close != nil { - t.Close() - } - } - }, - } -} diff --git a/vendor/github.com/quic-go/quic-go/logging/tracer_multiplexer.go b/vendor/github.com/quic-go/quic-go/logging/tracer_multiplexer.go new file mode 100644 index 00000000..f0878cfe --- /dev/null +++ b/vendor/github.com/quic-go/quic-go/logging/tracer_multiplexer.go @@ -0,0 +1,51 @@ +// Code generated by generate_multiplexer.go; DO NOT EDIT. + +package logging + +import "net" + +func NewMultiplexedTracer(tracers ...*Tracer) *Tracer { + if len(tracers) == 0 { + return nil + } + if len(tracers) == 1 { + return tracers[0] + } + return &Tracer{ + SentPacket: func(dest net.Addr, hdr *Header, size ByteCount, frames []Frame) { + for _, t := range tracers { + if t.SentPacket != nil { + t.SentPacket(dest, hdr, size, frames) + } + } + }, + SentVersionNegotiationPacket: func(dest net.Addr, destConnID ArbitraryLenConnectionID, srcConnID ArbitraryLenConnectionID, versions []Version) { + for _, t := range tracers { + if t.SentVersionNegotiationPacket != nil { + t.SentVersionNegotiationPacket(dest, destConnID, srcConnID, versions) + } + } + }, + DroppedPacket: func(addr net.Addr, packetType PacketType, size ByteCount, reason PacketDropReason) { + for _, t := range tracers { + if t.DroppedPacket != nil { + t.DroppedPacket(addr, packetType, size, reason) + } + } + }, + Debug: func(name string, msg string) { + for _, t := range tracers { + if t.Debug != nil { + t.Debug(name, msg) + } + } + }, + Close: func() { + for _, t := range tracers { + if t.Close != nil { + t.Close() + } + } + }, + } +} diff --git a/vendor/github.com/quic-go/quic-go/mockgen.go b/vendor/github.com/quic-go/quic-go/mockgen.go index 65ec465a..1a8b28db 100644 --- a/vendor/github.com/quic-go/quic-go/mockgen.go +++ b/vendor/github.com/quic-go/quic-go/mockgen.go @@ -61,9 +61,4 @@ type PacketHandler = packetHandler //go:generate sh -c "go run go.uber.org/mock/mockgen -typed -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_packet_handler_manager_test.go github.com/quic-go/quic-go PacketHandlerManager" type PacketHandlerManager = packetHandlerManager -// Need to use source mode for the batchConn, since reflect mode follows type aliases. -// See https://github.com/golang/mock/issues/244 for details. -// -//go:generate sh -c "go run go.uber.org/mock/mockgen -typed -package quic -self_package github.com/quic-go/quic-go -source sys_conn_oob.go -destination mock_batch_conn_test.go -mock_names batchConn=MockBatchConn" - //go:generate sh -c "go run go.uber.org/mock/mockgen -typed -package quic -self_package github.com/quic-go/quic-go -self_package github.com/quic-go/quic-go -destination mock_packetconn_test.go net PacketConn" diff --git a/vendor/github.com/quic-go/quic-go/mtu_discoverer.go b/vendor/github.com/quic-go/quic-go/mtu_discoverer.go index 3f3a640a..ee636a6d 100644 --- a/vendor/github.com/quic-go/quic-go/mtu_discoverer.go +++ b/vendor/github.com/quic-go/quic-go/mtu_discoverer.go @@ -13,16 +13,16 @@ import ( type mtuDiscoverer interface { // Start starts the MTU discovery process. // It's unnecessary to call ShouldSendProbe before that. - Start() + Start(now time.Time) ShouldSendProbe(now time.Time) bool CurrentSize() protocol.ByteCount - GetPing() (ping ackhandler.Frame, datagramSize protocol.ByteCount) + GetPing(now time.Time) (ping ackhandler.Frame, datagramSize protocol.ByteCount) } const ( // At some point, we have to stop searching for a higher MTU. // We're happy to send a packet that's 10 bytes smaller than the actual MTU. - maxMTUDiff = 20 + maxMTUDiff protocol.ByteCount = 20 // send a probe packet every mtuProbeDelay RTTs mtuProbeDelay = 5 // Once maxLostMTUProbes MTU probe packets larger than a certain size are lost, @@ -94,7 +94,6 @@ type mtuFinder struct { inFlight protocol.ByteCount // the size of the probe packet currently in flight. InvalidByteCount if none is in flight min protocol.ByteCount - limit protocol.ByteCount // on initialization, we treat the maximum size as the first "lost" packet lost [maxLostMTUProbes]protocol.ByteCount @@ -114,7 +113,6 @@ func newMTUDiscoverer( f := &mtuFinder{ inFlight: protocol.InvalidByteCount, min: start, - limit: max, rttStats: rttStats, mtuIncreased: mtuIncreased, tracer: tracer, @@ -142,8 +140,8 @@ func (f *mtuFinder) max() protocol.ByteCount { return f.lost[len(f.lost)-1] } -func (f *mtuFinder) Start() { - f.lastProbeTime = time.Now() // makes sure the first probe packet is not sent immediately +func (f *mtuFinder) Start(now time.Time) { + f.lastProbeTime = now // makes sure the first probe packet is not sent immediately } func (f *mtuFinder) ShouldSendProbe(now time.Time) bool { @@ -156,14 +154,14 @@ func (f *mtuFinder) ShouldSendProbe(now time.Time) bool { return !now.Before(f.lastProbeTime.Add(mtuProbeDelay * f.rttStats.SmoothedRTT())) } -func (f *mtuFinder) GetPing() (ackhandler.Frame, protocol.ByteCount) { +func (f *mtuFinder) GetPing(now time.Time) (ackhandler.Frame, protocol.ByteCount) { var size protocol.ByteCount if f.lastProbeWasLost { size = (f.min + f.lost[0]) / 2 } else { size = (f.min + f.max()) / 2 } - f.lastProbeTime = time.Now() + f.lastProbeTime = now f.inFlight = size return ackhandler.Frame{ Frame: &wire.PingFrame{}, diff --git a/vendor/github.com/quic-go/quic-go/multiplexer.go b/vendor/github.com/quic-go/quic-go/multiplexer.go deleted file mode 100644 index 85f7f403..00000000 --- a/vendor/github.com/quic-go/quic-go/multiplexer.go +++ /dev/null @@ -1,75 +0,0 @@ -package quic - -import ( - "fmt" - "net" - "sync" - - "github.com/quic-go/quic-go/internal/utils" -) - -var ( - connMuxerOnce sync.Once - connMuxer multiplexer -) - -type indexableConn interface{ LocalAddr() net.Addr } - -type multiplexer interface { - AddConn(conn indexableConn) - RemoveConn(indexableConn) error -} - -// The connMultiplexer listens on multiple net.PacketConns and dispatches -// incoming packets to the connection handler. -type connMultiplexer struct { - mutex sync.Mutex - - conns map[string] /* LocalAddr().String() */ indexableConn - logger utils.Logger -} - -var _ multiplexer = &connMultiplexer{} - -func getMultiplexer() multiplexer { - connMuxerOnce.Do(func() { - connMuxer = &connMultiplexer{ - conns: make(map[string]indexableConn), - logger: utils.DefaultLogger.WithPrefix("muxer"), - } - }) - return connMuxer -} - -func (m *connMultiplexer) index(addr net.Addr) string { - return addr.Network() + " " + addr.String() -} - -func (m *connMultiplexer) AddConn(c indexableConn) { - m.mutex.Lock() - defer m.mutex.Unlock() - - connIndex := m.index(c.LocalAddr()) - p, ok := m.conns[connIndex] - if ok { - // Panics if we're already listening on this connection. - // This is a safeguard because we're introducing a breaking API change, see - // https://github.com/quic-go/quic-go/issues/3727 for details. - // We'll remove this at a later time, when most users of the library have made the switch. - panic("connection already exists") // TODO: write a nice message - } - m.conns[connIndex] = p -} - -func (m *connMultiplexer) RemoveConn(c indexableConn) error { - m.mutex.Lock() - defer m.mutex.Unlock() - - connIndex := m.index(c.LocalAddr()) - if _, ok := m.conns[connIndex]; !ok { - return fmt.Errorf("cannote remove connection, connection is unknown") - } - - delete(m.conns, connIndex) - return nil -} diff --git a/vendor/github.com/quic-go/quic-go/packet_handler_map.go b/vendor/github.com/quic-go/quic-go/packet_handler_map.go index 7840202c..84841984 100644 --- a/vendor/github.com/quic-go/quic-go/packet_handler_map.go +++ b/vendor/github.com/quic-go/quic-go/packet_handler_map.go @@ -1,10 +1,6 @@ package quic import ( - "crypto/hmac" - "crypto/rand" - "crypto/sha256" - "hash" "io" "net" "sync" @@ -56,15 +52,12 @@ type packetHandlerMap struct { deleteRetiredConnsAfter time.Duration - statelessResetMutex sync.Mutex - statelessResetHasher hash.Hash - logger utils.Logger } var _ packetHandlerManager = &packetHandlerMap{} -func newPacketHandlerMap(key *StatelessResetKey, enqueueClosePacket func(closePacket), logger utils.Logger) *packetHandlerMap { +func newPacketHandlerMap(enqueueClosePacket func(closePacket), logger utils.Logger) *packetHandlerMap { h := &packetHandlerMap{ closeChan: make(chan struct{}), handlers: make(map[protocol.ConnectionID]packetHandler), @@ -73,9 +66,6 @@ func newPacketHandlerMap(key *StatelessResetKey, enqueueClosePacket func(closePa enqueueClosePacket: enqueueClosePacket, logger: logger, } - if key != nil { - h.statelessResetHasher = hmac.New(sha256.New, key[:]) - } if h.logger.Debug() { go h.logUsage() } @@ -236,20 +226,3 @@ func (h *packetHandlerMap) Close(e error) { h.mutex.Unlock() wg.Wait() } - -func (h *packetHandlerMap) GetStatelessResetToken(connID protocol.ConnectionID) protocol.StatelessResetToken { - var token protocol.StatelessResetToken - if h.statelessResetHasher == nil { - // Return a random stateless reset token. - // This token will be sent in the server's transport parameters. - // By using a random token, an off-path attacker won't be able to disrupt the connection. - rand.Read(token[:]) - return token - } - h.statelessResetMutex.Lock() - h.statelessResetHasher.Write(connID.Bytes()) - copy(token[:], h.statelessResetHasher.Sum(nil)) - h.statelessResetHasher.Reset() - h.statelessResetMutex.Unlock() - return token -} diff --git a/vendor/github.com/quic-go/quic-go/packet_packer.go b/vendor/github.com/quic-go/quic-go/packet_packer.go index 8b8a03d4..7724b503 100644 --- a/vendor/github.com/quic-go/quic-go/packet_packer.go +++ b/vendor/github.com/quic-go/quic-go/packet_packer.go @@ -5,6 +5,7 @@ import ( "encoding/binary" "errors" "fmt" + "time" "golang.org/x/exp/rand" @@ -18,10 +19,10 @@ import ( var errNothingToPack = errors.New("nothing to pack") type packer interface { - PackCoalescedPacket(onlyAck bool, maxPacketSize protocol.ByteCount, v protocol.Version) (*coalescedPacket, error) - PackAckOnlyPacket(maxPacketSize protocol.ByteCount, v protocol.Version) (shortHeaderPacket, *packetBuffer, error) - AppendPacket(buf *packetBuffer, maxPacketSize protocol.ByteCount, v protocol.Version) (shortHeaderPacket, error) - MaybePackProbePacket(protocol.EncryptionLevel, protocol.ByteCount, protocol.Version) (*coalescedPacket, error) + PackCoalescedPacket(onlyAck bool, maxPacketSize protocol.ByteCount, now time.Time, v protocol.Version) (*coalescedPacket, error) + PackAckOnlyPacket(maxPacketSize protocol.ByteCount, now time.Time, v protocol.Version) (shortHeaderPacket, *packetBuffer, error) + AppendPacket(buf *packetBuffer, maxPacketSize protocol.ByteCount, now time.Time, v protocol.Version) (shortHeaderPacket, error) + MaybePackProbePacket(protocol.EncryptionLevel, protocol.ByteCount, time.Time, protocol.Version) (*coalescedPacket, error) PackConnectionClose(*qerr.TransportError, protocol.ByteCount, protocol.Version) (*coalescedPacket, error) PackApplicationClose(*qerr.ApplicationError, protocol.ByteCount, protocol.Version) (*coalescedPacket, error) PackMTUProbePacket(ping ackhandler.Frame, size protocol.ByteCount, v protocol.Version) (shortHeaderPacket, *packetBuffer, error) @@ -106,12 +107,11 @@ type sealingManager interface { type frameSource interface { HasData() bool - AppendStreamFrames([]ackhandler.StreamFrame, protocol.ByteCount, protocol.Version) ([]ackhandler.StreamFrame, protocol.ByteCount) - AppendControlFrames([]ackhandler.Frame, protocol.ByteCount, protocol.Version) ([]ackhandler.Frame, protocol.ByteCount) + Append([]ackhandler.Frame, []ackhandler.StreamFrame, protocol.ByteCount, time.Time, protocol.Version) ([]ackhandler.Frame, []ackhandler.StreamFrame, protocol.ByteCount) } type ackFrameSource interface { - GetAckFrame(encLevel protocol.EncryptionLevel, onlyIfQueued bool) *wire.AckFrame + GetAckFrame(_ protocol.EncryptionLevel, now time.Time, onlyIfQueued bool) *wire.AckFrame } type packetPacker struct { @@ -328,7 +328,7 @@ func (p *packetPacker) initialPaddingLen(frames []ackhandler.Frame, currentSize, // PackCoalescedPacket packs a new packet. // It packs an Initial / Handshake if there is data to send in these packet number spaces. // It should only be called before the handshake is confirmed. -func (p *packetPacker) PackCoalescedPacket(onlyAck bool, maxPacketSize protocol.ByteCount, v protocol.Version) (*coalescedPacket, error) { +func (p *packetPacker) PackCoalescedPacket(onlyAck bool, maxSize protocol.ByteCount, now time.Time, v protocol.Version) (*coalescedPacket, error) { var ( initialHdr, handshakeHdr, zeroRTTHdr *wire.ExtendedHeader initialPayload, handshakePayload, zeroRTTPayload, oneRTTPayload payload @@ -342,7 +342,14 @@ func (p *packetPacker) PackCoalescedPacket(onlyAck bool, maxPacketSize protocol. } var size protocol.ByteCount if initialSealer != nil { - initialHdr, initialPayload = p.maybeGetCryptoPacket(maxPacketSize-protocol.ByteCount(initialSealer.Overhead()), protocol.EncryptionInitial, onlyAck, true, v) + initialHdr, initialPayload = p.maybeGetCryptoPacket( + maxSize-protocol.ByteCount(initialSealer.Overhead()), + protocol.EncryptionInitial, + now, + onlyAck, + true, + v, + ) if initialPayload.length > 0 { size += p.longHeaderPacketLength(initialHdr, initialPayload, v) + protocol.ByteCount(initialSealer.Overhead()) } @@ -350,14 +357,21 @@ func (p *packetPacker) PackCoalescedPacket(onlyAck bool, maxPacketSize protocol. // Add a Handshake packet. var handshakeSealer sealer - if (onlyAck && size == 0) || (!onlyAck && size < maxPacketSize-protocol.MinCoalescedPacketSize) { + if (onlyAck && size == 0) || (!onlyAck && size < maxSize-protocol.MinCoalescedPacketSize) { var err error handshakeSealer, err = p.cryptoSetup.GetHandshakeSealer() if err != nil && err != handshake.ErrKeysDropped && err != handshake.ErrKeysNotYetAvailable { return nil, err } if handshakeSealer != nil { - handshakeHdr, handshakePayload = p.maybeGetCryptoPacket(maxPacketSize-size-protocol.ByteCount(handshakeSealer.Overhead()), protocol.EncryptionHandshake, onlyAck, size == 0, v) + handshakeHdr, handshakePayload = p.maybeGetCryptoPacket( + maxSize-size-protocol.ByteCount(handshakeSealer.Overhead()), + protocol.EncryptionHandshake, + now, + onlyAck, + size == 0, + v, + ) if handshakePayload.length > 0 { s := p.longHeaderPacketLength(handshakeHdr, handshakePayload, v) + protocol.ByteCount(handshakeSealer.Overhead()) size += s @@ -370,7 +384,7 @@ func (p *packetPacker) PackCoalescedPacket(onlyAck bool, maxPacketSize protocol. var oneRTTSealer handshake.ShortHeaderSealer var connID protocol.ConnectionID var kp protocol.KeyPhaseBit - if (onlyAck && size == 0) || (!onlyAck && size < maxPacketSize-protocol.MinCoalescedPacketSize) { + if (onlyAck && size == 0) || (!onlyAck && size < maxSize-protocol.MinCoalescedPacketSize) { var err error oneRTTSealer, err = p.cryptoSetup.Get1RTTSealer() if err != nil && err != handshake.ErrKeysDropped && err != handshake.ErrKeysNotYetAvailable { @@ -381,7 +395,7 @@ func (p *packetPacker) PackCoalescedPacket(onlyAck bool, maxPacketSize protocol. connID = p.getDestConnID() oneRTTPacketNumber, oneRTTPacketNumberLen = p.pnManager.PeekPacketNumber(protocol.Encryption1RTT) hdrLen := wire.ShortHeaderLen(connID, oneRTTPacketNumberLen) - oneRTTPayload = p.maybeGetShortHeaderPacket(oneRTTSealer, hdrLen, maxPacketSize-size, onlyAck, size == 0, v) + oneRTTPayload = p.maybeGetShortHeaderPacket(oneRTTSealer, hdrLen, maxSize-size, onlyAck, size == 0, now, v) if oneRTTPayload.length > 0 { size += p.shortHeaderPacketLength(connID, oneRTTPacketNumberLen, oneRTTPayload) + protocol.ByteCount(oneRTTSealer.Overhead()) } @@ -392,7 +406,7 @@ func (p *packetPacker) PackCoalescedPacket(onlyAck bool, maxPacketSize protocol. return nil, err } if zeroRTTSealer != nil { - zeroRTTHdr, zeroRTTPayload = p.maybeGetAppDataPacketFor0RTT(zeroRTTSealer, maxPacketSize-size, v) + zeroRTTHdr, zeroRTTPayload = p.maybeGetAppDataPacketFor0RTT(zeroRTTSealer, maxSize-size, now, v) if zeroRTTPayload.length > 0 { size += p.longHeaderPacketLength(zeroRTTHdr, zeroRTTPayload, v) + protocol.ByteCount(zeroRTTSealer.Overhead()) } @@ -410,7 +424,7 @@ func (p *packetPacker) PackCoalescedPacket(onlyAck bool, maxPacketSize protocol. longHdrPackets: make([]*longHeaderPacket, 0, 3), } if initialPayload.length > 0 { - padding := p.initialPaddingLen(initialPayload.frames, size, maxPacketSize) + padding := p.initialPaddingLen(initialPayload.frames, size, maxSize) cont, err := p.appendLongHeaderPacket(buffer, initialHdr, initialPayload, padding, protocol.EncryptionInitial, initialSealer, v) if err != nil { return nil, err @@ -431,7 +445,7 @@ func (p *packetPacker) PackCoalescedPacket(onlyAck bool, maxPacketSize protocol. } packet.longHdrPackets = append(packet.longHdrPackets, longHdrPacket) } else if oneRTTPayload.length > 0 { - shp, err := p.appendShortHeaderPacket(buffer, connID, oneRTTPacketNumber, oneRTTPacketNumberLen, kp, oneRTTPayload, 0, maxPacketSize, oneRTTSealer, false, v) + shp, err := p.appendShortHeaderPacket(buffer, connID, oneRTTPacketNumber, oneRTTPacketNumberLen, kp, oneRTTPayload, 0, maxSize, oneRTTSealer, false, v) if err != nil { return nil, err } @@ -442,19 +456,25 @@ func (p *packetPacker) PackCoalescedPacket(onlyAck bool, maxPacketSize protocol. // PackAckOnlyPacket packs a packet containing only an ACK in the application data packet number space. // It should be called after the handshake is confirmed. -func (p *packetPacker) PackAckOnlyPacket(maxPacketSize protocol.ByteCount, v protocol.Version) (shortHeaderPacket, *packetBuffer, error) { +func (p *packetPacker) PackAckOnlyPacket(maxSize protocol.ByteCount, now time.Time, v protocol.Version) (shortHeaderPacket, *packetBuffer, error) { buf := getPacketBuffer() - packet, err := p.appendPacket(buf, true, maxPacketSize, v) + packet, err := p.appendPacket(buf, true, maxSize, now, v) return packet, buf, err } // AppendPacket packs a packet in the application data packet number space. // It should be called after the handshake is confirmed. -func (p *packetPacker) AppendPacket(buf *packetBuffer, maxPacketSize protocol.ByteCount, v protocol.Version) (shortHeaderPacket, error) { - return p.appendPacket(buf, false, maxPacketSize, v) +func (p *packetPacker) AppendPacket(buf *packetBuffer, maxSize protocol.ByteCount, now time.Time, v protocol.Version) (shortHeaderPacket, error) { + return p.appendPacket(buf, false, maxSize, now, v) } -func (p *packetPacker) appendPacket(buf *packetBuffer, onlyAck bool, maxPacketSize protocol.ByteCount, v protocol.Version) (shortHeaderPacket, error) { +func (p *packetPacker) appendPacket( + buf *packetBuffer, + onlyAck bool, + maxPacketSize protocol.ByteCount, + now time.Time, + v protocol.Version, +) (shortHeaderPacket, error) { sealer, err := p.cryptoSetup.Get1RTTSealer() if err != nil { return shortHeaderPacket{}, err @@ -462,7 +482,7 @@ func (p *packetPacker) appendPacket(buf *packetBuffer, onlyAck bool, maxPacketSi pn, pnLen := p.pnManager.PeekPacketNumber(protocol.Encryption1RTT) connID := p.getDestConnID() hdrLen := wire.ShortHeaderLen(connID, pnLen) - pl := p.maybeGetShortHeaderPacket(sealer, hdrLen, maxPacketSize, onlyAck, true, v) + pl := p.maybeGetShortHeaderPacket(sealer, hdrLen, maxPacketSize, onlyAck, true, now, v) if pl.length == 0 { return shortHeaderPacket{}, errNothingToPack } @@ -471,9 +491,15 @@ func (p *packetPacker) appendPacket(buf *packetBuffer, onlyAck bool, maxPacketSi return p.appendShortHeaderPacket(buf, connID, pn, pnLen, kp, pl, 0, maxPacketSize, sealer, false, v) } -func (p *packetPacker) maybeGetCryptoPacket(maxPacketSize protocol.ByteCount, encLevel protocol.EncryptionLevel, onlyAck, ackAllowed bool, v protocol.Version) (*wire.ExtendedHeader, payload) { +func (p *packetPacker) maybeGetCryptoPacket( + maxPacketSize protocol.ByteCount, + encLevel protocol.EncryptionLevel, + now time.Time, + onlyAck, ackAllowed bool, + v protocol.Version, +) (*wire.ExtendedHeader, payload) { if onlyAck { - if ack := p.acks.GetAckFrame(encLevel, true); ack != nil { + if ack := p.acks.GetAckFrame(encLevel, now, true); ack != nil { return p.getLongHeader(encLevel, v), payload{ ack: ack, length: ack.Length(v), @@ -500,7 +526,7 @@ func (p *packetPacker) maybeGetCryptoPacket(maxPacketSize protocol.ByteCount, en hasData := s.HasData() var ack *wire.AckFrame if ackAllowed { - ack = p.acks.GetAckFrame(encLevel, !hasRetransmission && !hasData) + ack = p.acks.GetAckFrame(encLevel, now, !hasRetransmission && !hasData) } if !hasData && !hasRetransmission && ack == nil { // nothing to send @@ -518,7 +544,7 @@ func (p *packetPacker) maybeGetCryptoPacket(maxPacketSize protocol.ByteCount, en if hasRetransmission { for { var f ackhandler.Frame - //nolint:exhaustive // 0-RTT packets can't contain any retransmission.s + //nolint:exhaustive // 0-RTT packets can't contain any retransmissions switch encLevel { case protocol.EncryptionInitial: f.Frame = p.retransmissionQueue.GetInitialFrame(maxPacketSize, v) @@ -543,23 +569,39 @@ func (p *packetPacker) maybeGetCryptoPacket(maxPacketSize protocol.ByteCount, en return hdr, pl } -func (p *packetPacker) maybeGetAppDataPacketFor0RTT(sealer sealer, maxPacketSize protocol.ByteCount, v protocol.Version) (*wire.ExtendedHeader, payload) { +func (p *packetPacker) maybeGetAppDataPacketFor0RTT( + sealer sealer, + maxPacketSize protocol.ByteCount, + now time.Time, + v protocol.Version, +) (*wire.ExtendedHeader, payload) { if p.perspective != protocol.PerspectiveClient { return nil, payload{} } hdr := p.getLongHeader(protocol.Encryption0RTT, v) maxPayloadSize := maxPacketSize - hdr.GetLength(v) - protocol.ByteCount(sealer.Overhead()) - return hdr, p.maybeGetAppDataPacket(maxPayloadSize, false, false, v) + return hdr, p.maybeGetAppDataPacket(maxPayloadSize, false, false, now, v) } -func (p *packetPacker) maybeGetShortHeaderPacket(sealer handshake.ShortHeaderSealer, hdrLen protocol.ByteCount, maxPacketSize protocol.ByteCount, onlyAck, ackAllowed bool, v protocol.Version) payload { +func (p *packetPacker) maybeGetShortHeaderPacket( + sealer handshake.ShortHeaderSealer, + hdrLen, maxPacketSize protocol.ByteCount, + onlyAck, ackAllowed bool, + now time.Time, + v protocol.Version, +) payload { maxPayloadSize := maxPacketSize - hdrLen - protocol.ByteCount(sealer.Overhead()) - return p.maybeGetAppDataPacket(maxPayloadSize, onlyAck, ackAllowed, v) + return p.maybeGetAppDataPacket(maxPayloadSize, onlyAck, ackAllowed, now, v) } -func (p *packetPacker) maybeGetAppDataPacket(maxPayloadSize protocol.ByteCount, onlyAck, ackAllowed bool, v protocol.Version) payload { - pl := p.composeNextPacket(maxPayloadSize, onlyAck, ackAllowed, v) +func (p *packetPacker) maybeGetAppDataPacket( + maxPayloadSize protocol.ByteCount, + onlyAck, ackAllowed bool, + now time.Time, + v protocol.Version, +) payload { + pl := p.composeNextPacket(maxPayloadSize, onlyAck, ackAllowed, now, v) // check if we have anything to send if len(pl.frames) == 0 && len(pl.streamFrames) == 0 { @@ -581,9 +623,14 @@ func (p *packetPacker) maybeGetAppDataPacket(maxPayloadSize protocol.ByteCount, return pl } -func (p *packetPacker) composeNextPacket(maxFrameSize protocol.ByteCount, onlyAck, ackAllowed bool, v protocol.Version) payload { +func (p *packetPacker) composeNextPacket( + maxFrameSize protocol.ByteCount, + onlyAck, ackAllowed bool, + now time.Time, + v protocol.Version, +) payload { if onlyAck { - if ack := p.acks.GetAckFrame(protocol.Encryption1RTT, true); ack != nil { + if ack := p.acks.GetAckFrame(protocol.Encryption1RTT, now, true); ack != nil { return payload{ack: ack, length: ack.Length(v)} } return payload{} @@ -595,7 +642,7 @@ func (p *packetPacker) composeNextPacket(maxFrameSize protocol.ByteCount, onlyAc var hasAck bool var pl payload if ackAllowed { - if ack := p.acks.GetAckFrame(protocol.Encryption1RTT, !hasRetransmission && !hasData); ack != nil { + if ack := p.acks.GetAckFrame(protocol.Encryption1RTT, now, !hasRetransmission && !hasData); ack != nil { pl.ack = ack pl.length += ack.Length(v) hasAck = true @@ -641,7 +688,7 @@ func (p *packetPacker) composeNextPacket(maxFrameSize protocol.ByteCount, onlyAc if hasData { var lengthAdded protocol.ByteCount startLen := len(pl.frames) - pl.frames, lengthAdded = p.framer.AppendControlFrames(pl.frames, maxFrameSize-pl.length, v) + pl.frames, pl.streamFrames, lengthAdded = p.framer.Append(pl.frames, pl.streamFrames, maxFrameSize-pl.length, now, v) pl.length += lengthAdded // add handlers for the control frames that were added for i := startLen; i < len(pl.frames); i++ { @@ -656,14 +703,16 @@ func (p *packetPacker) composeNextPacket(maxFrameSize protocol.ByteCount, onlyAc pl.frames[i].Handler = p.retransmissionQueue.AppDataAckHandler() } } - - pl.streamFrames, lengthAdded = p.framer.AppendStreamFrames(pl.streamFrames, maxFrameSize-pl.length, v) - pl.length += lengthAdded } return pl } -func (p *packetPacker) MaybePackProbePacket(encLevel protocol.EncryptionLevel, maxPacketSize protocol.ByteCount, v protocol.Version) (*coalescedPacket, error) { +func (p *packetPacker) MaybePackProbePacket( + encLevel protocol.EncryptionLevel, + maxPacketSize protocol.ByteCount, + now time.Time, + v protocol.Version, +) (*coalescedPacket, error) { if encLevel == protocol.Encryption1RTT { s, err := p.cryptoSetup.Get1RTTSealer() if err != nil { @@ -673,7 +722,7 @@ func (p *packetPacker) MaybePackProbePacket(encLevel protocol.EncryptionLevel, m connID := p.getDestConnID() pn, pnLen := p.pnManager.PeekPacketNumber(protocol.Encryption1RTT) hdrLen := wire.ShortHeaderLen(connID, pnLen) - pl := p.maybeGetAppDataPacket(maxPacketSize-protocol.ByteCount(s.Overhead())-hdrLen, false, true, v) + pl := p.maybeGetAppDataPacket(maxPacketSize-protocol.ByteCount(s.Overhead())-hdrLen, false, true, now, v) if pl.length == 0 { return nil, nil } @@ -687,8 +736,6 @@ func (p *packetPacker) MaybePackProbePacket(encLevel protocol.EncryptionLevel, m return packet, nil } - var hdr *wire.ExtendedHeader - var pl payload var sealer handshake.LongHeaderSealer //nolint:exhaustive // Probe packets are never sent for 0-RTT. switch encLevel { @@ -698,18 +745,16 @@ func (p *packetPacker) MaybePackProbePacket(encLevel protocol.EncryptionLevel, m if err != nil { return nil, err } - hdr, pl = p.maybeGetCryptoPacket(maxPacketSize-protocol.ByteCount(sealer.Overhead()), protocol.EncryptionInitial, false, true, v) case protocol.EncryptionHandshake: var err error sealer, err = p.cryptoSetup.GetHandshakeSealer() if err != nil { return nil, err } - hdr, pl = p.maybeGetCryptoPacket(maxPacketSize-protocol.ByteCount(sealer.Overhead()), protocol.EncryptionHandshake, false, true, v) default: panic("unknown encryption level") } - + hdr, pl := p.maybeGetCryptoPacket(maxPacketSize-protocol.ByteCount(sealer.Overhead()), encLevel, now, false, true, v) if pl.length == 0 { return nil, nil } diff --git a/vendor/github.com/quic-go/quic-go/receive_stream.go b/vendor/github.com/quic-go/quic-go/receive_stream.go index b8535ef5..192b92f7 100644 --- a/vendor/github.com/quic-go/quic-go/receive_stream.go +++ b/vendor/github.com/quic-go/quic-go/receive_stream.go @@ -17,8 +17,8 @@ import ( type receiveStreamI interface { ReceiveStream - handleStreamFrame(*wire.StreamFrame) error - handleResetStreamFrame(*wire.ResetStreamFrame) error + handleStreamFrame(*wire.StreamFrame, time.Time) error + handleResetStreamFrame(*wire.ResetStreamFrame, time.Time) error closeForShutdown(error) } @@ -91,16 +91,19 @@ func (s *receiveStream) Read(p []byte) (int, error) { defer func() { <-s.readOnce }() s.mutex.Lock() - queuedNewControlFrame, n, err := s.readImpl(p) + queuedStreamWindowUpdate, queuedConnWindowUpdate, n, err := s.readImpl(p) completed := s.isNewlyCompleted() s.mutex.Unlock() if completed { s.sender.onStreamCompleted(s.streamID) } - if queuedNewControlFrame { + if queuedStreamWindowUpdate { s.sender.onHasStreamControlFrame(s.streamID, s) } + if queuedConnWindowUpdate { + s.sender.onHasConnectionData() + } return n, err } @@ -125,20 +128,19 @@ func (s *receiveStream) isNewlyCompleted() bool { return false } -func (s *receiveStream) readImpl(p []byte) (bool, int, error) { +func (s *receiveStream) readImpl(p []byte) (hasStreamWindowUpdate bool, hasConnWindowUpdate bool, _ int, _ error) { if s.currentFrameIsLast && s.currentFrame == nil { s.errorRead = true - return false, 0, io.EOF + return false, false, 0, io.EOF } if s.cancelledRemotely || s.cancelledLocally { s.errorRead = true - return false, 0, s.cancelErr + return false, false, 0, s.cancelErr } if s.closeForShutdownErr != nil { - return false, 0, s.closeForShutdownErr + return false, false, 0, s.closeForShutdownErr } - var queuedNewControlFrame bool var bytesRead int var deadlineTimer *utils.Timer for bytesRead < len(p) { @@ -146,23 +148,23 @@ func (s *receiveStream) readImpl(p []byte) (bool, int, error) { s.dequeueNextFrame() } if s.currentFrame == nil && bytesRead > 0 { - return queuedNewControlFrame, bytesRead, s.closeForShutdownErr + return hasStreamWindowUpdate, hasConnWindowUpdate, bytesRead, s.closeForShutdownErr } for { // Stop waiting on errors if s.closeForShutdownErr != nil { - return queuedNewControlFrame, bytesRead, s.closeForShutdownErr + return hasStreamWindowUpdate, hasConnWindowUpdate, bytesRead, s.closeForShutdownErr } if s.cancelledRemotely || s.cancelledLocally { s.errorRead = true - return queuedNewControlFrame, 0, s.cancelErr + return hasStreamWindowUpdate, hasConnWindowUpdate, 0, s.cancelErr } deadline := s.deadline if !deadline.IsZero() { if !time.Now().Before(deadline) { - return queuedNewControlFrame, bytesRead, errDeadline + return hasStreamWindowUpdate, hasConnWindowUpdate, bytesRead, errDeadline } if deadlineTimer == nil { deadlineTimer = utils.NewTimer() @@ -192,10 +194,10 @@ func (s *receiveStream) readImpl(p []byte) (bool, int, error) { } if bytesRead > len(p) { - return queuedNewControlFrame, bytesRead, fmt.Errorf("BUG: bytesRead (%d) > len(p) (%d) in stream.Read", bytesRead, len(p)) + return hasStreamWindowUpdate, hasConnWindowUpdate, bytesRead, fmt.Errorf("BUG: bytesRead (%d) > len(p) (%d) in stream.Read", bytesRead, len(p)) } if s.readPosInFrame > len(s.currentFrame) { - return queuedNewControlFrame, bytesRead, fmt.Errorf("BUG: readPosInFrame (%d) > frame.DataLen (%d) in stream.Read", s.readPosInFrame, len(s.currentFrame)) + return hasStreamWindowUpdate, hasConnWindowUpdate, bytesRead, fmt.Errorf("BUG: readPosInFrame (%d) > frame.DataLen (%d) in stream.Read", s.readPosInFrame, len(s.currentFrame)) } m := copy(p[bytesRead:], s.currentFrame[s.readPosInFrame:]) @@ -205,9 +207,13 @@ func (s *receiveStream) readImpl(p []byte) (bool, int, error) { // when a RESET_STREAM was received, the flow controller was already // informed about the final byteOffset for this stream if !s.cancelledRemotely { - if queueMaxStreamData := s.flowController.AddBytesRead(protocol.ByteCount(m)); queueMaxStreamData { + hasStream, hasConn := s.flowController.AddBytesRead(protocol.ByteCount(m)) + if hasStream { s.queuedMaxStreamData = true - queuedNewControlFrame = true + hasStreamWindowUpdate = true + } + if hasConn { + hasConnWindowUpdate = true } } @@ -217,10 +223,10 @@ func (s *receiveStream) readImpl(p []byte) (bool, int, error) { s.currentFrameDone() } s.errorRead = true - return queuedNewControlFrame, bytesRead, io.EOF + return hasStreamWindowUpdate, hasConnWindowUpdate, bytesRead, io.EOF } } - return queuedNewControlFrame, bytesRead, nil + return hasStreamWindowUpdate, hasConnWindowUpdate, bytesRead, nil } func (s *receiveStream) dequeueNextFrame() { @@ -266,9 +272,9 @@ func (s *receiveStream) cancelReadImpl(errorCode qerr.StreamErrorCode) (queuedNe return true } -func (s *receiveStream) handleStreamFrame(frame *wire.StreamFrame) error { +func (s *receiveStream) handleStreamFrame(frame *wire.StreamFrame, now time.Time) error { s.mutex.Lock() - err := s.handleStreamFrameImpl(frame) + err := s.handleStreamFrameImpl(frame, now) completed := s.isNewlyCompleted() s.mutex.Unlock() @@ -279,9 +285,9 @@ func (s *receiveStream) handleStreamFrame(frame *wire.StreamFrame) error { return err } -func (s *receiveStream) handleStreamFrameImpl(frame *wire.StreamFrame) error { +func (s *receiveStream) handleStreamFrameImpl(frame *wire.StreamFrame, now time.Time) error { maxOffset := frame.Offset + frame.DataLen() - if err := s.flowController.UpdateHighestReceived(maxOffset, frame.Fin); err != nil { + if err := s.flowController.UpdateHighestReceived(maxOffset, frame.Fin, now); err != nil { return err } if frame.Fin { @@ -297,9 +303,9 @@ func (s *receiveStream) handleStreamFrameImpl(frame *wire.StreamFrame) error { return nil } -func (s *receiveStream) handleResetStreamFrame(frame *wire.ResetStreamFrame) error { +func (s *receiveStream) handleResetStreamFrame(frame *wire.ResetStreamFrame, now time.Time) error { s.mutex.Lock() - err := s.handleResetStreamFrameImpl(frame) + err := s.handleResetStreamFrameImpl(frame, now) completed := s.isNewlyCompleted() s.mutex.Unlock() @@ -309,11 +315,11 @@ func (s *receiveStream) handleResetStreamFrame(frame *wire.ResetStreamFrame) err return err } -func (s *receiveStream) handleResetStreamFrameImpl(frame *wire.ResetStreamFrame) error { +func (s *receiveStream) handleResetStreamFrameImpl(frame *wire.ResetStreamFrame, now time.Time) error { if s.closeForShutdownErr != nil { return nil } - if err := s.flowController.UpdateHighestReceived(frame.FinalSize, true); err != nil { + if err := s.flowController.UpdateHighestReceived(frame.FinalSize, true, now); err != nil { return err } s.finalOffset = frame.FinalSize @@ -333,7 +339,7 @@ func (s *receiveStream) handleResetStreamFrameImpl(frame *wire.ResetStreamFrame) return nil } -func (s *receiveStream) getControlFrame() (_ ackhandler.Frame, ok, hasMore bool) { +func (s *receiveStream) getControlFrame(now time.Time) (_ ackhandler.Frame, ok, hasMore bool) { s.mutex.Lock() defer s.mutex.Unlock() @@ -349,7 +355,10 @@ func (s *receiveStream) getControlFrame() (_ ackhandler.Frame, ok, hasMore bool) s.queuedMaxStreamData = false return ackhandler.Frame{ - Frame: &wire.MaxStreamDataFrame{StreamID: s.streamID, MaximumStreamData: s.flowController.GetWindowUpdate()}, + Frame: &wire.MaxStreamDataFrame{ + StreamID: s.streamID, + MaximumStreamData: s.flowController.GetWindowUpdate(now), + }, }, true, false } diff --git a/vendor/github.com/quic-go/quic-go/send_stream.go b/vendor/github.com/quic-go/quic-go/send_stream.go index 699c40ef..a588cc8a 100644 --- a/vendor/github.com/quic-go/quic-go/send_stream.go +++ b/vendor/github.com/quic-go/quic-go/send_stream.go @@ -18,7 +18,7 @@ type sendStreamI interface { SendStream handleStopSendingFrame(*wire.StopSendingFrame) hasData() bool - popStreamFrame(maxBytes protocol.ByteCount, v protocol.Version) (frame ackhandler.StreamFrame, ok, hasMore bool) + popStreamFrame(protocol.ByteCount, protocol.Version) (_ ackhandler.StreamFrame, _ *wire.StreamDataBlockedFrame, hasMore bool) closeForShutdown(error) updateSendWindow(protocol.ByteCount) } @@ -37,11 +37,10 @@ type sendStream struct { writeOffset protocol.ByteCount - cancelWriteErr *StreamError - closeForShutdownErr error - - queuedResetStreamFrame bool - queuedBlockedFrame bool + // finalError is the error that is returned by Write. + // It can either be a cancellation error or the shutdown error. + finalError error + queuedResetStreamFrame *wire.ResetStreamFrame finishedWriting bool // set once Close() is called finSent bool // set when a STREAM_FRAME with FIN bit has been sent @@ -49,6 +48,8 @@ type sendStream struct { // This can happen because the application called CancelWrite, // or because Write returned the error (for remote cancellations). cancellationFlagged bool + cancelled bool // both local and remote cancellations + closedForShutdown bool // set by closeForShutdown completed bool // set when this stream has been reported to the streamSender as completed dataForWriting []byte // during a Write() call, this slice is the part of p that still needs to be sent out @@ -106,16 +107,15 @@ func (s *sendStream) write(p []byte) (bool /* is newly completed */, int, error) s.mutex.Lock() defer s.mutex.Unlock() + if s.finalError != nil { + if s.cancelled { + s.cancellationFlagged = true + } + return s.isNewlyCompleted(), 0, s.finalError + } if s.finishedWriting { return false, 0, fmt.Errorf("write on closed stream %d", s.streamID) } - if s.cancelWriteErr != nil { - s.cancellationFlagged = true - return s.isNewlyCompleted(), 0, s.cancelWriteErr - } - if s.closeForShutdownErr != nil { - return false, 0, s.closeForShutdownErr - } if !s.deadline.IsZero() && !time.Now().Before(s.deadline) { return false, 0, errDeadline } @@ -169,7 +169,7 @@ func (s *sendStream) write(p []byte) (bool /* is newly completed */, int, error) } deadlineTimer.Reset(deadline) } - if s.dataForWriting == nil || s.cancelWriteErr != nil || s.closeForShutdownErr != nil { + if s.dataForWriting == nil || s.finalError != nil { break } } @@ -198,11 +198,11 @@ func (s *sendStream) write(p []byte) (bool /* is newly completed */, int, error) if bytesWritten == len(p) { return false, bytesWritten, nil } - if s.closeForShutdownErr != nil { - return false, bytesWritten, s.closeForShutdownErr - } else if s.cancelWriteErr != nil { - s.cancellationFlagged = true - return s.isNewlyCompleted(), bytesWritten, s.cancelWriteErr + if s.finalError != nil { + if s.cancelled { + s.cancellationFlagged = true + } + return s.isNewlyCompleted(), bytesWritten, s.finalError } return false, bytesWritten, nil } @@ -217,40 +217,37 @@ func (s *sendStream) canBufferStreamFrame() bool { // popStreamFrame returns the next STREAM frame that is supposed to be sent on this stream // maxBytes is the maximum length this frame (including frame header) will have. -func (s *sendStream) popStreamFrame(maxBytes protocol.ByteCount, v protocol.Version) (af ackhandler.StreamFrame, ok, hasMore bool) { +func (s *sendStream) popStreamFrame(maxBytes protocol.ByteCount, v protocol.Version) (_ ackhandler.StreamFrame, _ *wire.StreamDataBlockedFrame, hasMore bool) { s.mutex.Lock() - f, hasMoreData, queuedControlFrame := s.popNewOrRetransmittedStreamFrame(maxBytes, v) + f, blocked, hasMoreData := s.popNewOrRetransmittedStreamFrame(maxBytes, v) if f != nil { s.numOutstandingFrames++ } s.mutex.Unlock() - if queuedControlFrame { - s.sender.onHasStreamControlFrame(s.streamID, s) - } if f == nil { - return ackhandler.StreamFrame{}, false, hasMoreData + return ackhandler.StreamFrame{}, blocked, hasMoreData } return ackhandler.StreamFrame{ Frame: f, Handler: (*sendStreamAckHandler)(s), - }, true, hasMoreData + }, blocked, hasMoreData } -func (s *sendStream) popNewOrRetransmittedStreamFrame(maxBytes protocol.ByteCount, v protocol.Version) (_ *wire.StreamFrame, hasMoreData, queuedControlFrame bool) { - if s.cancelWriteErr != nil || s.closeForShutdownErr != nil { - return nil, false, false +func (s *sendStream) popNewOrRetransmittedStreamFrame(maxBytes protocol.ByteCount, v protocol.Version) (_ *wire.StreamFrame, _ *wire.StreamDataBlockedFrame, hasMoreData bool) { + if s.finalError != nil { + return nil, nil, false } if len(s.retransmissionQueue) > 0 { f, hasMoreRetransmissions := s.maybeGetRetransmission(maxBytes, v) if f != nil || hasMoreRetransmissions { if f == nil { - return nil, true, false + return nil, nil, true } // We always claim that we have more data to send. // This might be incorrect, in which case there'll be a spurious call to popStreamFrame in the future. - return f, true, false + return f, nil, true } } @@ -262,38 +259,45 @@ func (s *sendStream) popNewOrRetransmittedStreamFrame(maxBytes protocol.ByteCoun Offset: s.writeOffset, DataLenPresent: true, Fin: true, - }, false, false + }, nil, false } - return nil, false, false + return nil, nil, false } sendWindow := s.flowController.SendWindowSize() if sendWindow == 0 { - if s.flowController.IsNewlyBlocked() { - s.queuedBlockedFrame = true - return nil, false, true - } - return nil, true, false + return nil, nil, true } f, hasMoreData := s.popNewStreamFrame(maxBytes, sendWindow, v) - if dataLen := f.DataLen(); dataLen > 0 { + if f == nil { + return nil, nil, hasMoreData + } + if f.DataLen() > 0 { s.writeOffset += f.DataLen() s.flowController.AddBytesSent(f.DataLen()) } + var blocked *wire.StreamDataBlockedFrame + // If the entire send window is used, the stream might have become blocked on stream-level flow control. + // This is not guaranteed though, because the stream might also have been blocked on connection-level flow control. + if f.DataLen() == sendWindow && s.flowController.IsNewlyBlocked() { + blocked = &wire.StreamDataBlockedFrame{StreamID: s.streamID, MaximumStreamData: s.writeOffset} + } f.Fin = s.finishedWriting && s.dataForWriting == nil && s.nextFrame == nil && !s.finSent if f.Fin { s.finSent = true } - return f, hasMoreData, false + return f, blocked, hasMoreData } func (s *sendStream) popNewStreamFrame(maxBytes, sendWindow protocol.ByteCount, v protocol.Version) (*wire.StreamFrame, bool) { if s.nextFrame != nil { + maxDataLen := min(sendWindow, s.nextFrame.MaxDataLen(maxBytes, v)) + if maxDataLen == 0 { + return nil, true + } nextFrame := s.nextFrame s.nextFrame = nil - - maxDataLen := min(sendWindow, nextFrame.MaxDataLen(maxBytes, v)) if nextFrame.DataLen() > maxDataLen { s.nextFrame = wire.GetStreamFrame() s.nextFrame.StreamID = s.streamID @@ -371,7 +375,7 @@ func (s *sendStream) isNewlyCompleted() bool { return false } // We need to keep the stream around until all frames have been sent and acknowledged. - if s.numOutstandingFrames > 0 || len(s.retransmissionQueue) > 0 || s.queuedResetStreamFrame { + if s.numOutstandingFrames > 0 || len(s.retransmissionQueue) > 0 || s.queuedResetStreamFrame != nil { return false } // The stream is completed if we sent the FIN. @@ -384,7 +388,7 @@ func (s *sendStream) isNewlyCompleted() bool { // 2. we received a STOP_SENDING, and // * the application consumed the error via Write, or // * the application called Close - if s.cancelWriteErr != nil && (s.cancellationFlagged || s.finishedWriting) { + if s.cancelled && (s.cancellationFlagged || s.finishedWriting) { s.completed = true return true } @@ -393,13 +397,13 @@ func (s *sendStream) isNewlyCompleted() bool { func (s *sendStream) Close() error { s.mutex.Lock() - if s.closeForShutdownErr != nil { + if s.closedForShutdown || s.finishedWriting { s.mutex.Unlock() return nil } s.finishedWriting = true - cancelWriteErr := s.cancelWriteErr - if cancelWriteErr != nil { + cancelled := s.cancelled + if cancelled { s.cancellationFlagged = true } completed := s.isNewlyCompleted() @@ -408,7 +412,7 @@ func (s *sendStream) Close() error { if completed { s.sender.onStreamCompleted(s.streamID) } - if cancelWriteErr != nil { + if cancelled { return fmt.Errorf("close called for canceled stream %d", s.streamID) } s.sender.onHasStreamData(s.streamID, s) // need to send the FIN, must be called without holding the mutex @@ -418,18 +422,21 @@ func (s *sendStream) Close() error { } func (s *sendStream) CancelWrite(errorCode StreamErrorCode) { - s.cancelWriteImpl(errorCode, false) + s.cancelWrite(errorCode, false) } -func (s *sendStream) cancelWriteImpl(errorCode qerr.StreamErrorCode, remote bool) { +// cancelWrite cancels the stream +// It is possible to cancel a stream after it has been closed, both locally and remotely. +// This is useful to prevent the retransmission of outstanding stream data. +func (s *sendStream) cancelWrite(errorCode qerr.StreamErrorCode, remote bool) { s.mutex.Lock() - if s.closeForShutdownErr != nil { + if s.closedForShutdown { s.mutex.Unlock() return } if !remote { s.cancellationFlagged = true - if s.cancelWriteErr != nil { + if s.cancelled { completed := s.isNewlyCompleted() s.mutex.Unlock() // The user has called CancelWrite. If the previous cancellation was @@ -441,15 +448,20 @@ func (s *sendStream) cancelWriteImpl(errorCode qerr.StreamErrorCode, remote bool return } } - if s.cancelWriteErr != nil { + if s.cancelled { s.mutex.Unlock() return } - s.cancelWriteErr = &StreamError{StreamID: s.streamID, ErrorCode: errorCode, Remote: remote} - s.ctxCancel(s.cancelWriteErr) + s.cancelled = true + s.finalError = &StreamError{StreamID: s.streamID, ErrorCode: errorCode, Remote: remote} + s.ctxCancel(s.finalError) s.numOutstandingFrames = 0 s.retransmissionQueue = nil - s.queuedResetStreamFrame = true + s.queuedResetStreamFrame = &wire.ResetStreamFrame{ + StreamID: s.streamID, + FinalSize: s.writeOffset, + ErrorCode: errorCode, + } s.mutex.Unlock() s.signalWrite() @@ -470,33 +482,23 @@ func (s *sendStream) updateSendWindow(limit protocol.ByteCount) { } func (s *sendStream) handleStopSendingFrame(frame *wire.StopSendingFrame) { - s.cancelWriteImpl(frame.ErrorCode, true) + s.cancelWrite(frame.ErrorCode, true) } -func (s *sendStream) getControlFrame() (_ ackhandler.Frame, ok, hasMore bool) { +func (s *sendStream) getControlFrame(time.Time) (_ ackhandler.Frame, ok, hasMore bool) { s.mutex.Lock() defer s.mutex.Unlock() - if !s.queuedBlockedFrame && !s.queuedResetStreamFrame { + if s.queuedResetStreamFrame == nil { return ackhandler.Frame{}, false, false } - if s.queuedBlockedFrame { - s.queuedBlockedFrame = false - return ackhandler.Frame{ - Frame: &wire.StreamDataBlockedFrame{StreamID: s.streamID, MaximumStreamData: s.writeOffset}, - }, true, s.queuedResetStreamFrame - } - // RESET_STREAM frame - s.queuedResetStreamFrame = false s.numOutstandingFrames++ - return ackhandler.Frame{ - Frame: &wire.ResetStreamFrame{ - StreamID: s.streamID, - FinalSize: s.writeOffset, - ErrorCode: s.cancelWriteErr.ErrorCode, - }, + f := ackhandler.Frame{ + Frame: s.queuedResetStreamFrame, Handler: (*sendStreamResetStreamHandler)(s), - }, true, false + } + s.queuedResetStreamFrame = nil + return f, true, false } func (s *sendStream) Context() context.Context { @@ -516,7 +518,10 @@ func (s *sendStream) SetWriteDeadline(t time.Time) error { // The peer will NOT be informed about this: the stream is closed without sending a FIN or RST. func (s *sendStream) closeForShutdown(err error) { s.mutex.Lock() - s.closeForShutdownErr = err + s.closedForShutdown = true + if s.finalError == nil && !s.finishedWriting { + s.finalError = err + } s.mutex.Unlock() s.signalWrite() } @@ -537,7 +542,7 @@ func (s *sendStreamAckHandler) OnAcked(f wire.Frame) { sf := f.(*wire.StreamFrame) sf.PutBack() s.mutex.Lock() - if s.cancelWriteErr != nil { + if s.cancelled { s.mutex.Unlock() return } @@ -556,7 +561,7 @@ func (s *sendStreamAckHandler) OnAcked(f wire.Frame) { func (s *sendStreamAckHandler) OnLost(f wire.Frame) { sf := f.(*wire.StreamFrame) s.mutex.Lock() - if s.cancelWriteErr != nil { + if s.cancelled { s.mutex.Unlock() return } @@ -589,9 +594,10 @@ func (s *sendStreamResetStreamHandler) OnAcked(wire.Frame) { } } -func (s *sendStreamResetStreamHandler) OnLost(wire.Frame) { +func (s *sendStreamResetStreamHandler) OnLost(f wire.Frame) { s.mutex.Lock() - s.queuedResetStreamFrame = true + s.queuedResetStreamFrame = f.(*wire.ResetStreamFrame) + s.numOutstandingFrames-- s.mutex.Unlock() s.sender.onHasStreamControlFrame(s.streamID, (*sendStream)(s)) } diff --git a/vendor/github.com/quic-go/quic-go/server.go b/vendor/github.com/quic-go/quic-go/server.go index 0cf45aca..2bb821ab 100644 --- a/vendor/github.com/quic-go/quic-go/server.go +++ b/vendor/github.com/quic-go/quic-go/server.go @@ -72,9 +72,10 @@ type baseServer struct { tokenGenerator *handshake.TokenGenerator maxTokenAge time.Duration - connIDGenerator ConnectionIDGenerator - connHandler packetHandlerManager - onClose func() + connIDGenerator ConnectionIDGenerator + statelessResetter *statelessResetter + connHandler packetHandlerManager + onClose func() receivedPackets chan receivedPacket @@ -95,7 +96,7 @@ type baseServer struct { protocol.ConnectionID, /* destination connection ID */ protocol.ConnectionID, /* source connection ID */ ConnectionIDGenerator, - protocol.StatelessResetToken, + *statelessResetter, *Config, *tls.Config, *handshake.TokenGenerator, @@ -105,15 +106,24 @@ type baseServer struct { protocol.Version, ) quicConn - closeMx sync.Mutex - errorChan chan struct{} // is closed when the server is closed - closeErr error - running chan struct{} // closed as soon as run() returns + closeMx sync.Mutex + // errorChan is closed when Close is called. This has two effects: + // 1. it cancels handshakes that are still in flight (using CONNECTION_REFUSED) errors + // 2. it stops handling of packets passed to this server + errorChan chan struct{} + // acceptChan is closed when Close returns. + // This only happens once all handshake in flight have either completed and canceled. + // Calls to Accept will first drain the queue of connections that have completed the handshake, + // and then return ErrServerClosed. + stopAccepting chan struct{} + closeErr error + running chan struct{} // closed as soon as run() returns versionNegotiationQueue chan receivedPacket invalidTokenQueue chan rejectedPacket connectionRefusedQueue chan rejectedPacket retryQueue chan rejectedPacket + handshakingCount sync.WaitGroup verifySourceAddress func(net.Addr) bool @@ -239,6 +249,7 @@ func newServer( conn rawConn, connHandler packetHandlerManager, connIDGenerator ConnectionIDGenerator, + statelessResetter *statelessResetter, connContext func(context.Context) context.Context, tlsConf *tls.Config, config *Config, @@ -259,9 +270,11 @@ func newServer( maxTokenAge: maxTokenAge, verifySourceAddress: verifySourceAddress, connIDGenerator: connIDGenerator, + statelessResetter: statelessResetter, connHandler: connHandler, connQueue: make(chan quicConn, protocol.MaxAcceptQueueSize), errorChan: make(chan struct{}), + stopAccepting: make(chan struct{}), running: make(chan struct{}), receivedPackets: make(chan receivedPacket, protocol.MaxServerUnprocessedPackets), versionNegotiationQueue: make(chan receivedPacket, 4), @@ -332,7 +345,13 @@ func (s *baseServer) accept(ctx context.Context) (quicConn, error) { return nil, ctx.Err() case conn := <-s.connQueue: return conn, nil - case <-s.errorChan: + case <-s.stopAccepting: + // first drain the queue + select { + case conn := <-s.connQueue: + return conn, nil + default: + } return nil, s.closeErr } } @@ -356,6 +375,9 @@ func (s *baseServer) close(e error, notifyOnClose bool) { if notifyOnClose { s.onClose() } + // wait until all handshakes in flight have terminated + s.handshakingCount.Wait() + close(s.stopAccepting) } // Addr returns the server's network address @@ -366,6 +388,8 @@ func (s *baseServer) Addr() net.Addr { func (s *baseServer) handlePacket(p receivedPacket) { select { case s.receivedPackets <- p: + case <-s.errorChan: + return default: s.logger.Debugf("Dropping packet from %s (%d bytes). Server receive queue full.", p.remoteAddr, p.Size()) if s.tracer != nil && s.tracer.DroppedPacket != nil { @@ -686,7 +710,7 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error hdr.SrcConnectionID, connID, s.connIDGenerator, - s.connHandler.GetStatelessResetToken(connID), + s.statelessResetter, config, s.tlsConf, s.tokenGenerator, @@ -713,43 +737,42 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error delete(s.zeroRTTQueues, hdr.DestConnectionID) } - go conn.run() + s.handshakingCount.Add(1) go func() { - if completed := s.handleNewConn(conn); !completed { - return - } - - select { - case s.connQueue <- conn: - default: - conn.closeWithTransportError(ConnectionRefused) - } + defer s.handshakingCount.Done() + s.handleNewConn(conn) }() + go conn.run() return nil } -func (s *baseServer) handleNewConn(conn quicConn) bool { +func (s *baseServer) handleNewConn(conn quicConn) { if s.acceptEarlyConns { // wait until the early connection is ready, the handshake fails, or the server is closed select { case <-s.errorChan: conn.closeWithTransportError(ConnectionRefused) - return false + return case <-conn.Context().Done(): - return false + return case <-conn.earlyConnReady(): - return true + } + } else { + // wait until the handshake completes, fails, or the server is closed + select { + case <-s.errorChan: + conn.closeWithTransportError(ConnectionRefused) + return + case <-conn.Context().Done(): + return + case <-conn.HandshakeComplete(): } } - // wait until the handshake completes, fails, or the server is closed + select { - case <-s.errorChan: + case s.connQueue <- conn: + default: conn.closeWithTransportError(ConnectionRefused) - return false - case <-conn.Context().Done(): - return false - case <-conn.HandshakeComplete(): - return true } } diff --git a/vendor/github.com/quic-go/quic-go/stateless_reset.go b/vendor/github.com/quic-go/quic-go/stateless_reset.go new file mode 100644 index 00000000..cd0059a5 --- /dev/null +++ b/vendor/github.com/quic-go/quic-go/stateless_reset.go @@ -0,0 +1,42 @@ +package quic + +import ( + "crypto/hmac" + "crypto/rand" + "crypto/sha256" + "hash" + "sync" + + "github.com/quic-go/quic-go/internal/protocol" +) + +type statelessResetter struct { + mx sync.Mutex + h hash.Hash +} + +// newStatelessRetter creates a new stateless reset generator. +// It is valid to use a nil key. In that case, a random key will be used. +// This makes is impossible for on-path attackers to shut down established connections. +func newStatelessResetter(key *StatelessResetKey) *statelessResetter { + var h hash.Hash + if key != nil { + h = hmac.New(sha256.New, key[:]) + } else { + b := make([]byte, 32) + _, _ = rand.Read(b) + h = hmac.New(sha256.New, b) + } + return &statelessResetter{h: h} +} + +func (r *statelessResetter) GetStatelessResetToken(connID protocol.ConnectionID) protocol.StatelessResetToken { + r.mx.Lock() + defer r.mx.Unlock() + + var token protocol.StatelessResetToken + r.h.Write(connID.Bytes()) + copy(token[:], r.h.Sum(nil)) + r.h.Reset() + return token +} diff --git a/vendor/github.com/quic-go/quic-go/stream.go b/vendor/github.com/quic-go/quic-go/stream.go index 1ed26323..9cd2695d 100644 --- a/vendor/github.com/quic-go/quic-go/stream.go +++ b/vendor/github.com/quic-go/quic-go/stream.go @@ -24,6 +24,7 @@ var errDeadline net.Error = &deadlineError{} // The streamSender is notified by the stream about various events. type streamSender interface { + onHasConnectionData() onHasStreamData(protocol.StreamID, sendStreamI) onHasStreamControlFrame(protocol.StreamID, streamControlFrameGetter) // must be called without holding the mutex that is acquired by closeForShutdown @@ -52,12 +53,12 @@ type streamI interface { Stream closeForShutdown(error) // for receiving - handleStreamFrame(*wire.StreamFrame) error - handleResetStreamFrame(*wire.ResetStreamFrame) error + handleStreamFrame(*wire.StreamFrame, time.Time) error + handleResetStreamFrame(*wire.ResetStreamFrame, time.Time) error // for sending hasData() bool handleStopSendingFrame(*wire.StopSendingFrame) - popStreamFrame(maxBytes protocol.ByteCount, v protocol.Version) (ackhandler.StreamFrame, bool, bool) + popStreamFrame(protocol.ByteCount, protocol.Version) (_ ackhandler.StreamFrame, _ *wire.StreamDataBlockedFrame, hasMore bool) updateSendWindow(protocol.ByteCount) } @@ -131,12 +132,12 @@ func (s *stream) Close() error { return s.sendStream.Close() } -func (s *stream) getControlFrame() (_ ackhandler.Frame, ok, hasMore bool) { - f, ok, _ := s.sendStream.getControlFrame() +func (s *stream) getControlFrame(now time.Time) (_ ackhandler.Frame, ok, hasMore bool) { + f, ok, _ := s.sendStream.getControlFrame(now) if ok { return f, true, true } - return s.receiveStream.getControlFrame() + return s.receiveStream.getControlFrame(now) } func (s *stream) SetDeadline(t time.Time) error { diff --git a/vendor/github.com/quic-go/quic-go/sys_conn.go b/vendor/github.com/quic-go/quic-go/sys_conn.go index 71cc4607..811131d9 100644 --- a/vendor/github.com/quic-go/quic-go/sys_conn.go +++ b/vendor/github.com/quic-go/quic-go/sys_conn.go @@ -58,8 +58,8 @@ func wrapConn(pc net.PacketConn) (rawConn, error) { return nil, err } + // only set DF on UDP sockets if _, ok := pc.LocalAddr().(*net.UDPAddr); ok { - // Only set DF on sockets that we expect to be able to handle that configuration. var err error supportsDF, err = setDF(rawConn) if err != nil { diff --git a/vendor/github.com/quic-go/quic-go/sys_conn_df_darwin.go b/vendor/github.com/quic-go/quic-go/sys_conn_df_darwin.go index b51cd8f1..8ed273ee 100644 --- a/vendor/github.com/quic-go/quic-go/sys_conn_df_darwin.go +++ b/vendor/github.com/quic-go/quic-go/sys_conn_df_darwin.go @@ -4,47 +4,67 @@ package quic import ( "errors" + "fmt" "strconv" "strings" "syscall" "golang.org/x/sys/unix" +) - "github.com/quic-go/quic-go/internal/utils" +// for macOS versions, see https://en.wikipedia.org/wiki/Darwin_(operating_system)#Darwin_20_onwards +const ( + macOSVersion11 = 20 + macOSVersion15 = 24 ) func setDF(rawConn syscall.RawConn) (bool, error) { - // Setting DF bit is only supported from macOS11 + // Setting DF bit is only supported from macOS 11. // https://github.com/chromium/chromium/blob/117.0.5881.2/net/socket/udp_socket_posix.cc#L555 - if supportsDF, err := isAtLeastMacOS11(); !supportsDF || err != nil { + version, err := getMacOSVersion() + if err != nil || version < macOSVersion11 { return false, err } - // Enabling IP_DONTFRAG will force the kernel to return "sendto: message too long" - // and the datagram will not be fragmented - var errDFIPv4, errDFIPv6 error + var controlErr error + var disableDF bool if err := rawConn.Control(func(fd uintptr) { - errDFIPv4 = unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_DONTFRAG, 1) - errDFIPv6 = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_DONTFRAG, 1) + addr, err := unix.Getsockname(int(fd)) + if err != nil { + controlErr = fmt.Errorf("getsockname: %w", err) + return + } + + // Dual-stack sockets are effectively IPv6 sockets (with IPV6_ONLY set to 0). + // On macOS, the DF bit on dual-stack sockets is controlled by the IPV6_DONTFRAG option. + // See https://datatracker.ietf.org/doc/draft-seemann-tsvwg-udp-fragmentation/ for details. + switch addr.(type) { + case *unix.SockaddrInet4: + controlErr = unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_DONTFRAG, 1) + case *unix.SockaddrInet6: + controlErr = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_DONTFRAG, 1) + + // Setting the DF bit on dual-stack sockets works since macOS Sequoia. + // Disable DF on dual-stack sockets before Sequoia. + if version < macOSVersion15 { + // check if this is a dual-stack socket by reading the IPV6_V6ONLY flag + v6only, err := unix.GetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_V6ONLY) + if err != nil { + controlErr = fmt.Errorf("getting IPV6_V6ONLY: %w", err) + return + } + disableDF = v6only == 0 + } + default: + controlErr = fmt.Errorf("unknown address type: %T", addr) + } }); err != nil { return false, err } - switch { - case errDFIPv4 == nil && errDFIPv6 == nil: - utils.DefaultLogger.Debugf("Setting DF for IPv4 and IPv6.") - case errDFIPv4 == nil && errDFIPv6 != nil: - utils.DefaultLogger.Debugf("Setting DF for IPv4.") - case errDFIPv4 != nil && errDFIPv6 == nil: - utils.DefaultLogger.Debugf("Setting DF for IPv6.") - // On macOS, the syscall for setting DF bit for IPv4 fails on dual-stack listeners. - // Treat the connection as not having DF enabled, even though the DF bit will be set - // when used for IPv6. - // See https://github.com/quic-go/quic-go/issues/3793 for details. - return false, nil - case errDFIPv4 != nil && errDFIPv6 != nil: - return false, errors.New("setting DF failed for both IPv4 and IPv6") + if controlErr != nil { + return false, controlErr } - return true, nil + return !disableDF, nil } func isSendMsgSizeErr(err error) bool { @@ -53,22 +73,20 @@ func isSendMsgSizeErr(err error) bool { func isRecvMsgSizeErr(error) bool { return false } -func isAtLeastMacOS11() (bool, error) { +func getMacOSVersion() (int, error) { uname := &unix.Utsname{} - err := unix.Uname(uname) - if err != nil { - return false, err + if err := unix.Uname(uname); err != nil { + return 0, err } release := string(uname.Release[:]) - if idx := strings.Index(release, "."); idx != -1 { - version, err := strconv.Atoi(release[:idx]) - if err != nil { - return false, err - } - // Darwin version 20 is macOS version 11 - // https://en.wikipedia.org/wiki/Darwin_(operating_system)#Darwin_20_onwards - return version >= 20, nil + idx := strings.Index(release, ".") + if idx == -1 { + return 0, nil + } + version, err := strconv.Atoi(release[:idx]) + if err != nil { + return 0, err } - return false, nil + return version, nil } diff --git a/vendor/github.com/quic-go/quic-go/sys_conn_df_windows.go b/vendor/github.com/quic-go/quic-go/sys_conn_df_windows.go index e27635ec..4c140f00 100644 --- a/vendor/github.com/quic-go/quic-go/sys_conn_df_windows.go +++ b/vendor/github.com/quic-go/quic-go/sys_conn_df_windows.go @@ -12,21 +12,19 @@ import ( ) const ( - // IP_DONTFRAGMENT controls the Don't Fragment (DF) bit. - // - // It's the same code point for both IPv4 and IPv6 on Windows. - // https://microsoft.github.io/windows-docs-rs/doc/windows/Win32/Networking/WinSock/constant.IP_DONTFRAG.html - // https://microsoft.github.io/windows-docs-rs/doc/windows/Win32/Networking/WinSock/constant.IPV6_DONTFRAG.html - // + // https://microsoft.github.io/windows-docs-rs/doc/windows/Win32/Networking/WinSock/constant.IP_DONTFRAGMENT.html //nolint:stylecheck IP_DONTFRAGMENT = 14 + // https://microsoft.github.io/windows-docs-rs/doc/windows/Win32/Networking/WinSock/constant.IPV6_DONTFRAG.html + //nolint:stylecheck + IPV6_DONTFRAG = 14 ) func setDF(rawConn syscall.RawConn) (bool, error) { var errDFIPv4, errDFIPv6 error if err := rawConn.Control(func(fd uintptr) { errDFIPv4 = windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IP, IP_DONTFRAGMENT, 1) - errDFIPv6 = windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IPV6, IP_DONTFRAGMENT, 1) + errDFIPv6 = windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IPV6, IPV6_DONTFRAG, 1) }); err != nil { return false, err } diff --git a/vendor/github.com/quic-go/quic-go/sys_conn_oob.go b/vendor/github.com/quic-go/quic-go/sys_conn_oob.go index a6795ca2..75979682 100644 --- a/vendor/github.com/quic-go/quic-go/sys_conn_oob.go +++ b/vendor/github.com/quic-go/quic-go/sys_conn_oob.go @@ -83,7 +83,7 @@ func newConn(c OOBCapablePacketConn, supportsDF bool) (*oobConn, error) { if err != nil { return nil, err } - needsPacketInfo := false + var needsPacketInfo bool if udpAddr, ok := c.LocalAddr().(*net.UDPAddr); ok && udpAddr.IP.IsUnspecified() { needsPacketInfo = true } diff --git a/vendor/github.com/quic-go/quic-go/transport.go b/vendor/github.com/quic-go/quic-go/transport.go index 059f30f5..41dbc7ab 100644 --- a/vendor/github.com/quic-go/quic-go/transport.go +++ b/vendor/github.com/quic-go/quic-go/transport.go @@ -5,6 +5,7 @@ import ( "crypto/rand" "crypto/tls" "errors" + "fmt" "net" "sync" "sync/atomic" @@ -16,6 +17,27 @@ import ( "github.com/quic-go/quic-go/logging" ) +// ErrTransportClosed is returned by the Transport's Listen or Dial method after it was closed. +var ErrTransportClosed = &errTransportClosed{} + +type errTransportClosed struct { + err error +} + +func (e *errTransportClosed) Unwrap() []error { return []error{net.ErrClosed, e.err} } + +func (e *errTransportClosed) Error() string { + if e.err == nil { + return "quic: transport closed" + } + return fmt.Sprintf("quic: transport closed: %s", e.err) +} + +func (e *errTransportClosed) Is(target error) bool { + _, ok := target.(*errTransportClosed) + return ok +} + var errListenerAlreadySet = errors.New("listener already set") // The Transport is the central point to manage incoming and outgoing QUIC connections. @@ -115,7 +137,8 @@ type Transport struct { connIDLen int // Set in init. // If no ConnectionIDGenerator is set, this is set to a default. - connIDGenerator ConnectionIDGenerator + connIDGenerator ConnectionIDGenerator + statelessResetter *statelessResetter server *baseServer @@ -125,7 +148,7 @@ type Transport struct { statelessResetQueue chan receivedPacket listening chan struct{} // is closed when listen returns - closed bool + closeErr error createdConn bool isSingleUse bool // was created for a single server or client, i.e. by calling quic.Listen or quic.Dial @@ -168,6 +191,9 @@ func (t *Transport) createServer(tlsConf *tls.Config, conf *Config, allow0RTT bo t.mutex.Lock() defer t.mutex.Unlock() + if t.closeErr != nil { + return nil, t.closeErr + } if t.server != nil { return nil, errListenerAlreadySet } @@ -175,17 +201,22 @@ func (t *Transport) createServer(tlsConf *tls.Config, conf *Config, allow0RTT bo if err := t.init(false); err != nil { return nil, err } + maxTokenAge := t.MaxTokenAge + if maxTokenAge == 0 { + maxTokenAge = 24 * time.Hour + } s := newServer( t.conn, t.handlerMap, t.connIDGenerator, + t.statelessResetter, t.ConnContext, tlsConf, conf, t.Tracer, t.closeServer, *t.TokenGeneratorKey, - t.MaxTokenAge, + maxTokenAge, t.VerifySourceAddress, t.DisableVersionNegotiationPackets, allow0RTT, @@ -205,20 +236,137 @@ func (t *Transport) DialEarly(ctx context.Context, addr net.Addr, tlsConf *tls.C } func (t *Transport) dial(ctx context.Context, addr net.Addr, host string, tlsConf *tls.Config, conf *Config, use0RTT bool) (EarlyConnection, error) { + if err := t.init(t.isSingleUse); err != nil { + return nil, err + } if err := validateConfig(conf); err != nil { return nil, err } conf = populateConfig(conf) - if err := t.init(t.isSingleUse); err != nil { + tlsConf = tlsConf.Clone() + setTLSConfigServerName(tlsConf, addr, host) + return t.doDial(ctx, + newSendConn(t.conn, addr, packetInfo{}, utils.DefaultLogger), + tlsConf, + conf, + 0, + false, + use0RTT, + conf.Versions[0], + ) +} + +func (t *Transport) doDial( + ctx context.Context, + sendConn sendConn, + tlsConf *tls.Config, + config *Config, + initialPacketNumber protocol.PacketNumber, + hasNegotiatedVersion bool, + use0RTT bool, + version protocol.Version, +) (quicConn, error) { + srcConnID, err := t.connIDGenerator.GenerateConnectionID() + if err != nil { return nil, err } - var onClose func() - if t.isSingleUse { - onClose = func() { t.Close() } + destConnID, err := generateConnectionIDForInitial() + if err != nil { + return nil, err + } + + tracingID := nextConnTracingID() + ctx = context.WithValue(ctx, ConnectionTracingKey, tracingID) + + t.mutex.Lock() + if t.closeErr != nil { + t.mutex.Unlock() + return nil, t.closeErr + } + + var tracer *logging.ConnectionTracer + if config.Tracer != nil { + tracer = config.Tracer(ctx, protocol.PerspectiveClient, destConnID) + } + if tracer != nil && tracer.StartedConnection != nil { + tracer.StartedConnection(sendConn.LocalAddr(), sendConn.RemoteAddr(), srcConnID, destConnID) + } + + logger := utils.DefaultLogger.WithPrefix("client") + logger.Infof("Starting new connection to %s (%s -> %s), source connection ID %s, destination connection ID %s, version %s", tlsConf.ServerName, sendConn.LocalAddr(), sendConn.RemoteAddr(), srcConnID, destConnID, version) + + conn := newClientConnection( + context.WithoutCancel(ctx), + sendConn, + t.handlerMap, + destConnID, + srcConnID, + t.connIDGenerator, + t.statelessResetter, + config, + tlsConf, + initialPacketNumber, + use0RTT, + hasNegotiatedVersion, + tracer, + logger, + version, + ) + t.handlerMap.Add(srcConnID, conn) + t.mutex.Unlock() + + // The error channel needs to be buffered, as the run loop will continue running + // after doDial returns (if the handshake is successful). + errChan := make(chan error, 1) + recreateChan := make(chan errCloseForRecreating) + go func() { + err := conn.run() + var recreateErr *errCloseForRecreating + if errors.As(err, &recreateErr) { + recreateChan <- *recreateErr + return + } + if t.isSingleUse { + t.Close() + } + errChan <- err + }() + + // Only set when we're using 0-RTT. + // Otherwise, earlyConnChan will be nil. Receiving from a nil chan blocks forever. + var earlyConnChan <-chan struct{} + if use0RTT { + earlyConnChan = conn.earlyConnReady() + } + + select { + case <-ctx.Done(): + conn.destroy(nil) + // wait until the Go routine that called Connection.run() returns + select { + case <-errChan: + case <-recreateChan: + } + return nil, context.Cause(ctx) + case params := <-recreateChan: + return t.doDial(ctx, + sendConn, + tlsConf, + config, + params.nextPacketNumber, + true, + use0RTT, + params.nextVersion, + ) + case err := <-errChan: + return nil, err + case <-earlyConnChan: + // ready to send 0-RTT data + return conn, nil + case <-conn.HandshakeComplete(): + // handshake successfully completed + return conn, nil } - tlsConf = tlsConf.Clone() - setTLSConfigServerName(tlsConf, addr, host) - return dial(ctx, newSendConn(t.conn, addr, packetInfo{}, utils.DefaultLogger), t.connIDGenerator, t.handlerMap, tlsConf, conf, onClose, use0RTT) } func (t *Transport) init(allowZeroLengthConnIDs bool) error { @@ -237,7 +385,9 @@ func (t *Transport) init(allowZeroLengthConnIDs bool) error { t.logger = utils.DefaultLogger // TODO: make this configurable t.conn = conn - t.handlerMap = newPacketHandlerMap(t.StatelessResetKey, t.enqueueClosePacket, t.logger) + if t.handlerMap == nil { // allows mocking the handlerMap in tests + t.handlerMap = newPacketHandlerMap(t.enqueueClosePacket, t.logger) + } t.listening = make(chan struct{}) t.closeQueue = make(chan closePacket, 4) @@ -262,8 +412,8 @@ func (t *Transport) init(allowZeroLengthConnIDs bool) error { t.connIDLen = connIDLen t.connIDGenerator = &protocol.DefaultConnectionIDGenerator{ConnLen: t.connIDLen} } + t.statelessResetter = newStatelessResetter(t.StatelessResetKey) - getMultiplexer().AddConn(t.Conn) go t.listen(conn) go t.runSendQueue() }) @@ -300,11 +450,14 @@ func (t *Transport) runSendQueue() { } } -// Close closes the underlying connection. +// Close stops listening for UDP datagrams on the Transport.Conn. // If any listener was started, it will be closed as well. // It is invalid to start new listeners or connections after that. func (t *Transport) Close() error { - t.close(errors.New("closing")) + // avoid race condition if the transport is currently being initialized + t.init(false) + + t.close(nil) if t.createdConn { if err := t.Conn.Close(); err != nil { return err @@ -323,7 +476,7 @@ func (t *Transport) closeServer() { t.mutex.Lock() t.server = nil if t.isSingleUse { - t.closed = true + t.closeErr = ErrServerClosed } t.mutex.Unlock() if t.createdConn { @@ -339,10 +492,12 @@ func (t *Transport) closeServer() { func (t *Transport) close(e error) { t.mutex.Lock() defer t.mutex.Unlock() - if t.closed { + + if t.closeErr != nil { return } + e = &errTransportClosed{err: e} if t.handlerMap != nil { t.handlerMap.Close(e) } @@ -352,7 +507,7 @@ func (t *Transport) close(e error) { if t.Tracer != nil && t.Tracer.Close != nil { t.Tracer.Close() } - t.closed = true + t.closeErr = e } // only print warnings about the UDP receive buffer size once @@ -360,7 +515,6 @@ var setBufferWarningOnce sync.Once func (t *Transport) listen(conn rawConn) { defer close(t.listening) - defer getMultiplexer().RemoveConn(t.Conn) for { p, err := conn.ReadPacket() @@ -370,7 +524,7 @@ func (t *Transport) listen(conn rawConn) { // See https://github.com/quic-go/quic-go/issues/1737 for details. if nerr, ok := err.(net.Error); ok && nerr.Temporary() { t.mutex.Lock() - closed := t.closed + closed := t.closeErr != nil t.mutex.Unlock() if closed { return @@ -424,7 +578,12 @@ func (t *Transport) handlePacket(p receivedPacket) { return } if !wire.IsLongHeaderPacket(p.data[0]) { - t.maybeSendStatelessReset(p) + if statelessResetQueued := t.maybeSendStatelessReset(p); !statelessResetQueued { + if t.Tracer != nil && t.Tracer.DroppedPacket != nil { + t.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropUnknownConnectionID) + } + p.buffer.Release() + } return } @@ -432,29 +591,32 @@ func (t *Transport) handlePacket(p receivedPacket) { defer t.mutex.Unlock() if t.server == nil { // no server set t.logger.Debugf("received a packet with an unexpected connection ID %s", connID) + if t.Tracer != nil && t.Tracer.DroppedPacket != nil { + t.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropUnknownConnectionID) + } + p.buffer.MaybeRelease() return } t.server.handlePacket(p) } -func (t *Transport) maybeSendStatelessReset(p receivedPacket) { +func (t *Transport) maybeSendStatelessReset(p receivedPacket) (statelessResetQueued bool) { if t.StatelessResetKey == nil { - p.buffer.Release() - return + return false } // Don't send a stateless reset in response to very small packets. // This includes packets that could be stateless resets. if len(p.data) <= protocol.MinStatelessResetSize { - p.buffer.Release() - return + return false } select { case t.statelessResetQueue <- p: + return true default: // it's fine to not send a stateless reset when we're busy - p.buffer.Release() + return false } } @@ -466,7 +628,7 @@ func (t *Transport) sendStatelessReset(p receivedPacket) { t.logger.Errorf("error parsing connection ID on packet from %s: %s", p.remoteAddr, err) return } - token := t.handlerMap.GetStatelessResetToken(connID) + token := t.statelessResetter.GetStatelessResetToken(connID) t.logger.Debugf("Sending stateless reset to %s (connection ID: %s). Token: %#x", p.remoteAddr, connID, token) data := make([]byte, protocol.MinStatelessResetSize-16, protocol.MinStatelessResetSize) rand.Read(data) @@ -489,7 +651,7 @@ func (t *Transport) maybeHandleStatelessReset(data []byte) bool { token := *(*protocol.StatelessResetToken)(data[len(data)-16:]) if conn, ok := t.handlerMap.GetByResetToken(token); ok { t.logger.Debugf("Received a stateless reset with token %#x. Closing connection.", token) - go conn.destroy(&StatelessResetError{Token: token}) + go conn.destroy(&StatelessResetError{}) return true } return false diff --git a/vendor/modules.txt b/vendor/modules.txt index d5a83f0f..e185c100 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -91,7 +91,7 @@ github.com/prometheus/common/model github.com/prometheus/procfs github.com/prometheus/procfs/internal/fs github.com/prometheus/procfs/internal/util -# github.com/quic-go/quic-go v0.48.2 +# github.com/quic-go/quic-go v0.49.1 ## explicit; go 1.22 github.com/quic-go/quic-go github.com/quic-go/quic-go/internal/ackhandler