TUN-8371: Bump quic-go to v0.42.0

## Summary
We discovered that we were being impacted by a bug in quic-go,
that could create deadlocks and not close connections.

This commit bumps quic-go to the version that contains the fix
to prevent that from happening.
This commit is contained in:
João "Pisco" Fernandes 2024-04-18 18:26:01 +01:00 committed by chungthuang
parent 5e5f2f4d8c
commit 84833011ec
79 changed files with 730 additions and 763 deletions

4
go.mod
View File

@ -25,7 +25,7 @@ require (
github.com/pkg/errors v0.9.1
github.com/prometheus/client_golang v1.13.0
github.com/prometheus/client_model v0.2.0
github.com/quic-go/quic-go v0.40.1-0.20240101045026-22b7f7744eb6
github.com/quic-go/quic-go v0.42.0
github.com/rs/zerolog v1.20.0
github.com/stretchr/testify v1.8.4
github.com/urfave/cli/v2 v2.3.0
@ -84,7 +84,7 @@ require (
github.com/prometheus/procfs v0.8.0 // indirect
github.com/russross/blackfriday/v2 v2.1.0 // indirect
go.opentelemetry.io/otel/metric v1.21.0 // indirect
go.uber.org/mock v0.3.0 // indirect
go.uber.org/mock v0.4.0 // indirect
golang.org/x/exp v0.0.0-20221205204356-47842c84f3db // indirect
golang.org/x/mod v0.11.0 // indirect
golang.org/x/oauth2 v0.13.0 // indirect

10
go.sum
View File

@ -322,8 +322,8 @@ github.com/prometheus/procfs v0.6.0/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1
github.com/prometheus/procfs v0.7.3/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1xBZuNvfVA=
github.com/prometheus/procfs v0.8.0 h1:ODq8ZFEaYeCaZOJlZZdJA2AbQR98dSHSM1KW/You5mo=
github.com/prometheus/procfs v0.8.0/go.mod h1:z7EfXMXOkbkqb9IINtpCn86r/to3BnA0uaxHdg830/4=
github.com/quic-go/quic-go v0.40.1-0.20240101045026-22b7f7744eb6 h1:OI4WiysowCcxLtcZMGBZildo12di3ljcMN4vWdUQpoU=
github.com/quic-go/quic-go v0.40.1-0.20240101045026-22b7f7744eb6/go.mod h1:qCkNjqczPEvgsOnxZ0eCD14lv+B2LHlFAB++CNOh9hA=
github.com/quic-go/quic-go v0.42.0 h1:uSfdap0eveIl8KXnipv9K7nlwZ5IqLlYOpJ58u5utpM=
github.com/quic-go/quic-go v0.42.0/go.mod h1:132kz4kL3F9vxhW3CtQJLDVwcFe5wdWeJXXijhsO57M=
github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4=
github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ=
github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog=
@ -381,8 +381,8 @@ go.opentelemetry.io/proto/otlp v1.0.0 h1:T0TX0tmXU8a3CbNXzEKGeU5mIVOdf0oykP+u2lI
go.opentelemetry.io/proto/otlp v1.0.0/go.mod h1:Sy6pihPLfYHkr3NkUbEhGHFhINUSI/v80hjKIs5JXpM=
go.uber.org/automaxprocs v1.4.0 h1:CpDZl6aOlLhReez+8S3eEotD7Jx0Os++lemPlMULQP0=
go.uber.org/automaxprocs v1.4.0/go.mod h1:/mTEdr7LvHhs0v7mjdxDreTz1OG5zdZGqgOnhWiR/+Q=
go.uber.org/mock v0.3.0 h1:3mUxI1No2/60yUYax92Pt8eNOEecx2D3lcXZh2NEZJo=
go.uber.org/mock v0.3.0/go.mod h1:a6FSlNadKUHUa9IP5Vyt1zh4fC7uAwxMutEAscFbkZc=
go.uber.org/mock v0.4.0 h1:VcM4ZOtdbR4f6VXfiOpwpVJDL6lCReaZ6mw31wqh7KU=
go.uber.org/mock v0.4.0/go.mod h1:a6FSlNadKUHUa9IP5Vyt1zh4fC7uAwxMutEAscFbkZc=
golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
@ -552,6 +552,8 @@ golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk=
golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY=

View File

@ -2,16 +2,6 @@ run:
skip-files:
- internal/handshake/cipher_suite.go
linters-settings:
depguard:
rules:
qtls:
list-mode: lax
files:
- "!internal/qtls/**"
- "$all"
deny:
- pkg: github.com/quic-go/qtls-go1-20
desc: "importing qtls only allowed in internal/qtls"
misspell:
ignore-words:
- ect
@ -20,7 +10,6 @@ linters:
disable-all: true
enable:
- asciicheck
- depguard
- exhaustive
- exportloopref
- goimports

View File

@ -183,26 +183,20 @@ quic-go logs a wide range of events defined in [draft-ietf-quic-qlog-quic-events
qlog files can be processed by a number of 3rd-party tools. [qviz](https://qvis.quictools.info/) has proven very useful for debugging all kinds of QUIC connection failures.
qlog is activated by setting a `Tracer` callback on the `Config`. It is called as soon as quic-go decides to starts the QUIC handshake on a new connection.
A useful implementation of this callback could look like this:
qlog can be activated by setting the `Tracer` callback on the `Config`. It is called as soon as quic-go decides to start the QUIC handshake on a new connection.
`qlog.DefaultTracer` provides a tracer implementation which writes qlog files to a directory specified by the `QLOGDIR` environment variable, if set.
The default qlog tracer can be used like this:
```go
quic.Config{
Tracer: func(ctx context.Context, p logging.Perspective, connID quic.ConnectionID) *logging.ConnectionTracer {
role := "server"
if p == logging.PerspectiveClient {
role = "client"
}
filename := fmt.Sprintf("./log_%s_%s.qlog", connID, role)
f, err := os.Create(filename)
// handle the error
return qlog.NewConnectionTracer(f, p, connID)
}
Tracer: qlog.DefaultTracer,
}
```
This implementation of the callback creates a new qlog file in the current directory named `log_<client / server>_<QUIC connection ID>.qlog`.
This example creates a new qlog file under `<QLOGDIR>/<Original Destination Connection ID>_<Vantage Point>.qlog`, e.g. `qlogs/2e0407da_client.qlog`.
For custom qlog behavior, `qlog.NewConnectionTracer` can be used.
## Using HTTP/3
### As a server
@ -232,11 +226,13 @@ http.Client{
| [algernon](https://github.com/xyproto/algernon) | Small self-contained pure-Go web server with Lua, Markdown, HTTP/2, QUIC, Redis and PostgreSQL support | ![GitHub Repo stars](https://img.shields.io/github/stars/xyproto/algernon?style=flat-square) |
| [caddy](https://github.com/caddyserver/caddy/) | Fast, multi-platform web server with automatic HTTPS | ![GitHub Repo stars](https://img.shields.io/github/stars/caddyserver/caddy?style=flat-square) |
| [cloudflared](https://github.com/cloudflare/cloudflared) | A tunneling daemon that proxies traffic from the Cloudflare network to your origins | ![GitHub Repo stars](https://img.shields.io/github/stars/cloudflare/cloudflared?style=flat-square) |
| [frp](https://github.com/fatedier/frp) | A fast reverse proxy to help you expose a local server behind a NAT or firewall to the internet | ![GitHub Repo stars](https://img.shields.io/github/stars/fatedier/frp?style=flat-square) |
| [go-libp2p](https://github.com/libp2p/go-libp2p) | libp2p implementation in Go, powering [Kubo](https://github.com/ipfs/kubo) (IPFS) and [Lotus](https://github.com/filecoin-project/lotus) (Filecoin), among others | ![GitHub Repo stars](https://img.shields.io/github/stars/libp2p/go-libp2p?style=flat-square) |
| [gost](https://github.com/go-gost/gost) | A simple security tunnel written in Go | ![GitHub Repo stars](https://img.shields.io/github/stars/go-gost/gost?style=flat-square) |
| [Hysteria](https://github.com/apernet/hysteria) | A powerful, lightning fast and censorship resistant proxy | ![GitHub Repo stars](https://img.shields.io/github/stars/apernet/hysteria?style=flat-square) |
| [Mercure](https://github.com/dunglas/mercure) | An open, easy, fast, reliable and battery-efficient solution for real-time communications | ![GitHub Repo stars](https://img.shields.io/github/stars/dunglas/mercure?style=flat-square) |
| [OONI Probe](https://github.com/ooni/probe-cli) | Next generation OONI Probe. Library and CLI tool. | ![GitHub Repo stars](https://img.shields.io/github/stars/ooni/probe-cli?style=flat-square) |
| [RoadRunner](https://github.com/roadrunner-server/roadrunner) | High-performance PHP application server, process manager written in Go and powered with plugins | ![GitHub Repo stars](https://img.shields.io/github/stars/roadrunner-server/roadrunner?style=flat-square) |
| [syncthing](https://github.com/syncthing/syncthing/) | Open Source Continuous File Synchronization | ![GitHub Repo stars](https://img.shields.io/github/stars/syncthing/syncthing?style=flat-square) |
| [traefik](https://github.com/traefik/traefik) | The Cloud Native Application Proxy | ![GitHub Repo stars](https://img.shields.io/github/stars/traefik/traefik?style=flat-square) |
| [v2ray-core](https://github.com/v2fly/v2ray-core) | A platform for building proxies to bypass network restrictions | ![GitHub Repo stars](https://img.shields.io/github/stars/v2fly/v2ray-core?style=flat-square) |

View File

@ -28,7 +28,7 @@ type client struct {
initialPacketNumber protocol.PacketNumber
hasNegotiatedVersion bool
version protocol.VersionNumber
version protocol.Version
handshakeChan chan struct{}
@ -232,7 +232,7 @@ func (c *client) dial(ctx context.Context) error {
select {
case <-ctx.Done():
c.conn.shutdown()
c.conn.destroy(nil)
return context.Cause(ctx)
case err := <-errorChan:
return err

View File

@ -4,7 +4,6 @@ import (
"math/bits"
"net"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/utils"
)
@ -13,7 +12,6 @@ import (
// with an exponential backoff.
type closedLocalConn struct {
counter uint32
perspective protocol.Perspective
logger utils.Logger
sendPacket func(net.Addr, packetInfo)
@ -22,10 +20,9 @@ type closedLocalConn struct {
var _ packetHandler = &closedLocalConn{}
// newClosedLocalConn creates a new closedLocalConn and runs it.
func newClosedLocalConn(sendPacket func(net.Addr, packetInfo), pers protocol.Perspective, logger utils.Logger) packetHandler {
func newClosedLocalConn(sendPacket func(net.Addr, packetInfo), logger utils.Logger) packetHandler {
return &closedLocalConn{
sendPacket: sendPacket,
perspective: pers,
logger: logger,
}
}
@ -41,24 +38,20 @@ func (c *closedLocalConn) handlePacket(p receivedPacket) {
c.sendPacket(p.remoteAddr, p.info)
}
func (c *closedLocalConn) shutdown() {}
func (c *closedLocalConn) destroy(error) {}
func (c *closedLocalConn) getPerspective() protocol.Perspective { return c.perspective }
func (c *closedLocalConn) closeWithTransportError(TransportErrorCode) {}
// A closedRemoteConn is a connection that was closed remotely.
// For such a connection, we might receive reordered packets that were sent before the CONNECTION_CLOSE.
// We can just ignore those packets.
type closedRemoteConn struct {
perspective protocol.Perspective
}
type closedRemoteConn struct{}
var _ packetHandler = &closedRemoteConn{}
func newClosedRemoteConn(pers protocol.Perspective) packetHandler {
return &closedRemoteConn{perspective: pers}
func newClosedRemoteConn() packetHandler {
return &closedRemoteConn{}
}
func (s *closedRemoteConn) handlePacket(receivedPacket) {}
func (s *closedRemoteConn) shutdown() {}
func (s *closedRemoteConn) destroy(error) {}
func (s *closedRemoteConn) getPerspective() protocol.Perspective { return s.perspective }
func (c *closedRemoteConn) handlePacket(receivedPacket) {}
func (c *closedRemoteConn) destroy(error) {}
func (c *closedRemoteConn) closeWithTransportError(TransportErrorCode) {}

View File

@ -5,6 +5,8 @@ coverage:
- interop/
- internal/handshake/cipher_suite.go
- internal/utils/linkedlist/linkedlist.go
- internal/testdata
- testutils/
- fuzzing/
- metrics/
status:

View File

@ -2,7 +2,6 @@ package quic
import (
"fmt"
"net"
"time"
"github.com/quic-go/quic-go/internal/protocol"
@ -49,16 +48,6 @@ func validateConfig(config *Config) error {
return nil
}
// populateServerConfig populates fields in the quic.Config with their default values, if none are set
// it may be called with nil
func populateServerConfig(config *Config) *Config {
config = populateConfig(config)
if config.RequireAddressValidation == nil {
config.RequireAddressValidation = func(net.Addr) bool { return false }
}
return config
}
// populateConfig populates fields in the quic.Config with their default values, if none are set
// it may be called with nil
func populateConfig(config *Config) *Config {
@ -111,7 +100,6 @@ func populateConfig(config *Config) *Config {
Versions: versions,
HandshakeIdleTimeout: handshakeIdleTimeout,
MaxIdleTimeout: idleTimeout,
RequireAddressValidation: config.RequireAddressValidation,
KeepAlivePeriod: config.KeepAlivePeriod,
InitialStreamReceiveWindow: initialStreamReceiveWindow,
MaxStreamReceiveWindow: maxStreamReceiveWindow,

View File

@ -19,7 +19,7 @@ type connIDGenerator struct {
getStatelessResetToken func(protocol.ConnectionID) protocol.StatelessResetToken
removeConnectionID func(protocol.ConnectionID)
retireConnectionID func(protocol.ConnectionID)
replaceWithClosed func([]protocol.ConnectionID, protocol.Perspective, []byte)
replaceWithClosed func([]protocol.ConnectionID, []byte)
queueControlFrame func(wire.Frame)
}
@ -30,7 +30,7 @@ func newConnIDGenerator(
getStatelessResetToken func(protocol.ConnectionID) protocol.StatelessResetToken,
removeConnectionID func(protocol.ConnectionID),
retireConnectionID func(protocol.ConnectionID),
replaceWithClosed func([]protocol.ConnectionID, protocol.Perspective, []byte),
replaceWithClosed func([]protocol.ConnectionID, []byte),
queueControlFrame func(wire.Frame),
generator ConnectionIDGenerator,
) *connIDGenerator {
@ -126,7 +126,7 @@ func (m *connIDGenerator) RemoveAll() {
}
}
func (m *connIDGenerator) ReplaceWithClosed(pers protocol.Perspective, connClose []byte) {
func (m *connIDGenerator) ReplaceWithClosed(connClose []byte) {
connIDs := make([]protocol.ConnectionID, 0, len(m.activeSrcConnIDs)+1)
if m.initialClientDestConnID != nil {
connIDs = append(connIDs, *m.initialClientDestConnID)
@ -134,5 +134,5 @@ func (m *connIDGenerator) ReplaceWithClosed(pers protocol.Perspective, connClose
for _, connID := range m.activeSrcConnIDs {
connIDs = append(connIDs, connID)
}
m.replaceWithClosed(connIDs, pers, connClose)
m.replaceWithClosed(connIDs, connClose)
}

View File

@ -25,7 +25,7 @@ import (
)
type unpacker interface {
UnpackLongHeader(hdr *wire.Header, rcvTime time.Time, data []byte, v protocol.VersionNumber) (*unpackedPacket, error)
UnpackLongHeader(hdr *wire.Header, rcvTime time.Time, data []byte, v protocol.Version) (*unpackedPacket, error)
UnpackShortHeader(rcvTime time.Time, data []byte) (protocol.PacketNumber, protocol.PacketNumberLen, protocol.KeyPhaseBit, []byte, error)
}
@ -93,7 +93,7 @@ type connRunner interface {
GetStatelessResetToken(protocol.ConnectionID) protocol.StatelessResetToken
Retire(protocol.ConnectionID)
Remove(protocol.ConnectionID)
ReplaceWithClosed([]protocol.ConnectionID, protocol.Perspective, []byte)
ReplaceWithClosed([]protocol.ConnectionID, []byte)
AddResetToken(protocol.StatelessResetToken, packetHandler)
RemoveResetToken(protocol.StatelessResetToken)
}
@ -106,7 +106,7 @@ type closeError struct {
type errCloseForRecreating struct {
nextPacketNumber protocol.PacketNumber
nextVersion protocol.VersionNumber
nextVersion protocol.Version
}
func (e *errCloseForRecreating) Error() string {
@ -128,7 +128,7 @@ type connection struct {
srcConnIDLen int
perspective protocol.Perspective
version protocol.VersionNumber
version protocol.Version
config *Config
conn sendConn
@ -177,6 +177,7 @@ type connection struct {
earlyConnReadyChan chan struct{}
sentFirstPacket bool
droppedInitialKeys bool
handshakeComplete bool
handshakeConfirmed bool
@ -235,7 +236,7 @@ var newConnection = func(
tracer *logging.ConnectionTracer,
tracingID uint64,
logger utils.Logger,
v protocol.VersionNumber,
v protocol.Version,
) quicConn {
s := &connection{
conn: conn,
@ -348,7 +349,7 @@ var newClientConnection = func(
tracer *logging.ConnectionTracer,
tracingID uint64,
logger utils.Logger,
v protocol.VersionNumber,
v protocol.Version,
) quicConn {
s := &connection{
conn: conn,
@ -453,7 +454,7 @@ func (s *connection) preSetup() {
s.handshakeStream = newCryptoStream()
s.sendQueue = newSendQueue(s.conn)
s.retransmissionQueue = newRetransmissionQueue()
s.frameParser = wire.NewFrameParser(s.config.EnableDatagrams)
s.frameParser = *wire.NewFrameParser(s.config.EnableDatagrams)
s.rttStats = &utils.RTTStats{}
s.connFlowController = flowcontrol.NewConnectionFlowController(
protocol.ByteCount(s.config.InitialConnectionReceiveWindow),
@ -520,6 +521,9 @@ func (s *connection) run() error {
runLoop:
for {
if s.framer.QueuedTooManyControlFrames() {
s.closeLocal(&qerr.TransportError{ErrorCode: InternalError})
}
// Close immediately if requested
select {
case closeErr = <-s.closeChan:
@ -1148,7 +1152,7 @@ func (s *connection) handleUnpackedLongHeaderPacket(
if !s.receivedFirstPacket {
s.receivedFirstPacket = true
if !s.versionNegotiated && s.tracer != nil && s.tracer.NegotiatedVersion != nil {
var clientVersions, serverVersions []protocol.VersionNumber
var clientVersions, serverVersions []protocol.Version
switch s.perspective {
case protocol.PerspectiveClient:
clientVersions = s.config.Versions
@ -1185,7 +1189,8 @@ func (s *connection) handleUnpackedLongHeaderPacket(
}
}
if s.perspective == protocol.PerspectiveServer && packet.encryptionLevel == protocol.EncryptionHandshake {
if s.perspective == protocol.PerspectiveServer && packet.encryptionLevel == protocol.EncryptionHandshake &&
!s.droppedInitialKeys {
// On the server side, Initial keys are dropped as soon as the first Handshake packet is received.
// See Section 4.9.1 of RFC 9001.
if err := s.dropEncryptionLevel(protocol.EncryptionInitial); err != nil {
@ -1572,13 +1577,6 @@ func (s *connection) closeRemote(e error) {
})
}
// Close the connection. It sends a NO_ERROR application error.
// It waits until the run loop has stopped before returning
func (s *connection) shutdown() {
s.closeLocal(nil)
<-s.ctx.Done()
}
func (s *connection) CloseWithError(code ApplicationErrorCode, desc string) error {
s.closeLocal(&qerr.ApplicationError{
ErrorCode: code,
@ -1588,6 +1586,11 @@ func (s *connection) CloseWithError(code ApplicationErrorCode, desc string) erro
return nil
}
func (s *connection) closeWithTransportError(code TransportErrorCode) {
s.closeLocal(&qerr.TransportError{ErrorCode: code})
<-s.ctx.Done()
}
func (s *connection) handleCloseError(closeErr *closeError) {
e := closeErr.err
if e == nil {
@ -1632,7 +1635,7 @@ func (s *connection) handleCloseError(closeErr *closeError) {
// If this is a remote close we're done here
if closeErr.remote {
s.connIDGenerator.ReplaceWithClosed(s.perspective, nil)
s.connIDGenerator.ReplaceWithClosed(nil)
return
}
if closeErr.immediate {
@ -1649,7 +1652,7 @@ func (s *connection) handleCloseError(closeErr *closeError) {
if err != nil {
s.logger.Debugf("Error sending CONNECTION_CLOSE: %s", err)
}
s.connIDGenerator.ReplaceWithClosed(s.perspective, connClosePacket)
s.connIDGenerator.ReplaceWithClosed(connClosePacket)
}
func (s *connection) dropEncryptionLevel(encLevel protocol.EncryptionLevel) error {
@ -1661,6 +1664,7 @@ func (s *connection) dropEncryptionLevel(encLevel protocol.EncryptionLevel) erro
//nolint:exhaustive // only Initial and 0-RTT need special treatment
switch encLevel {
case protocol.EncryptionInitial:
s.droppedInitialKeys = true
s.cryptoStreamHandler.DiscardInitialKeys()
case protocol.Encryption0RTT:
s.streamsMap.ResetFor0RTT()
@ -2077,7 +2081,8 @@ func (s *connection) sendPackedCoalescedPacket(packet *coalescedPacket, ecn prot
largestAcked = p.ack.LargestAcked()
}
s.sentPacketHandler.SentPacket(now, p.header.PacketNumber, largestAcked, p.streamFrames, p.frames, p.EncryptionLevel(), ecn, p.length, false)
if s.perspective == protocol.PerspectiveClient && p.EncryptionLevel() == protocol.EncryptionHandshake {
if s.perspective == protocol.PerspectiveClient && p.EncryptionLevel() == protocol.EncryptionHandshake &&
!s.droppedInitialKeys {
// On the client side, Initial keys are dropped as soon as the first Handshake packet is sent.
// See Section 4.9.1 of RFC 9001.
if err := s.dropEncryptionLevel(protocol.EncryptionInitial); err != nil {
@ -2377,11 +2382,7 @@ func (s *connection) RemoteAddr() net.Addr {
return s.conn.RemoteAddr()
}
func (s *connection) getPerspective() protocol.Perspective {
return s.perspective
}
func (s *connection) GetVersion() protocol.VersionNumber {
func (s *connection) GetVersion() protocol.Version {
return s.version
}

View File

@ -15,15 +15,25 @@ type framer interface {
HasData() bool
QueueControlFrame(wire.Frame)
AppendControlFrames([]ackhandler.Frame, protocol.ByteCount, protocol.VersionNumber) ([]ackhandler.Frame, protocol.ByteCount)
AppendControlFrames([]ackhandler.Frame, protocol.ByteCount, protocol.Version) ([]ackhandler.Frame, protocol.ByteCount)
AddActiveStream(protocol.StreamID)
AppendStreamFrames([]ackhandler.StreamFrame, protocol.ByteCount, protocol.VersionNumber) ([]ackhandler.StreamFrame, protocol.ByteCount)
AppendStreamFrames([]ackhandler.StreamFrame, protocol.ByteCount, protocol.Version) ([]ackhandler.StreamFrame, protocol.ByteCount)
Handle0RTTRejection() error
// QueuedTooManyControlFrames says if the control frame queue exceeded its maximum queue length.
// This is a hack.
// It is easier to implement than propagating an error return value in QueueControlFrame.
// The correct solution would be to queue frames with their respective structs.
// See https://github.com/quic-go/quic-go/issues/4271 for the queueing of stream-related control frames.
QueuedTooManyControlFrames() bool
}
const maxPathResponses = 256
const (
maxPathResponses = 256
maxControlFrames = 16 << 10
)
type framerI struct {
mutex sync.Mutex
@ -36,6 +46,7 @@ type framerI struct {
controlFrameMutex sync.Mutex
controlFrames []wire.Frame
pathResponses []*wire.PathResponseFrame
queuedTooManyControlFrames bool
}
var _ framer = &framerI{}
@ -73,10 +84,15 @@ func (f *framerI) QueueControlFrame(frame wire.Frame) {
f.pathResponses = append(f.pathResponses, pr)
return
}
// This is a hack.
if len(f.controlFrames) >= maxControlFrames {
f.queuedTooManyControlFrames = true
return
}
f.controlFrames = append(f.controlFrames, frame)
}
func (f *framerI) AppendControlFrames(frames []ackhandler.Frame, maxLen protocol.ByteCount, v protocol.VersionNumber) ([]ackhandler.Frame, protocol.ByteCount) {
func (f *framerI) AppendControlFrames(frames []ackhandler.Frame, maxLen protocol.ByteCount, v protocol.Version) ([]ackhandler.Frame, protocol.ByteCount) {
f.controlFrameMutex.Lock()
defer f.controlFrameMutex.Unlock()
@ -105,6 +121,10 @@ func (f *framerI) AppendControlFrames(frames []ackhandler.Frame, maxLen protocol
return frames, length
}
func (f *framerI) QueuedTooManyControlFrames() bool {
return f.queuedTooManyControlFrames
}
func (f *framerI) AddActiveStream(id protocol.StreamID) {
f.mutex.Lock()
if _, ok := f.activeStreams[id]; !ok {
@ -114,7 +134,7 @@ func (f *framerI) AddActiveStream(id protocol.StreamID) {
f.mutex.Unlock()
}
func (f *framerI) AppendStreamFrames(frames []ackhandler.StreamFrame, maxLen protocol.ByteCount, v protocol.VersionNumber) ([]ackhandler.StreamFrame, protocol.ByteCount) {
func (f *framerI) AppendStreamFrames(frames []ackhandler.StreamFrame, maxLen protocol.ByteCount, v protocol.Version) ([]ackhandler.StreamFrame, protocol.ByteCount) {
startLen := len(frames)
var length protocol.ByteCount
f.mutex.Lock()

View File

@ -16,8 +16,12 @@ import (
// The StreamID is the ID of a QUIC stream.
type StreamID = protocol.StreamID
// A Version is a QUIC version number.
type Version = protocol.Version
// A VersionNumber is a QUIC version number.
type VersionNumber = protocol.VersionNumber
// Deprecated: VersionNumber was renamed to Version.
type VersionNumber = Version
const (
// Version1 is RFC 9000
@ -159,6 +163,9 @@ type Connection interface {
OpenStream() (Stream, error)
// OpenStreamSync opens a new bidirectional QUIC stream.
// It blocks until a new stream can be opened.
// There is no signaling to the peer about new streams:
// The peer can only accept the stream after data has been sent on the stream,
// or the stream has been reset or closed.
// If the error is non-nil, it satisfies the net.Error interface.
// If the connection was closed due to a timeout, Timeout() will be true.
OpenStreamSync(context.Context) (Stream, error)
@ -255,7 +262,7 @@ type Config struct {
GetConfigForClient func(info *ClientHelloInfo) (*Config, error)
// The QUIC versions that can be negotiated.
// If not set, it uses all versions available.
Versions []VersionNumber
Versions []Version
// HandshakeIdleTimeout is the idle timeout before completion of the handshake.
// If we don't receive any packet from the peer within this time, the connection attempt is aborted.
// Additionally, if the handshake doesn't complete in twice this time, the connection attempt is also aborted.
@ -267,11 +274,6 @@ type Config struct {
// If the timeout is exceeded, the connection is closed.
// If this value is zero, the timeout is set to 30 seconds.
MaxIdleTimeout time.Duration
// RequireAddressValidation determines if a QUIC Retry packet is sent.
// This allows the server to verify the client's address, at the cost of increasing the handshake latency by 1 RTT.
// See https://datatracker.ietf.org/doc/html/rfc9000#section-8 for details.
// If not set, every client is forced to prove its remote address.
RequireAddressValidation func(net.Addr) bool
// The TokenStore stores tokens received from the server.
// Tokens are used to skip address validation on future connection attempts.
// The key used to store tokens is the ServerName from the tls.Config, if set
@ -331,8 +333,15 @@ type Config struct {
Tracer func(context.Context, logging.Perspective, ConnectionID) *logging.ConnectionTracer
}
// ClientHelloInfo contains information about an incoming connection attempt.
type ClientHelloInfo struct {
// RemoteAddr is the remote address on the Initial packet.
// Unless AddrVerified is set, the address is not yet verified, and could be a spoofed IP address.
RemoteAddr net.Addr
// AddrVerified says if the remote address was verified using QUIC's Retry mechanism.
// Note that the Retry mechanism costs one network roundtrip,
// and is not performed unless Transport.MaxUnvalidatedHandshakes is surpassed.
AddrVerified bool
}
// ConnectionState records basic details about a QUIC connection
@ -347,7 +356,7 @@ type ConnectionState struct {
// Used0RTT says if 0-RTT resumption was used.
Used0RTT bool
// Version is the QUIC version of the QUIC connection.
Version VersionNumber
Version Version
// GSO says if generic segmentation offload is used
GSO bool
}

View File

@ -20,5 +20,5 @@ func NewAckHandler(
logger utils.Logger,
) (SentPacketHandler, ReceivedPacketHandler) {
sph := newSentPacketHandler(initialPacketNumber, initialMaxDatagramSize, rttStats, clientAddressValidated, enableECN, pers, tracer, logger)
return sph, newReceivedPacketHandler(sph, rttStats, logger)
return sph, newReceivedPacketHandler(sph, logger)
}

View File

@ -14,23 +14,19 @@ type receivedPacketHandler struct {
initialPackets *receivedPacketTracker
handshakePackets *receivedPacketTracker
appDataPackets *receivedPacketTracker
appDataPackets appDataReceivedPacketTracker
lowest1RTTPacket protocol.PacketNumber
}
var _ ReceivedPacketHandler = &receivedPacketHandler{}
func newReceivedPacketHandler(
sentPackets sentPacketTracker,
rttStats *utils.RTTStats,
logger utils.Logger,
) ReceivedPacketHandler {
func newReceivedPacketHandler(sentPackets sentPacketTracker, logger utils.Logger) ReceivedPacketHandler {
return &receivedPacketHandler{
sentPackets: sentPackets,
initialPackets: newReceivedPacketTracker(rttStats, logger),
handshakePackets: newReceivedPacketTracker(rttStats, logger),
appDataPackets: newReceivedPacketTracker(rttStats, logger),
initialPackets: newReceivedPacketTracker(),
handshakePackets: newReceivedPacketTracker(),
appDataPackets: *newAppDataReceivedPacketTracker(logger),
lowest1RTTPacket: protocol.InvalidPacketNumber,
}
}
@ -88,41 +84,28 @@ func (h *receivedPacketHandler) DropPackets(encLevel protocol.EncryptionLevel) {
}
func (h *receivedPacketHandler) GetAlarmTimeout() time.Time {
var initialAlarm, handshakeAlarm time.Time
if h.initialPackets != nil {
initialAlarm = h.initialPackets.GetAlarmTimeout()
}
if h.handshakePackets != nil {
handshakeAlarm = h.handshakePackets.GetAlarmTimeout()
}
oneRTTAlarm := h.appDataPackets.GetAlarmTimeout()
return utils.MinNonZeroTime(utils.MinNonZeroTime(initialAlarm, handshakeAlarm), oneRTTAlarm)
return h.appDataPackets.GetAlarmTimeout()
}
func (h *receivedPacketHandler) GetAckFrame(encLevel protocol.EncryptionLevel, onlyIfQueued bool) *wire.AckFrame {
var ack *wire.AckFrame
//nolint:exhaustive // 0-RTT packets can't contain ACK frames.
switch encLevel {
case protocol.EncryptionInitial:
if h.initialPackets != nil {
ack = h.initialPackets.GetAckFrame(onlyIfQueued)
return h.initialPackets.GetAckFrame()
}
return nil
case protocol.EncryptionHandshake:
if h.handshakePackets != nil {
ack = h.handshakePackets.GetAckFrame(onlyIfQueued)
return h.handshakePackets.GetAckFrame()
}
return nil
case protocol.Encryption1RTT:
// 0-RTT packets can't contain ACK frames
return h.appDataPackets.GetAckFrame(onlyIfQueued)
default:
// 0-RTT packets can't contain ACK frames
return nil
}
// For Initial and Handshake ACKs, the delay time is ignored by the receiver.
// Set it to 0 in order to save bytes.
if ack != nil {
ack.DelayTime = 0
}
return ack
}
func (h *receivedPacketHandler) IsPotentiallyDuplicate(pn protocol.PacketNumber, encLevel protocol.EncryptionLevel) bool {

View File

@ -9,40 +9,19 @@ import (
"github.com/quic-go/quic-go/internal/wire"
)
// number of ack-eliciting packets received before sending an ack.
const packetsBeforeAck = 2
// The receivedPacketTracker tracks packets for the Initial and Handshake packet number space.
// Every received packet is acknowledged immediately.
type receivedPacketTracker struct {
largestObserved protocol.PacketNumber
ignoreBelow protocol.PacketNumber
largestObservedRcvdTime time.Time
ect0, ect1, ecnce uint64
packetHistory *receivedPacketHistory
packetHistory receivedPacketHistory
maxAckDelay time.Duration
rttStats *utils.RTTStats
hasNewAck bool // true as soon as we received an ack-eliciting new packet
ackQueued bool // true once we received more than 2 (or later in the connection 10) ack-eliciting packets
ackElicitingPacketsReceivedSinceLastAck int
ackAlarm time.Time
lastAck *wire.AckFrame
logger utils.Logger
hasNewAck bool // true as soon as we received an ack-eliciting new packet
}
func newReceivedPacketTracker(
rttStats *utils.RTTStats,
logger utils.Logger,
) *receivedPacketTracker {
return &receivedPacketTracker{
packetHistory: newReceivedPacketHistory(),
maxAckDelay: protocol.MaxAckDelay,
rttStats: rttStats,
logger: logger,
}
func newReceivedPacketTracker() *receivedPacketTracker {
return &receivedPacketTracker{packetHistory: *newReceivedPacketHistory()}
}
func (h *receivedPacketTracker) ReceivedPacket(pn protocol.PacketNumber, ecn protocol.ECN, rcvTime time.Time, ackEliciting bool) error {
@ -50,16 +29,6 @@ func (h *receivedPacketTracker) ReceivedPacket(pn protocol.PacketNumber, ecn pro
return fmt.Errorf("recevedPacketTracker BUG: ReceivedPacket called for old / duplicate packet %d", pn)
}
isMissing := h.isMissing(pn)
if pn >= h.largestObserved {
h.largestObserved = pn
h.largestObservedRcvdTime = rcvTime
}
if ackEliciting {
h.hasNewAck = true
h.maybeQueueACK(pn, rcvTime, ecn, isMissing)
}
//nolint:exhaustive // Only need to count ECT(0), ECT(1) and ECN-CE.
switch ecn {
case protocol.ECT0:
@ -69,12 +38,99 @@ func (h *receivedPacketTracker) ReceivedPacket(pn protocol.PacketNumber, ecn pro
case protocol.ECNCE:
h.ecnce++
}
if !ackEliciting {
return nil
}
h.hasNewAck = true
return nil
}
func (h *receivedPacketTracker) GetAckFrame() *wire.AckFrame {
if !h.hasNewAck {
return nil
}
// This function always returns the same ACK frame struct, filled with the most recent values.
ack := h.lastAck
if ack == nil {
ack = &wire.AckFrame{}
}
ack.Reset()
ack.ECT0 = h.ect0
ack.ECT1 = h.ect1
ack.ECNCE = h.ecnce
ack.AckRanges = h.packetHistory.AppendAckRanges(ack.AckRanges)
h.lastAck = ack
h.hasNewAck = false
return ack
}
func (h *receivedPacketTracker) IsPotentiallyDuplicate(pn protocol.PacketNumber) bool {
return h.packetHistory.IsPotentiallyDuplicate(pn)
}
// number of ack-eliciting packets received before sending an ACK
const packetsBeforeAck = 2
// The appDataReceivedPacketTracker tracks packets received in the Application Data packet number space.
// It waits until at least 2 packets were received before queueing an ACK, or until the max_ack_delay was reached.
type appDataReceivedPacketTracker struct {
receivedPacketTracker
largestObservedRcvdTime time.Time
largestObserved protocol.PacketNumber
ignoreBelow protocol.PacketNumber
maxAckDelay time.Duration
ackQueued bool // true if we need send a new ACK
ackElicitingPacketsReceivedSinceLastAck int
ackAlarm time.Time
logger utils.Logger
}
func newAppDataReceivedPacketTracker(logger utils.Logger) *appDataReceivedPacketTracker {
h := &appDataReceivedPacketTracker{
receivedPacketTracker: *newReceivedPacketTracker(),
maxAckDelay: protocol.MaxAckDelay,
logger: logger,
}
return h
}
func (h *appDataReceivedPacketTracker) ReceivedPacket(pn protocol.PacketNumber, ecn protocol.ECN, rcvTime time.Time, ackEliciting bool) error {
if err := h.receivedPacketTracker.ReceivedPacket(pn, ecn, rcvTime, ackEliciting); err != nil {
return err
}
if pn >= h.largestObserved {
h.largestObserved = pn
h.largestObservedRcvdTime = rcvTime
}
if !ackEliciting {
return nil
}
h.ackElicitingPacketsReceivedSinceLastAck++
isMissing := h.isMissing(pn)
if !h.ackQueued && h.shouldQueueACK(pn, ecn, isMissing) {
h.ackQueued = true
h.ackAlarm = time.Time{} // cancel the ack alarm
}
if !h.ackQueued {
// No ACK queued, but we'll need to acknowledge the packet after max_ack_delay.
h.ackAlarm = rcvTime.Add(h.maxAckDelay)
if h.logger.Debug() {
h.logger.Debugf("\tSetting ACK timer to max ack delay: %s", h.maxAckDelay)
}
}
return nil
}
// IgnoreBelow sets a lower limit for acknowledging packets.
// Packets with packet numbers smaller than p will not be acked.
func (h *receivedPacketTracker) IgnoreBelow(pn protocol.PacketNumber) {
func (h *appDataReceivedPacketTracker) IgnoreBelow(pn protocol.PacketNumber) {
if pn <= h.ignoreBelow {
return
}
@ -86,14 +142,14 @@ func (h *receivedPacketTracker) IgnoreBelow(pn protocol.PacketNumber) {
}
// isMissing says if a packet was reported missing in the last ACK.
func (h *receivedPacketTracker) isMissing(p protocol.PacketNumber) bool {
func (h *appDataReceivedPacketTracker) isMissing(p protocol.PacketNumber) bool {
if h.lastAck == nil || p < h.ignoreBelow {
return false
}
return p < h.lastAck.LargestAcked() && !h.lastAck.AcksPacket(p)
}
func (h *receivedPacketTracker) hasNewMissingPackets() bool {
func (h *appDataReceivedPacketTracker) hasNewMissingPackets() bool {
if h.lastAck == nil {
return false
}
@ -101,31 +157,21 @@ func (h *receivedPacketTracker) hasNewMissingPackets() bool {
return highestRange.Smallest > h.lastAck.LargestAcked()+1 && highestRange.Len() == 1
}
// maybeQueueACK queues an ACK, if necessary.
func (h *receivedPacketTracker) maybeQueueACK(pn protocol.PacketNumber, rcvTime time.Time, ecn protocol.ECN, wasMissing bool) {
func (h *appDataReceivedPacketTracker) shouldQueueACK(pn protocol.PacketNumber, ecn protocol.ECN, wasMissing bool) bool {
// always acknowledge the first packet
if h.lastAck == nil {
if !h.ackQueued {
h.logger.Debugf("\tQueueing ACK because the first packet should be acknowledged.")
return true
}
h.ackQueued = true
return
}
if h.ackQueued {
return
}
h.ackElicitingPacketsReceivedSinceLastAck++
// Send an ACK if this packet was reported missing in an ACK sent before.
// Ack decimation with reordering relies on the timer to send an ACK, but if
// missing packets we reported in the previous ack, send an ACK immediately.
// missing packets we reported in the previous ACK, send an ACK immediately.
if wasMissing {
if h.logger.Debug() {
h.logger.Debugf("\tQueueing ACK because packet %d was missing before.", pn)
}
h.ackQueued = true
return true
}
// send an ACK every 2 ack-eliciting packets
@ -133,68 +179,42 @@ func (h *receivedPacketTracker) maybeQueueACK(pn protocol.PacketNumber, rcvTime
if h.logger.Debug() {
h.logger.Debugf("\tQueueing ACK because packet %d packets were received after the last ACK (using initial threshold: %d).", h.ackElicitingPacketsReceivedSinceLastAck, packetsBeforeAck)
}
h.ackQueued = true
} else if h.ackAlarm.IsZero() {
if h.logger.Debug() {
h.logger.Debugf("\tSetting ACK timer to max ack delay: %s", h.maxAckDelay)
}
h.ackAlarm = rcvTime.Add(h.maxAckDelay)
return true
}
// queue an ACK if there are new missing packets to report
if h.hasNewMissingPackets() {
h.logger.Debugf("\tQueuing ACK because there's a new missing packet to report.")
h.ackQueued = true
return true
}
// queue an ACK if the packet was ECN-CE marked
if ecn == protocol.ECNCE {
h.logger.Debugf("\tQueuing ACK because the packet was ECN-CE marked.")
h.ackQueued = true
}
if h.ackQueued {
// cancel the ack alarm
h.ackAlarm = time.Time{}
return true
}
return false
}
func (h *receivedPacketTracker) GetAckFrame(onlyIfQueued bool) *wire.AckFrame {
if !h.hasNewAck {
return nil
}
func (h *appDataReceivedPacketTracker) GetAckFrame(onlyIfQueued bool) *wire.AckFrame {
now := time.Now()
if onlyIfQueued {
if !h.ackQueued && (h.ackAlarm.IsZero() || h.ackAlarm.After(now)) {
if onlyIfQueued && !h.ackQueued {
if h.ackAlarm.IsZero() || h.ackAlarm.After(now) {
return nil
}
if h.logger.Debug() && !h.ackQueued && !h.ackAlarm.IsZero() {
if h.logger.Debug() && !h.ackAlarm.IsZero() {
h.logger.Debugf("Sending ACK because the ACK timer expired.")
}
}
// This function always returns the same ACK frame struct, filled with the most recent values.
ack := h.lastAck
ack := h.receivedPacketTracker.GetAckFrame()
if ack == nil {
ack = &wire.AckFrame{}
return nil
}
ack.Reset()
ack.DelayTime = max(0, now.Sub(h.largestObservedRcvdTime))
ack.ECT0 = h.ect0
ack.ECT1 = h.ect1
ack.ECNCE = h.ecnce
ack.AckRanges = h.packetHistory.AppendAckRanges(ack.AckRanges)
h.lastAck = ack
h.ackAlarm = time.Time{}
h.ackQueued = false
h.hasNewAck = false
h.ackAlarm = time.Time{}
h.ackElicitingPacketsReceivedSinceLastAck = 0
return ack
}
func (h *receivedPacketTracker) GetAlarmTimeout() time.Time { return h.ackAlarm }
func (h *receivedPacketTracker) IsPotentiallyDuplicate(pn protocol.PacketNumber) bool {
return h.packetHistory.IsPotentiallyDuplicate(pn)
}
func (h *appDataReceivedPacketTracker) GetAlarmTimeout() time.Time { return h.ackAlarm }

View File

@ -48,10 +48,12 @@ func (c *baseFlowController) AddBytesSent(n protocol.ByteCount) {
}
// UpdateSendWindow is called after receiving a MAX_{STREAM_}DATA frame.
func (c *baseFlowController) UpdateSendWindow(offset protocol.ByteCount) {
func (c *baseFlowController) UpdateSendWindow(offset protocol.ByteCount) (updated bool) {
if offset > c.sendWindow {
c.sendWindow = offset
return true
}
return false
}
func (c *baseFlowController) sendWindowSize() protocol.ByteCount {

View File

@ -5,7 +5,7 @@ import "github.com/quic-go/quic-go/internal/protocol"
type flowController interface {
// for sending
SendWindowSize() protocol.ByteCount
UpdateSendWindow(protocol.ByteCount)
UpdateSendWindow(protocol.ByteCount) (updated bool)
AddBytesSent(protocol.ByteCount)
// for receiving
AddBytesRead(protocol.ByteCount)
@ -16,12 +16,11 @@ type flowController interface {
// A StreamFlowController is a flow controller for a QUIC stream.
type StreamFlowController interface {
flowController
// for receiving
// UpdateHighestReceived should be called when a new highest offset is received
// UpdateHighestReceived is called when a new highest offset is received
// final has to be to true if this is the final offset of the stream,
// as contained in a STREAM frame with FIN bit, and the RESET_STREAM frame
UpdateHighestReceived(offset protocol.ByteCount, final bool) error
// Abandon should be called when reading from the stream is aborted early,
// Abandon is called when reading from the stream is aborted early,
// and there won't be any further calls to AddBytesRead.
Abandon()
}

View File

@ -1,13 +1,12 @@
package handshake
import (
"crypto/cipher"
"encoding/binary"
"github.com/quic-go/quic-go/internal/protocol"
)
func createAEAD(suite *cipherSuite, trafficSecret []byte, v protocol.VersionNumber) cipher.AEAD {
func createAEAD(suite *cipherSuite, trafficSecret []byte, v protocol.Version) *xorNonceAEAD {
keyLabel := hkdfLabelKeyV1
ivLabel := hkdfLabelIVV1
if v == protocol.Version2 {
@ -20,28 +19,26 @@ func createAEAD(suite *cipherSuite, trafficSecret []byte, v protocol.VersionNumb
}
type longHeaderSealer struct {
aead cipher.AEAD
aead *xorNonceAEAD
headerProtector headerProtector
// use a single slice to avoid allocations
nonceBuf []byte
nonceBuf [8]byte
}
var _ LongHeaderSealer = &longHeaderSealer{}
func newLongHeaderSealer(aead cipher.AEAD, headerProtector headerProtector) LongHeaderSealer {
func newLongHeaderSealer(aead *xorNonceAEAD, headerProtector headerProtector) LongHeaderSealer {
if aead.NonceSize() != 8 {
panic("unexpected nonce size")
}
return &longHeaderSealer{
aead: aead,
headerProtector: headerProtector,
nonceBuf: make([]byte, aead.NonceSize()),
}
}
func (s *longHeaderSealer) Seal(dst, src []byte, pn protocol.PacketNumber, ad []byte) []byte {
binary.BigEndian.PutUint64(s.nonceBuf[len(s.nonceBuf)-8:], uint64(pn))
// The AEAD we're using here will be the qtls.aeadAESGCM13.
// It uses the nonce provided here and XOR it with the IV.
return s.aead.Seal(dst, s.nonceBuf, src, ad)
binary.BigEndian.PutUint64(s.nonceBuf[:], uint64(pn))
return s.aead.Seal(dst, s.nonceBuf[:], src, ad)
}
func (s *longHeaderSealer) EncryptHeader(sample []byte, firstByte *byte, pnBytes []byte) {
@ -53,21 +50,23 @@ func (s *longHeaderSealer) Overhead() int {
}
type longHeaderOpener struct {
aead cipher.AEAD
aead *xorNonceAEAD
headerProtector headerProtector
highestRcvdPN protocol.PacketNumber // highest packet number received (which could be successfully unprotected)
// use a single slice to avoid allocations
nonceBuf []byte
// use a single array to avoid allocations
nonceBuf [8]byte
}
var _ LongHeaderOpener = &longHeaderOpener{}
func newLongHeaderOpener(aead cipher.AEAD, headerProtector headerProtector) LongHeaderOpener {
func newLongHeaderOpener(aead *xorNonceAEAD, headerProtector headerProtector) LongHeaderOpener {
if aead.NonceSize() != 8 {
panic("unexpected nonce size")
}
return &longHeaderOpener{
aead: aead,
headerProtector: headerProtector,
nonceBuf: make([]byte, aead.NonceSize()),
}
}
@ -76,10 +75,8 @@ func (o *longHeaderOpener) DecodePacketNumber(wirePN protocol.PacketNumber, wire
}
func (o *longHeaderOpener) Open(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error) {
binary.BigEndian.PutUint64(o.nonceBuf[len(o.nonceBuf)-8:], uint64(pn))
// The AEAD we're using here will be the qtls.aeadAESGCM13.
// It uses the nonce provided here and XOR it with the IV.
dec, err := o.aead.Open(dst, o.nonceBuf, src, ad)
binary.BigEndian.PutUint64(o.nonceBuf[:], uint64(pn))
dec, err := o.aead.Open(dst, o.nonceBuf[:], src, ad)
if err == nil {
o.highestRcvdPN = max(o.highestRcvdPN, pn)
} else {

View File

@ -18,7 +18,7 @@ type cipherSuite struct {
ID uint16
Hash crypto.Hash
KeyLen int
AEAD func(key, nonceMask []byte) cipher.AEAD
AEAD func(key, nonceMask []byte) *xorNonceAEAD
}
func (s cipherSuite) IVLen() int { return aeadNonceLength }
@ -36,7 +36,7 @@ func getCipherSuite(id uint16) *cipherSuite {
}
}
func aeadAESGCMTLS13(key, nonceMask []byte) cipher.AEAD {
func aeadAESGCMTLS13(key, nonceMask []byte) *xorNonceAEAD {
if len(nonceMask) != aeadNonceLength {
panic("tls: internal error: wrong nonce length")
}
@ -54,7 +54,7 @@ func aeadAESGCMTLS13(key, nonceMask []byte) cipher.AEAD {
return ret
}
func aeadChaCha20Poly1305(key, nonceMask []byte) cipher.AEAD {
func aeadChaCha20Poly1305(key, nonceMask []byte) *xorNonceAEAD {
if len(nonceMask) != aeadNonceLength {
panic("tls: internal error: wrong nonce length")
}

View File

@ -8,7 +8,6 @@ import (
"fmt"
"net"
"strings"
"sync"
"sync/atomic"
"time"
@ -33,7 +32,7 @@ type cryptoSetup struct {
events []Event
version protocol.VersionNumber
version protocol.Version
ourParams *wire.TransportParameters
peerParams *wire.TransportParameters
@ -48,8 +47,6 @@ type cryptoSetup struct {
perspective protocol.Perspective
mutex sync.Mutex // protects all members below
handshakeCompleteTime time.Time
zeroRTTOpener LongHeaderOpener // only set for the server
@ -79,7 +76,7 @@ func NewCryptoSetupClient(
rttStats *utils.RTTStats,
tracer *logging.ConnectionTracer,
logger utils.Logger,
version protocol.VersionNumber,
version protocol.Version,
) CryptoSetup {
cs := newCryptoSetup(
connID,
@ -114,7 +111,7 @@ func NewCryptoSetupServer(
rttStats *utils.RTTStats,
tracer *logging.ConnectionTracer,
logger utils.Logger,
version protocol.VersionNumber,
version protocol.Version,
) CryptoSetup {
cs := newCryptoSetup(
connID,
@ -172,7 +169,7 @@ func newCryptoSetup(
tracer *logging.ConnectionTracer,
logger utils.Logger,
perspective protocol.Perspective,
version protocol.VersionNumber,
version protocol.Version,
) *cryptoSetup {
initialSealer, initialOpener := NewInitialAEAD(connID, perspective, version)
if tracer != nil && tracer.UpdatedKeyFromTLS != nil {
@ -269,10 +266,10 @@ func (h *cryptoSetup) handleEvent(ev tls.QUICEvent) (done bool, err error) {
case tls.QUICNoEvent:
return true, nil
case tls.QUICSetReadSecret:
h.SetReadKey(ev.Level, ev.Suite, ev.Data)
h.setReadKey(ev.Level, ev.Suite, ev.Data)
return false, nil
case tls.QUICSetWriteSecret:
h.SetWriteKey(ev.Level, ev.Suite, ev.Data)
h.setWriteKey(ev.Level, ev.Suite, ev.Data)
return false, nil
case tls.QUICTransportParameters:
return false, h.handleTransportParameters(ev.Data)
@ -434,19 +431,16 @@ func (h *cryptoSetup) handleSessionTicket(sessionTicketData []byte, using0RTT bo
func (h *cryptoSetup) rejected0RTT() {
h.logger.Debugf("0-RTT was rejected. Dropping 0-RTT keys.")
h.mutex.Lock()
had0RTTKeys := h.zeroRTTSealer != nil
h.zeroRTTSealer = nil
h.mutex.Unlock()
if had0RTTKeys {
h.events = append(h.events, Event{Kind: EventDiscard0RTTKeys})
}
}
func (h *cryptoSetup) SetReadKey(el tls.QUICEncryptionLevel, suiteID uint16, trafficSecret []byte) {
func (h *cryptoSetup) setReadKey(el tls.QUICEncryptionLevel, suiteID uint16, trafficSecret []byte) {
suite := getCipherSuite(suiteID)
h.mutex.Lock()
//nolint:exhaustive // The TLS stack doesn't export Initial keys.
switch el {
case tls.QUICEncryptionLevelEarly:
@ -478,16 +472,14 @@ func (h *cryptoSetup) SetReadKey(el tls.QUICEncryptionLevel, suiteID uint16, tra
default:
panic("unexpected read encryption level")
}
h.mutex.Unlock()
h.events = append(h.events, Event{Kind: EventReceivedReadKeys})
if h.tracer != nil && h.tracer.UpdatedKeyFromTLS != nil {
h.tracer.UpdatedKeyFromTLS(qtls.FromTLSEncryptionLevel(el), h.perspective.Opposite())
}
}
func (h *cryptoSetup) SetWriteKey(el tls.QUICEncryptionLevel, suiteID uint16, trafficSecret []byte) {
func (h *cryptoSetup) setWriteKey(el tls.QUICEncryptionLevel, suiteID uint16, trafficSecret []byte) {
suite := getCipherSuite(suiteID)
h.mutex.Lock()
//nolint:exhaustive // The TLS stack doesn't export Initial keys.
switch el {
case tls.QUICEncryptionLevelEarly:
@ -498,7 +490,6 @@ func (h *cryptoSetup) SetWriteKey(el tls.QUICEncryptionLevel, suiteID uint16, tr
createAEAD(suite, trafficSecret, h.version),
newHeaderProtector(suite, trafficSecret, true, h.version),
)
h.mutex.Unlock()
if h.logger.Debug() {
h.logger.Debugf("Installed 0-RTT Write keys (using %s)", tls.CipherSuiteName(suite.ID))
}
@ -533,7 +524,6 @@ func (h *cryptoSetup) SetWriteKey(el tls.QUICEncryptionLevel, suiteID uint16, tr
default:
panic("unexpected write encryption level")
}
h.mutex.Unlock()
if h.tracer != nil && h.tracer.UpdatedKeyFromTLS != nil {
h.tracer.UpdatedKeyFromTLS(qtls.FromTLSEncryptionLevel(el), h.perspective)
}
@ -555,11 +545,9 @@ func (h *cryptoSetup) writeRecord(encLevel tls.QUICEncryptionLevel, p []byte) {
}
func (h *cryptoSetup) DiscardInitialKeys() {
h.mutex.Lock()
dropped := h.initialOpener != nil
h.initialOpener = nil
h.initialSealer = nil
h.mutex.Unlock()
if dropped {
h.logger.Debugf("Dropping Initial keys.")
}
@ -574,22 +562,17 @@ func (h *cryptoSetup) SetHandshakeConfirmed() {
h.aead.SetHandshakeConfirmed()
// drop Handshake keys
var dropped bool
h.mutex.Lock()
if h.handshakeOpener != nil {
h.handshakeOpener = nil
h.handshakeSealer = nil
dropped = true
}
h.mutex.Unlock()
if dropped {
h.logger.Debugf("Dropping Handshake keys.")
}
}
func (h *cryptoSetup) GetInitialSealer() (LongHeaderSealer, error) {
h.mutex.Lock()
defer h.mutex.Unlock()
if h.initialSealer == nil {
return nil, ErrKeysDropped
}
@ -597,9 +580,6 @@ func (h *cryptoSetup) GetInitialSealer() (LongHeaderSealer, error) {
}
func (h *cryptoSetup) Get0RTTSealer() (LongHeaderSealer, error) {
h.mutex.Lock()
defer h.mutex.Unlock()
if h.zeroRTTSealer == nil {
return nil, ErrKeysDropped
}
@ -607,9 +587,6 @@ func (h *cryptoSetup) Get0RTTSealer() (LongHeaderSealer, error) {
}
func (h *cryptoSetup) GetHandshakeSealer() (LongHeaderSealer, error) {
h.mutex.Lock()
defer h.mutex.Unlock()
if h.handshakeSealer == nil {
if h.initialSealer == nil {
return nil, ErrKeysDropped
@ -620,9 +597,6 @@ func (h *cryptoSetup) GetHandshakeSealer() (LongHeaderSealer, error) {
}
func (h *cryptoSetup) Get1RTTSealer() (ShortHeaderSealer, error) {
h.mutex.Lock()
defer h.mutex.Unlock()
if !h.has1RTTSealer {
return nil, ErrKeysNotYetAvailable
}
@ -630,9 +604,6 @@ func (h *cryptoSetup) Get1RTTSealer() (ShortHeaderSealer, error) {
}
func (h *cryptoSetup) GetInitialOpener() (LongHeaderOpener, error) {
h.mutex.Lock()
defer h.mutex.Unlock()
if h.initialOpener == nil {
return nil, ErrKeysDropped
}
@ -640,9 +611,6 @@ func (h *cryptoSetup) GetInitialOpener() (LongHeaderOpener, error) {
}
func (h *cryptoSetup) Get0RTTOpener() (LongHeaderOpener, error) {
h.mutex.Lock()
defer h.mutex.Unlock()
if h.zeroRTTOpener == nil {
if h.initialOpener != nil {
return nil, ErrKeysNotYetAvailable
@ -654,9 +622,6 @@ func (h *cryptoSetup) Get0RTTOpener() (LongHeaderOpener, error) {
}
func (h *cryptoSetup) GetHandshakeOpener() (LongHeaderOpener, error) {
h.mutex.Lock()
defer h.mutex.Unlock()
if h.handshakeOpener == nil {
if h.initialOpener != nil {
return nil, ErrKeysNotYetAvailable
@ -668,9 +633,6 @@ func (h *cryptoSetup) GetHandshakeOpener() (LongHeaderOpener, error) {
}
func (h *cryptoSetup) Get1RTTOpener() (ShortHeaderOpener, error) {
h.mutex.Lock()
defer h.mutex.Unlock()
if h.zeroRTTOpener != nil && time.Since(h.handshakeCompleteTime) > 3*h.rttStats.PTO(true) {
h.zeroRTTOpener = nil
h.logger.Debugf("Dropping 0-RTT keys.")

View File

@ -17,14 +17,14 @@ type headerProtector interface {
DecryptHeader(sample []byte, firstByte *byte, hdrBytes []byte)
}
func hkdfHeaderProtectionLabel(v protocol.VersionNumber) string {
func hkdfHeaderProtectionLabel(v protocol.Version) string {
if v == protocol.Version2 {
return "quicv2 hp"
}
return "quic hp"
}
func newHeaderProtector(suite *cipherSuite, trafficSecret []byte, isLongHeader bool, v protocol.VersionNumber) headerProtector {
func newHeaderProtector(suite *cipherSuite, trafficSecret []byte, isLongHeader bool, v protocol.Version) headerProtector {
hkdfLabel := hkdfHeaderProtectionLabel(v)
switch suite.ID {
case tls.TLS_AES_128_GCM_SHA256, tls.TLS_AES_256_GCM_SHA384:
@ -37,7 +37,7 @@ func newHeaderProtector(suite *cipherSuite, trafficSecret []byte, isLongHeader b
}
type aesHeaderProtector struct {
mask []byte
mask [16]byte // AES always has a 16 byte block size
block cipher.Block
isLongHeader bool
}
@ -52,7 +52,6 @@ func newAESHeaderProtector(suite *cipherSuite, trafficSecret []byte, isLongHeade
}
return &aesHeaderProtector{
block: block,
mask: make([]byte, block.BlockSize()),
isLongHeader: isLongHeader,
}
}
@ -69,7 +68,7 @@ func (p *aesHeaderProtector) apply(sample []byte, firstByte *byte, hdrBytes []by
if len(sample) != len(p.mask) {
panic("invalid sample size")
}
p.block.Encrypt(p.mask, sample)
p.block.Encrypt(p.mask[:], sample)
if p.isLongHeader {
*firstByte ^= p.mask[0] & 0xf
} else {

View File

@ -7,7 +7,7 @@ import (
"golang.org/x/crypto/hkdf"
)
// hkdfExpandLabel HKDF expands a label.
// hkdfExpandLabel HKDF expands a label as defined in RFC 8446, section 7.1.
// Since this implementation avoids using a cryptobyte.Builder, it is about 15% faster than the
// hkdfExpandLabel in the standard library.
func hkdfExpandLabel(hash crypto.Hash, secret, context []byte, label string, length int) []byte {

View File

@ -21,7 +21,7 @@ const (
hkdfLabelIVV2 = "quicv2 iv"
)
func getSalt(v protocol.VersionNumber) []byte {
func getSalt(v protocol.Version) []byte {
if v == protocol.Version2 {
return quicSaltV2
}
@ -31,7 +31,7 @@ func getSalt(v protocol.VersionNumber) []byte {
var initialSuite = getCipherSuite(tls.TLS_AES_128_GCM_SHA256)
// NewInitialAEAD creates a new AEAD for Initial encryption / decryption.
func NewInitialAEAD(connID protocol.ConnectionID, pers protocol.Perspective, v protocol.VersionNumber) (LongHeaderSealer, LongHeaderOpener) {
func NewInitialAEAD(connID protocol.ConnectionID, pers protocol.Perspective, v protocol.Version) (LongHeaderSealer, LongHeaderOpener) {
clientSecret, serverSecret := computeSecrets(connID, v)
var mySecret, otherSecret []byte
if pers == protocol.PerspectiveClient {
@ -51,14 +51,14 @@ func NewInitialAEAD(connID protocol.ConnectionID, pers protocol.Perspective, v p
newLongHeaderOpener(decrypter, newAESHeaderProtector(initialSuite, otherSecret, true, hkdfHeaderProtectionLabel(v)))
}
func computeSecrets(connID protocol.ConnectionID, v protocol.VersionNumber) (clientSecret, serverSecret []byte) {
func computeSecrets(connID protocol.ConnectionID, v protocol.Version) (clientSecret, serverSecret []byte) {
initialSecret := hkdf.Extract(crypto.SHA256.New, connID.Bytes(), getSalt(v))
clientSecret = hkdfExpandLabel(crypto.SHA256, initialSecret, []byte{}, "client in", crypto.SHA256.Size())
serverSecret = hkdfExpandLabel(crypto.SHA256, initialSecret, []byte{}, "server in", crypto.SHA256.Size())
return
}
func computeInitialKeyAndIV(secret []byte, v protocol.VersionNumber) (key, iv []byte) {
func computeInitialKeyAndIV(secret []byte, v protocol.Version) (key, iv []byte) {
keyLabel := hkdfLabelKeyV1
ivLabel := hkdfLabelIVV1
if v == protocol.Version2 {

View File

@ -40,7 +40,7 @@ var (
)
// GetRetryIntegrityTag calculates the integrity tag on a Retry packet
func GetRetryIntegrityTag(retry []byte, origDestConnID protocol.ConnectionID, version protocol.VersionNumber) *[16]byte {
func GetRetryIntegrityTag(retry []byte, origDestConnID protocol.ConnectionID, version protocol.Version) *[16]byte {
retryMutex.Lock()
defer retryMutex.Unlock()

View File

@ -59,7 +59,7 @@ type updatableAEAD struct {
tracer *logging.ConnectionTracer
logger utils.Logger
version protocol.VersionNumber
version protocol.Version
// use a single slice to avoid allocations
nonceBuf []byte
@ -70,7 +70,7 @@ var (
_ ShortHeaderSealer = &updatableAEAD{}
)
func newUpdatableAEAD(rttStats *utils.RTTStats, tracer *logging.ConnectionTracer, logger utils.Logger, version protocol.VersionNumber) *updatableAEAD {
func newUpdatableAEAD(rttStats *utils.RTTStats, tracer *logging.ConnectionTracer, logger utils.Logger, version protocol.Version) *updatableAEAD {
return &updatableAEAD{
firstPacketNumber: protocol.InvalidPacketNumber,
largestAcked: protocol.InvalidPacketNumber,
@ -133,7 +133,7 @@ func (a *updatableAEAD) SetReadKey(suite *cipherSuite, trafficSecret []byte) {
// SetWriteKey sets the write key.
// For the client, this function is called after SetReadKey.
// For the server, this function is called before SetWriteKey.
// For the server, this function is called before SetReadKey.
func (a *updatableAEAD) SetWriteKey(suite *cipherSuite, trafficSecret []byte) {
a.sendAEAD = createAEAD(suite, trafficSecret, a.version)
a.headerEncrypter = newHeaderProtector(suite, trafficSecret, false, a.version)

View File

@ -17,9 +17,9 @@ func (p Perspective) Opposite() Perspective {
func (p Perspective) String() string {
switch p {
case PerspectiveServer:
return "Server"
return "server"
case PerspectiveClient:
return "Client"
return "client"
default:
return "invalid perspective"
}

View File

@ -1,14 +1,17 @@
package protocol
import (
"crypto/rand"
"encoding/binary"
"fmt"
"math"
"sync"
"time"
"golang.org/x/exp/rand"
)
// VersionNumber is a version number as int
type VersionNumber uint32
// Version is a version number as int
type Version uint32
// gQUIC version range as defined in the wiki: https://github.com/quicwg/base-drafts/wiki/QUIC-Versions
const (
@ -18,22 +21,22 @@ const (
// The version numbers, making grepping easier
const (
VersionUnknown VersionNumber = math.MaxUint32
versionDraft29 VersionNumber = 0xff00001d // draft-29 used to be a widely deployed version
Version1 VersionNumber = 0x1
Version2 VersionNumber = 0x6b3343cf
VersionUnknown Version = math.MaxUint32
versionDraft29 Version = 0xff00001d // draft-29 used to be a widely deployed version
Version1 Version = 0x1
Version2 Version = 0x6b3343cf
)
// SupportedVersions lists the versions that the server supports
// must be in sorted descending order
var SupportedVersions = []VersionNumber{Version1, Version2}
var SupportedVersions = []Version{Version1, Version2}
// IsValidVersion says if the version is known to quic-go
func IsValidVersion(v VersionNumber) bool {
func IsValidVersion(v Version) bool {
return v == Version1 || IsSupportedVersion(SupportedVersions, v)
}
func (vn VersionNumber) String() string {
func (vn Version) String() string {
//nolint:exhaustive
switch vn {
case VersionUnknown:
@ -52,16 +55,16 @@ func (vn VersionNumber) String() string {
}
}
func (vn VersionNumber) isGQUIC() bool {
func (vn Version) isGQUIC() bool {
return vn > gquicVersion0 && vn <= maxGquicVersion
}
func (vn VersionNumber) toGQUICVersion() int {
func (vn Version) toGQUICVersion() int {
return int(10*(vn-gquicVersion0)/0x100) + int(vn%0x10)
}
// IsSupportedVersion returns true if the server supports this version
func IsSupportedVersion(supported []VersionNumber, v VersionNumber) bool {
func IsSupportedVersion(supported []Version, v Version) bool {
for _, t := range supported {
if t == v {
return true
@ -74,7 +77,7 @@ func IsSupportedVersion(supported []VersionNumber, v VersionNumber) bool {
// ours is a slice of versions that we support, sorted by our preference (descending)
// theirs is a slice of versions offered by the peer. The order does not matter.
// The bool returned indicates if a matching version was found.
func ChooseSupportedVersion(ours, theirs []VersionNumber) (VersionNumber, bool) {
func ChooseSupportedVersion(ours, theirs []Version) (Version, bool) {
for _, ourVer := range ours {
for _, theirVer := range theirs {
if ourVer == theirVer {
@ -85,19 +88,25 @@ func ChooseSupportedVersion(ours, theirs []VersionNumber) (VersionNumber, bool)
return 0, false
}
// generateReservedVersion generates a reserved version number (v & 0x0f0f0f0f == 0x0a0a0a0a)
func generateReservedVersion() VersionNumber {
b := make([]byte, 4)
_, _ = rand.Read(b) // ignore the error here. Failure to read random data doesn't break anything
return VersionNumber((binary.BigEndian.Uint32(b) | 0x0a0a0a0a) & 0xfafafafa)
var (
versionNegotiationMx sync.Mutex
versionNegotiationRand = rand.New(rand.NewSource(uint64(time.Now().UnixNano())))
)
// generateReservedVersion generates a reserved version (v & 0x0f0f0f0f == 0x0a0a0a0a)
func generateReservedVersion() Version {
var b [4]byte
_, _ = versionNegotiationRand.Read(b[:]) // ignore the error here. Failure to read random data doesn't break anything
return Version((binary.BigEndian.Uint32(b[:]) | 0x0a0a0a0a) & 0xfafafafa)
}
// GetGreasedVersions adds one reserved version number to a slice of version numbers, at a random position
func GetGreasedVersions(supported []VersionNumber) []VersionNumber {
b := make([]byte, 1)
_, _ = rand.Read(b) // ignore the error here. Failure to read random data doesn't break anything
randPos := int(b[0]) % (len(supported) + 1)
greased := make([]VersionNumber, len(supported)+1)
// GetGreasedVersions adds one reserved version number to a slice of version numbers, at a random position.
// It doesn't modify the supported slice.
func GetGreasedVersions(supported []Version) []Version {
versionNegotiationMx.Lock()
defer versionNegotiationMx.Unlock()
randPos := rand.Intn(len(supported) + 1)
greased := make([]Version, len(supported)+1)
copy(greased, supported[:randPos])
greased[randPos] = generateReservedVersion()
copy(greased[randPos+1:], supported[randPos:])

View File

@ -101,8 +101,8 @@ func (e *HandshakeTimeoutError) Is(target error) bool { return target == net.Err
// A VersionNegotiationError occurs when the client and the server can't agree on a QUIC version.
type VersionNegotiationError struct {
Ours []protocol.VersionNumber
Theirs []protocol.VersionNumber
Ours []protocol.Version
Theirs []protocol.Version
}
func (e *VersionNegotiationError) Error() string {

View File

@ -1,23 +1,11 @@
package qtls
import (
"crypto"
"crypto/cipher"
"crypto/tls"
"fmt"
"unsafe"
)
type cipherSuiteTLS13 struct {
ID uint16
KeyLen int
AEAD func(key, fixedNonce []byte) cipher.AEAD
Hash crypto.Hash
}
//go:linkname cipherSuiteTLS13ByID crypto/tls.cipherSuiteTLS13ByID
func cipherSuiteTLS13ByID(id uint16) *cipherSuiteTLS13
//go:linkname cipherSuitesTLS13 crypto/tls.cipherSuitesTLS13
var cipherSuitesTLS13 []unsafe.Pointer

View File

@ -1,12 +1,12 @@
//go:build go1.21
package qtls
import (
"crypto/tls"
"sync"
)
type clientSessionCache struct {
mx sync.Mutex
getData func(earlyData bool) []byte
setData func(data []byte, earlyData bool) (allowEarlyData bool)
wrapped tls.ClientSessionCache
@ -14,7 +14,10 @@ type clientSessionCache struct {
var _ tls.ClientSessionCache = &clientSessionCache{}
func (c clientSessionCache) Put(key string, cs *tls.ClientSessionState) {
func (c *clientSessionCache) Put(key string, cs *tls.ClientSessionState) {
c.mx.Lock()
defer c.mx.Unlock()
if cs == nil {
c.wrapped.Put(key, nil)
return
@ -34,7 +37,10 @@ func (c clientSessionCache) Put(key string, cs *tls.ClientSessionState) {
c.wrapped.Put(key, newCS)
}
func (c clientSessionCache) Get(key string) (*tls.ClientSessionState, bool) {
func (c *clientSessionCache) Get(key string) (*tls.ClientSessionState, bool) {
c.mx.Lock()
defer c.mx.Unlock()
cs, ok := c.wrapped.Get(key)
if !ok || cs == nil {
return cs, ok

View File

@ -27,18 +27,6 @@ func MinTime(a, b time.Time) time.Time {
return a
}
// MinNonZeroTime returns the earlist time that is not time.Time{}
// If both a and b are time.Time{}, it returns time.Time{}
func MinNonZeroTime(a, b time.Time) time.Time {
if a.IsZero() {
return b
}
if b.IsZero() {
return a
}
return MinTime(a, b)
}
// MaxTime returns the later time
func MaxTime(a, b time.Time) time.Time {
if a.After(b) {

View File

@ -22,7 +22,7 @@ type AckFrame struct {
}
// parseAckFrame reads an ACK frame
func parseAckFrame(frame *AckFrame, r *bytes.Reader, typ uint64, ackDelayExponent uint8, _ protocol.VersionNumber) error {
func parseAckFrame(frame *AckFrame, r *bytes.Reader, typ uint64, ackDelayExponent uint8, _ protocol.Version) error {
ecn := typ == ackECNFrameType
la, err := quicvarint.Read(r)
@ -110,7 +110,7 @@ func parseAckFrame(frame *AckFrame, r *bytes.Reader, typ uint64, ackDelayExponen
}
// Append appends an ACK frame.
func (f *AckFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, error) {
func (f *AckFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {
hasECN := f.ECT0 > 0 || f.ECT1 > 0 || f.ECNCE > 0
if hasECN {
b = append(b, ackECNFrameType)
@ -143,7 +143,7 @@ func (f *AckFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, error) {
}
// Length of a written frame
func (f *AckFrame) Length(_ protocol.VersionNumber) protocol.ByteCount {
func (f *AckFrame) Length(_ protocol.Version) protocol.ByteCount {
largestAcked := f.AckRanges[0].Largest
numRanges := f.numEncodableAckRanges()

View File

@ -16,7 +16,7 @@ type ConnectionCloseFrame struct {
ReasonPhrase string
}
func parseConnectionCloseFrame(r *bytes.Reader, typ uint64, _ protocol.VersionNumber) (*ConnectionCloseFrame, error) {
func parseConnectionCloseFrame(r *bytes.Reader, typ uint64, _ protocol.Version) (*ConnectionCloseFrame, error) {
f := &ConnectionCloseFrame{IsApplicationError: typ == applicationCloseFrameType}
ec, err := quicvarint.Read(r)
if err != nil {
@ -53,7 +53,7 @@ func parseConnectionCloseFrame(r *bytes.Reader, typ uint64, _ protocol.VersionNu
}
// Length of a written frame
func (f *ConnectionCloseFrame) Length(protocol.VersionNumber) protocol.ByteCount {
func (f *ConnectionCloseFrame) Length(protocol.Version) protocol.ByteCount {
length := 1 + quicvarint.Len(f.ErrorCode) + quicvarint.Len(uint64(len(f.ReasonPhrase))) + protocol.ByteCount(len(f.ReasonPhrase))
if !f.IsApplicationError {
length += quicvarint.Len(f.FrameType) // for the frame type
@ -61,7 +61,7 @@ func (f *ConnectionCloseFrame) Length(protocol.VersionNumber) protocol.ByteCount
return length
}
func (f *ConnectionCloseFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, error) {
func (f *ConnectionCloseFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {
if f.IsApplicationError {
b = append(b, applicationCloseFrameType)
} else {

View File

@ -14,7 +14,7 @@ type CryptoFrame struct {
Data []byte
}
func parseCryptoFrame(r *bytes.Reader, _ protocol.VersionNumber) (*CryptoFrame, error) {
func parseCryptoFrame(r *bytes.Reader, _ protocol.Version) (*CryptoFrame, error) {
frame := &CryptoFrame{}
offset, err := quicvarint.Read(r)
if err != nil {
@ -38,7 +38,7 @@ func parseCryptoFrame(r *bytes.Reader, _ protocol.VersionNumber) (*CryptoFrame,
return frame, nil
}
func (f *CryptoFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, error) {
func (f *CryptoFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {
b = append(b, cryptoFrameType)
b = quicvarint.Append(b, uint64(f.Offset))
b = quicvarint.Append(b, uint64(len(f.Data)))
@ -47,7 +47,7 @@ func (f *CryptoFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, error)
}
// Length of a written frame
func (f *CryptoFrame) Length(_ protocol.VersionNumber) protocol.ByteCount {
func (f *CryptoFrame) Length(_ protocol.Version) protocol.ByteCount {
return 1 + quicvarint.Len(uint64(f.Offset)) + quicvarint.Len(uint64(len(f.Data))) + protocol.ByteCount(len(f.Data))
}
@ -71,7 +71,7 @@ func (f *CryptoFrame) MaxDataLen(maxSize protocol.ByteCount) protocol.ByteCount
// The frame might not be split if:
// * the size is large enough to fit the whole frame
// * the size is too small to fit even a 1-byte frame. In that case, the frame returned is nil.
func (f *CryptoFrame) MaybeSplitOffFrame(maxSize protocol.ByteCount, version protocol.VersionNumber) (*CryptoFrame, bool /* was splitting required */) {
func (f *CryptoFrame) MaybeSplitOffFrame(maxSize protocol.ByteCount, version protocol.Version) (*CryptoFrame, bool /* was splitting required */) {
if f.Length(version) <= maxSize {
return nil, false
}

View File

@ -12,7 +12,7 @@ type DataBlockedFrame struct {
MaximumData protocol.ByteCount
}
func parseDataBlockedFrame(r *bytes.Reader, _ protocol.VersionNumber) (*DataBlockedFrame, error) {
func parseDataBlockedFrame(r *bytes.Reader, _ protocol.Version) (*DataBlockedFrame, error) {
offset, err := quicvarint.Read(r)
if err != nil {
return nil, err
@ -20,12 +20,12 @@ func parseDataBlockedFrame(r *bytes.Reader, _ protocol.VersionNumber) (*DataBloc
return &DataBlockedFrame{MaximumData: protocol.ByteCount(offset)}, nil
}
func (f *DataBlockedFrame) Append(b []byte, version protocol.VersionNumber) ([]byte, error) {
func (f *DataBlockedFrame) Append(b []byte, version protocol.Version) ([]byte, error) {
b = append(b, dataBlockedFrameType)
return quicvarint.Append(b, uint64(f.MaximumData)), nil
}
// Length of a written frame
func (f *DataBlockedFrame) Length(version protocol.VersionNumber) protocol.ByteCount {
func (f *DataBlockedFrame) Length(version protocol.Version) protocol.ByteCount {
return 1 + quicvarint.Len(uint64(f.MaximumData))
}

View File

@ -20,7 +20,7 @@ type DatagramFrame struct {
Data []byte
}
func parseDatagramFrame(r *bytes.Reader, typ uint64, _ protocol.VersionNumber) (*DatagramFrame, error) {
func parseDatagramFrame(r *bytes.Reader, typ uint64, _ protocol.Version) (*DatagramFrame, error) {
f := &DatagramFrame{}
f.DataLenPresent = typ&0x1 > 0
@ -45,7 +45,7 @@ func parseDatagramFrame(r *bytes.Reader, typ uint64, _ protocol.VersionNumber) (
return f, nil
}
func (f *DatagramFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, error) {
func (f *DatagramFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {
typ := uint8(0x30)
if f.DataLenPresent {
typ ^= 0b1
@ -59,7 +59,7 @@ func (f *DatagramFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, erro
}
// MaxDataLen returns the maximum data length
func (f *DatagramFrame) MaxDataLen(maxSize protocol.ByteCount, version protocol.VersionNumber) protocol.ByteCount {
func (f *DatagramFrame) MaxDataLen(maxSize protocol.ByteCount, version protocol.Version) protocol.ByteCount {
headerLen := protocol.ByteCount(1)
if f.DataLenPresent {
// pretend that the data size will be 1 bytes
@ -77,7 +77,7 @@ func (f *DatagramFrame) MaxDataLen(maxSize protocol.ByteCount, version protocol.
}
// Length of a written frame
func (f *DatagramFrame) Length(_ protocol.VersionNumber) protocol.ByteCount {
func (f *DatagramFrame) Length(_ protocol.Version) protocol.ByteCount {
length := 1 + protocol.ByteCount(len(f.Data))
if f.DataLenPresent {
length += quicvarint.Len(uint64(len(f.Data)))

View File

@ -32,7 +32,7 @@ type ExtendedHeader struct {
parsedLen protocol.ByteCount
}
func (h *ExtendedHeader) parse(b *bytes.Reader, v protocol.VersionNumber) (bool /* reserved bits valid */, error) {
func (h *ExtendedHeader) parse(b *bytes.Reader, v protocol.Version) (bool /* reserved bits valid */, error) {
startLen := b.Len()
// read the (now unencrypted) first byte
var err error
@ -51,7 +51,7 @@ func (h *ExtendedHeader) parse(b *bytes.Reader, v protocol.VersionNumber) (bool
return reservedBitsValid, err
}
func (h *ExtendedHeader) parseLongHeader(b *bytes.Reader, _ protocol.VersionNumber) (bool /* reserved bits valid */, error) {
func (h *ExtendedHeader) parseLongHeader(b *bytes.Reader, _ protocol.Version) (bool /* reserved bits valid */, error) {
if err := h.readPacketNumber(b); err != nil {
return false, err
}
@ -95,7 +95,7 @@ func (h *ExtendedHeader) readPacketNumber(b *bytes.Reader) error {
}
// Append appends the Header.
func (h *ExtendedHeader) Append(b []byte, v protocol.VersionNumber) ([]byte, error) {
func (h *ExtendedHeader) Append(b []byte, v protocol.Version) ([]byte, error) {
if h.DestConnectionID.Len() > protocol.MaxConnIDLen {
return nil, fmt.Errorf("invalid connection ID length: %d bytes", h.DestConnectionID.Len())
}
@ -162,7 +162,7 @@ func (h *ExtendedHeader) ParsedLen() protocol.ByteCount {
}
// GetLength determines the length of the Header.
func (h *ExtendedHeader) GetLength(_ protocol.VersionNumber) protocol.ByteCount {
func (h *ExtendedHeader) GetLength(_ protocol.Version) protocol.ByteCount {
length := 1 /* type byte */ + 4 /* version */ + 1 /* dest conn ID len */ + protocol.ByteCount(h.DestConnectionID.Len()) + 1 /* src conn ID len */ + protocol.ByteCount(h.SrcConnectionID.Len()) + protocol.ByteCount(h.PacketNumberLen) + 2 /* length */
if h.Type == protocol.PacketTypeInitial {
length += quicvarint.Len(uint64(len(h.Token))) + protocol.ByteCount(len(h.Token))

View File

@ -36,7 +36,8 @@ const (
handshakeDoneFrameType = 0x1e
)
type frameParser struct {
// The FrameParser parses QUIC frames, one by one.
type FrameParser struct {
r bytes.Reader // cached bytes.Reader, so we don't have to repeatedly allocate them
ackDelayExponent uint8
@ -47,11 +48,9 @@ type frameParser struct {
ackFrame *AckFrame
}
var _ FrameParser = &frameParser{}
// NewFrameParser creates a new frame parser.
func NewFrameParser(supportsDatagrams bool) *frameParser {
return &frameParser{
func NewFrameParser(supportsDatagrams bool) *FrameParser {
return &FrameParser{
r: *bytes.NewReader(nil),
supportsDatagrams: supportsDatagrams,
ackFrame: &AckFrame{},
@ -60,7 +59,7 @@ func NewFrameParser(supportsDatagrams bool) *frameParser {
// ParseNext parses the next frame.
// It skips PADDING frames.
func (p *frameParser) ParseNext(data []byte, encLevel protocol.EncryptionLevel, v protocol.VersionNumber) (int, Frame, error) {
func (p *FrameParser) ParseNext(data []byte, encLevel protocol.EncryptionLevel, v protocol.Version) (int, Frame, error) {
startLen := len(data)
p.r.Reset(data)
frame, err := p.parseNext(&p.r, encLevel, v)
@ -69,7 +68,7 @@ func (p *frameParser) ParseNext(data []byte, encLevel protocol.EncryptionLevel,
return n, frame, err
}
func (p *frameParser) parseNext(r *bytes.Reader, encLevel protocol.EncryptionLevel, v protocol.VersionNumber) (Frame, error) {
func (p *FrameParser) parseNext(r *bytes.Reader, encLevel protocol.EncryptionLevel, v protocol.Version) (Frame, error) {
for r.Len() != 0 {
typ, err := quicvarint.Read(r)
if err != nil {
@ -95,7 +94,7 @@ func (p *frameParser) parseNext(r *bytes.Reader, encLevel protocol.EncryptionLev
return nil, nil
}
func (p *frameParser) parseFrame(r *bytes.Reader, typ uint64, encLevel protocol.EncryptionLevel, v protocol.VersionNumber) (Frame, error) {
func (p *FrameParser) parseFrame(r *bytes.Reader, typ uint64, encLevel protocol.EncryptionLevel, v protocol.Version) (Frame, error) {
var frame Frame
var err error
if typ&0xf8 == 0x8 {
@ -163,7 +162,7 @@ func (p *frameParser) parseFrame(r *bytes.Reader, typ uint64, encLevel protocol.
return frame, nil
}
func (p *frameParser) isAllowedAtEncLevel(f Frame, encLevel protocol.EncryptionLevel) bool {
func (p *FrameParser) isAllowedAtEncLevel(f Frame, encLevel protocol.EncryptionLevel) bool {
switch encLevel {
case protocol.EncryptionInitial, protocol.EncryptionHandshake:
switch f.(type) {
@ -186,6 +185,8 @@ func (p *frameParser) isAllowedAtEncLevel(f Frame, encLevel protocol.EncryptionL
}
}
func (p *frameParser) SetAckDelayExponent(exp uint8) {
// SetAckDelayExponent sets the acknowledgment delay exponent (sent in the transport parameters).
// This value is used to scale the ACK Delay field in the ACK frame.
func (p *FrameParser) SetAckDelayExponent(exp uint8) {
p.ackDelayExponent = exp
}

View File

@ -7,11 +7,11 @@ import (
// A HandshakeDoneFrame is a HANDSHAKE_DONE frame
type HandshakeDoneFrame struct{}
func (f *HandshakeDoneFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, error) {
func (f *HandshakeDoneFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {
return append(b, handshakeDoneFrameType), nil
}
// Length of a written frame
func (f *HandshakeDoneFrame) Length(_ protocol.VersionNumber) protocol.ByteCount {
func (f *HandshakeDoneFrame) Length(_ protocol.Version) protocol.ByteCount {
return 1
}

View File

@ -85,11 +85,11 @@ func IsLongHeaderPacket(firstByte byte) bool {
// ParseVersion parses the QUIC version.
// It should only be called for Long Header packets (Short Header packets don't contain a version number).
func ParseVersion(data []byte) (protocol.VersionNumber, error) {
func ParseVersion(data []byte) (protocol.Version, error) {
if len(data) < 5 {
return 0, io.EOF
}
return protocol.VersionNumber(binary.BigEndian.Uint32(data[1:5])), nil
return protocol.Version(binary.BigEndian.Uint32(data[1:5])), nil
}
// IsVersionNegotiationPacket says if this is a version negotiation packet
@ -109,7 +109,7 @@ func Is0RTTPacket(b []byte) bool {
if !IsLongHeaderPacket(b[0]) {
return false
}
version := protocol.VersionNumber(binary.BigEndian.Uint32(b[1:5]))
version := protocol.Version(binary.BigEndian.Uint32(b[1:5]))
//nolint:exhaustive // We only need to test QUIC versions that we support.
switch version {
case protocol.Version1:
@ -128,7 +128,7 @@ type Header struct {
typeByte byte
Type protocol.PacketType
Version protocol.VersionNumber
Version protocol.Version
SrcConnectionID protocol.ConnectionID
DestConnectionID protocol.ConnectionID
@ -184,7 +184,7 @@ func (h *Header) parseLongHeader(b *bytes.Reader) error {
if err != nil {
return err
}
h.Version = protocol.VersionNumber(v)
h.Version = protocol.Version(v)
if h.Version != 0 && h.typeByte&0x40 == 0 {
return errors.New("not a QUIC packet")
}
@ -278,7 +278,7 @@ func (h *Header) ParsedLen() protocol.ByteCount {
// ParseExtended parses the version dependent part of the header.
// The Reader has to be set such that it points to the first byte of the header.
func (h *Header) ParseExtended(b *bytes.Reader, ver protocol.VersionNumber) (*ExtendedHeader, error) {
func (h *Header) ParseExtended(b *bytes.Reader, ver protocol.Version) (*ExtendedHeader, error) {
extHdr := h.toExtendedHeader()
reservedBitsValid, err := extHdr.parse(b, ver)
if err != nil {

View File

@ -6,12 +6,6 @@ import (
// A Frame in QUIC
type Frame interface {
Append(b []byte, version protocol.VersionNumber) ([]byte, error)
Length(version protocol.VersionNumber) protocol.ByteCount
}
// A FrameParser parses QUIC frames, one by one.
type FrameParser interface {
ParseNext([]byte, protocol.EncryptionLevel, protocol.VersionNumber) (int, Frame, error)
SetAckDelayExponent(uint8)
Append(b []byte, version protocol.Version) ([]byte, error)
Length(version protocol.Version) protocol.ByteCount
}

View File

@ -63,7 +63,9 @@ func LogFrame(logger utils.Logger, frame Frame, sent bool) {
logger.Debugf("\t%s &wire.StreamsBlockedFrame{Type: bidi, MaxStreams: %d}", dir, f.StreamLimit)
}
case *NewConnectionIDFrame:
logger.Debugf("\t%s &wire.NewConnectionIDFrame{SequenceNumber: %d, ConnectionID: %s, StatelessResetToken: %#x}", dir, f.SequenceNumber, f.ConnectionID, f.StatelessResetToken)
logger.Debugf("\t%s &wire.NewConnectionIDFrame{SequenceNumber: %d, RetirePriorTo: %d, ConnectionID: %s, StatelessResetToken: %#x}", dir, f.SequenceNumber, f.RetirePriorTo, f.ConnectionID, f.StatelessResetToken)
case *RetireConnectionIDFrame:
logger.Debugf("\t%s &wire.RetireConnectionIDFrame{SequenceNumber: %d}", dir, f.SequenceNumber)
case *NewTokenFrame:
logger.Debugf("\t%s &wire.NewTokenFrame{Token: %#x}", dir, f.Token)
default:

View File

@ -13,7 +13,7 @@ type MaxDataFrame struct {
}
// parseMaxDataFrame parses a MAX_DATA frame
func parseMaxDataFrame(r *bytes.Reader, _ protocol.VersionNumber) (*MaxDataFrame, error) {
func parseMaxDataFrame(r *bytes.Reader, _ protocol.Version) (*MaxDataFrame, error) {
frame := &MaxDataFrame{}
byteOffset, err := quicvarint.Read(r)
if err != nil {
@ -23,13 +23,13 @@ func parseMaxDataFrame(r *bytes.Reader, _ protocol.VersionNumber) (*MaxDataFrame
return frame, nil
}
func (f *MaxDataFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, error) {
func (f *MaxDataFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {
b = append(b, maxDataFrameType)
b = quicvarint.Append(b, uint64(f.MaximumData))
return b, nil
}
// Length of a written frame
func (f *MaxDataFrame) Length(_ protocol.VersionNumber) protocol.ByteCount {
func (f *MaxDataFrame) Length(_ protocol.Version) protocol.ByteCount {
return 1 + quicvarint.Len(uint64(f.MaximumData))
}

View File

@ -13,7 +13,7 @@ type MaxStreamDataFrame struct {
MaximumStreamData protocol.ByteCount
}
func parseMaxStreamDataFrame(r *bytes.Reader, _ protocol.VersionNumber) (*MaxStreamDataFrame, error) {
func parseMaxStreamDataFrame(r *bytes.Reader, _ protocol.Version) (*MaxStreamDataFrame, error) {
sid, err := quicvarint.Read(r)
if err != nil {
return nil, err
@ -29,7 +29,7 @@ func parseMaxStreamDataFrame(r *bytes.Reader, _ protocol.VersionNumber) (*MaxStr
}, nil
}
func (f *MaxStreamDataFrame) Append(b []byte, version protocol.VersionNumber) ([]byte, error) {
func (f *MaxStreamDataFrame) Append(b []byte, version protocol.Version) ([]byte, error) {
b = append(b, maxStreamDataFrameType)
b = quicvarint.Append(b, uint64(f.StreamID))
b = quicvarint.Append(b, uint64(f.MaximumStreamData))
@ -37,6 +37,6 @@ func (f *MaxStreamDataFrame) Append(b []byte, version protocol.VersionNumber) ([
}
// Length of a written frame
func (f *MaxStreamDataFrame) Length(version protocol.VersionNumber) protocol.ByteCount {
func (f *MaxStreamDataFrame) Length(version protocol.Version) protocol.ByteCount {
return 1 + quicvarint.Len(uint64(f.StreamID)) + quicvarint.Len(uint64(f.MaximumStreamData))
}

View File

@ -14,7 +14,7 @@ type MaxStreamsFrame struct {
MaxStreamNum protocol.StreamNum
}
func parseMaxStreamsFrame(r *bytes.Reader, typ uint64, _ protocol.VersionNumber) (*MaxStreamsFrame, error) {
func parseMaxStreamsFrame(r *bytes.Reader, typ uint64, _ protocol.Version) (*MaxStreamsFrame, error) {
f := &MaxStreamsFrame{}
switch typ {
case bidiMaxStreamsFrameType:
@ -33,7 +33,7 @@ func parseMaxStreamsFrame(r *bytes.Reader, typ uint64, _ protocol.VersionNumber)
return f, nil
}
func (f *MaxStreamsFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, error) {
func (f *MaxStreamsFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {
switch f.Type {
case protocol.StreamTypeBidi:
b = append(b, bidiMaxStreamsFrameType)
@ -45,6 +45,6 @@ func (f *MaxStreamsFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, er
}
// Length of a written frame
func (f *MaxStreamsFrame) Length(protocol.VersionNumber) protocol.ByteCount {
func (f *MaxStreamsFrame) Length(protocol.Version) protocol.ByteCount {
return 1 + quicvarint.Len(uint64(f.MaxStreamNum))
}

View File

@ -18,7 +18,7 @@ type NewConnectionIDFrame struct {
StatelessResetToken protocol.StatelessResetToken
}
func parseNewConnectionIDFrame(r *bytes.Reader, _ protocol.VersionNumber) (*NewConnectionIDFrame, error) {
func parseNewConnectionIDFrame(r *bytes.Reader, _ protocol.Version) (*NewConnectionIDFrame, error) {
seq, err := quicvarint.Read(r)
if err != nil {
return nil, err
@ -57,7 +57,7 @@ func parseNewConnectionIDFrame(r *bytes.Reader, _ protocol.VersionNumber) (*NewC
return frame, nil
}
func (f *NewConnectionIDFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, error) {
func (f *NewConnectionIDFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {
b = append(b, newConnectionIDFrameType)
b = quicvarint.Append(b, f.SequenceNumber)
b = quicvarint.Append(b, f.RetirePriorTo)
@ -72,6 +72,6 @@ func (f *NewConnectionIDFrame) Append(b []byte, _ protocol.VersionNumber) ([]byt
}
// Length of a written frame
func (f *NewConnectionIDFrame) Length(protocol.VersionNumber) protocol.ByteCount {
func (f *NewConnectionIDFrame) Length(protocol.Version) protocol.ByteCount {
return 1 + quicvarint.Len(f.SequenceNumber) + quicvarint.Len(f.RetirePriorTo) + 1 /* connection ID length */ + protocol.ByteCount(f.ConnectionID.Len()) + 16
}

View File

@ -14,7 +14,7 @@ type NewTokenFrame struct {
Token []byte
}
func parseNewTokenFrame(r *bytes.Reader, _ protocol.VersionNumber) (*NewTokenFrame, error) {
func parseNewTokenFrame(r *bytes.Reader, _ protocol.Version) (*NewTokenFrame, error) {
tokenLen, err := quicvarint.Read(r)
if err != nil {
return nil, err
@ -32,7 +32,7 @@ func parseNewTokenFrame(r *bytes.Reader, _ protocol.VersionNumber) (*NewTokenFra
return &NewTokenFrame{Token: token}, nil
}
func (f *NewTokenFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, error) {
func (f *NewTokenFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {
b = append(b, newTokenFrameType)
b = quicvarint.Append(b, uint64(len(f.Token)))
b = append(b, f.Token...)
@ -40,6 +40,6 @@ func (f *NewTokenFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, erro
}
// Length of a written frame
func (f *NewTokenFrame) Length(protocol.VersionNumber) protocol.ByteCount {
func (f *NewTokenFrame) Length(protocol.Version) protocol.ByteCount {
return 1 + quicvarint.Len(uint64(len(f.Token))) + protocol.ByteCount(len(f.Token))
}

View File

@ -12,7 +12,7 @@ type PathChallengeFrame struct {
Data [8]byte
}
func parsePathChallengeFrame(r *bytes.Reader, _ protocol.VersionNumber) (*PathChallengeFrame, error) {
func parsePathChallengeFrame(r *bytes.Reader, _ protocol.Version) (*PathChallengeFrame, error) {
frame := &PathChallengeFrame{}
if _, err := io.ReadFull(r, frame.Data[:]); err != nil {
if err == io.ErrUnexpectedEOF {
@ -23,13 +23,13 @@ func parsePathChallengeFrame(r *bytes.Reader, _ protocol.VersionNumber) (*PathCh
return frame, nil
}
func (f *PathChallengeFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, error) {
func (f *PathChallengeFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {
b = append(b, pathChallengeFrameType)
b = append(b, f.Data[:]...)
return b, nil
}
// Length of a written frame
func (f *PathChallengeFrame) Length(_ protocol.VersionNumber) protocol.ByteCount {
func (f *PathChallengeFrame) Length(_ protocol.Version) protocol.ByteCount {
return 1 + 8
}

View File

@ -12,7 +12,7 @@ type PathResponseFrame struct {
Data [8]byte
}
func parsePathResponseFrame(r *bytes.Reader, _ protocol.VersionNumber) (*PathResponseFrame, error) {
func parsePathResponseFrame(r *bytes.Reader, _ protocol.Version) (*PathResponseFrame, error) {
frame := &PathResponseFrame{}
if _, err := io.ReadFull(r, frame.Data[:]); err != nil {
if err == io.ErrUnexpectedEOF {
@ -23,13 +23,13 @@ func parsePathResponseFrame(r *bytes.Reader, _ protocol.VersionNumber) (*PathRes
return frame, nil
}
func (f *PathResponseFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, error) {
func (f *PathResponseFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {
b = append(b, pathResponseFrameType)
b = append(b, f.Data[:]...)
return b, nil
}
// Length of a written frame
func (f *PathResponseFrame) Length(_ protocol.VersionNumber) protocol.ByteCount {
func (f *PathResponseFrame) Length(_ protocol.Version) protocol.ByteCount {
return 1 + 8
}

View File

@ -7,11 +7,11 @@ import (
// A PingFrame is a PING frame
type PingFrame struct{}
func (f *PingFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, error) {
func (f *PingFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {
return append(b, pingFrameType), nil
}
// Length of a written frame
func (f *PingFrame) Length(_ protocol.VersionNumber) protocol.ByteCount {
func (f *PingFrame) Length(_ protocol.Version) protocol.ByteCount {
return 1
}

View File

@ -15,7 +15,7 @@ type ResetStreamFrame struct {
FinalSize protocol.ByteCount
}
func parseResetStreamFrame(r *bytes.Reader, _ protocol.VersionNumber) (*ResetStreamFrame, error) {
func parseResetStreamFrame(r *bytes.Reader, _ protocol.Version) (*ResetStreamFrame, error) {
var streamID protocol.StreamID
var byteOffset protocol.ByteCount
sid, err := quicvarint.Read(r)
@ -40,7 +40,7 @@ func parseResetStreamFrame(r *bytes.Reader, _ protocol.VersionNumber) (*ResetStr
}, nil
}
func (f *ResetStreamFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, error) {
func (f *ResetStreamFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {
b = append(b, resetStreamFrameType)
b = quicvarint.Append(b, uint64(f.StreamID))
b = quicvarint.Append(b, uint64(f.ErrorCode))
@ -49,6 +49,6 @@ func (f *ResetStreamFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, e
}
// Length of a written frame
func (f *ResetStreamFrame) Length(version protocol.VersionNumber) protocol.ByteCount {
func (f *ResetStreamFrame) Length(version protocol.Version) protocol.ByteCount {
return 1 + quicvarint.Len(uint64(f.StreamID)) + quicvarint.Len(uint64(f.ErrorCode)) + quicvarint.Len(uint64(f.FinalSize))
}

View File

@ -12,7 +12,7 @@ type RetireConnectionIDFrame struct {
SequenceNumber uint64
}
func parseRetireConnectionIDFrame(r *bytes.Reader, _ protocol.VersionNumber) (*RetireConnectionIDFrame, error) {
func parseRetireConnectionIDFrame(r *bytes.Reader, _ protocol.Version) (*RetireConnectionIDFrame, error) {
seq, err := quicvarint.Read(r)
if err != nil {
return nil, err
@ -20,13 +20,13 @@ func parseRetireConnectionIDFrame(r *bytes.Reader, _ protocol.VersionNumber) (*R
return &RetireConnectionIDFrame{SequenceNumber: seq}, nil
}
func (f *RetireConnectionIDFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, error) {
func (f *RetireConnectionIDFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {
b = append(b, retireConnectionIDFrameType)
b = quicvarint.Append(b, f.SequenceNumber)
return b, nil
}
// Length of a written frame
func (f *RetireConnectionIDFrame) Length(protocol.VersionNumber) protocol.ByteCount {
func (f *RetireConnectionIDFrame) Length(protocol.Version) protocol.ByteCount {
return 1 + quicvarint.Len(f.SequenceNumber)
}

View File

@ -15,7 +15,7 @@ type StopSendingFrame struct {
}
// parseStopSendingFrame parses a STOP_SENDING frame
func parseStopSendingFrame(r *bytes.Reader, _ protocol.VersionNumber) (*StopSendingFrame, error) {
func parseStopSendingFrame(r *bytes.Reader, _ protocol.Version) (*StopSendingFrame, error) {
streamID, err := quicvarint.Read(r)
if err != nil {
return nil, err
@ -32,11 +32,11 @@ func parseStopSendingFrame(r *bytes.Reader, _ protocol.VersionNumber) (*StopSend
}
// Length of a written frame
func (f *StopSendingFrame) Length(_ protocol.VersionNumber) protocol.ByteCount {
func (f *StopSendingFrame) Length(_ protocol.Version) protocol.ByteCount {
return 1 + quicvarint.Len(uint64(f.StreamID)) + quicvarint.Len(uint64(f.ErrorCode))
}
func (f *StopSendingFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, error) {
func (f *StopSendingFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {
b = append(b, stopSendingFrameType)
b = quicvarint.Append(b, uint64(f.StreamID))
b = quicvarint.Append(b, uint64(f.ErrorCode))

View File

@ -13,7 +13,7 @@ type StreamDataBlockedFrame struct {
MaximumStreamData protocol.ByteCount
}
func parseStreamDataBlockedFrame(r *bytes.Reader, _ protocol.VersionNumber) (*StreamDataBlockedFrame, error) {
func parseStreamDataBlockedFrame(r *bytes.Reader, _ protocol.Version) (*StreamDataBlockedFrame, error) {
sid, err := quicvarint.Read(r)
if err != nil {
return nil, err
@ -29,7 +29,7 @@ func parseStreamDataBlockedFrame(r *bytes.Reader, _ protocol.VersionNumber) (*St
}, nil
}
func (f *StreamDataBlockedFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, error) {
func (f *StreamDataBlockedFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {
b = append(b, 0x15)
b = quicvarint.Append(b, uint64(f.StreamID))
b = quicvarint.Append(b, uint64(f.MaximumStreamData))
@ -37,6 +37,6 @@ func (f *StreamDataBlockedFrame) Append(b []byte, _ protocol.VersionNumber) ([]b
}
// Length of a written frame
func (f *StreamDataBlockedFrame) Length(version protocol.VersionNumber) protocol.ByteCount {
func (f *StreamDataBlockedFrame) Length(version protocol.Version) protocol.ByteCount {
return 1 + quicvarint.Len(uint64(f.StreamID)) + quicvarint.Len(uint64(f.MaximumStreamData))
}

View File

@ -20,7 +20,7 @@ type StreamFrame struct {
fromPool bool
}
func parseStreamFrame(r *bytes.Reader, typ uint64, _ protocol.VersionNumber) (*StreamFrame, error) {
func parseStreamFrame(r *bytes.Reader, typ uint64, _ protocol.Version) (*StreamFrame, error) {
hasOffset := typ&0b100 > 0
fin := typ&0b1 > 0
hasDataLen := typ&0b10 > 0
@ -79,7 +79,7 @@ func parseStreamFrame(r *bytes.Reader, typ uint64, _ protocol.VersionNumber) (*S
}
// Write writes a STREAM frame
func (f *StreamFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, error) {
func (f *StreamFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {
if len(f.Data) == 0 && !f.Fin {
return nil, errors.New("StreamFrame: attempting to write empty frame without FIN")
}
@ -108,7 +108,7 @@ func (f *StreamFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, error)
}
// Length returns the total length of the STREAM frame
func (f *StreamFrame) Length(version protocol.VersionNumber) protocol.ByteCount {
func (f *StreamFrame) Length(version protocol.Version) protocol.ByteCount {
length := 1 + quicvarint.Len(uint64(f.StreamID))
if f.Offset != 0 {
length += quicvarint.Len(uint64(f.Offset))
@ -126,7 +126,7 @@ func (f *StreamFrame) DataLen() protocol.ByteCount {
// MaxDataLen returns the maximum data length
// If 0 is returned, writing will fail (a STREAM frame must contain at least 1 byte of data).
func (f *StreamFrame) MaxDataLen(maxSize protocol.ByteCount, version protocol.VersionNumber) protocol.ByteCount {
func (f *StreamFrame) MaxDataLen(maxSize protocol.ByteCount, version protocol.Version) protocol.ByteCount {
headerLen := 1 + quicvarint.Len(uint64(f.StreamID))
if f.Offset != 0 {
headerLen += quicvarint.Len(uint64(f.Offset))
@ -151,7 +151,7 @@ func (f *StreamFrame) MaxDataLen(maxSize protocol.ByteCount, version protocol.Ve
// The frame might not be split if:
// * the size is large enough to fit the whole frame
// * the size is too small to fit even a 1-byte frame. In that case, the frame returned is nil.
func (f *StreamFrame) MaybeSplitOffFrame(maxSize protocol.ByteCount, version protocol.VersionNumber) (*StreamFrame, bool /* was splitting required */) {
func (f *StreamFrame) MaybeSplitOffFrame(maxSize protocol.ByteCount, version protocol.Version) (*StreamFrame, bool /* was splitting required */) {
if maxSize >= f.Length(version) {
return nil, false
}

View File

@ -14,7 +14,7 @@ type StreamsBlockedFrame struct {
StreamLimit protocol.StreamNum
}
func parseStreamsBlockedFrame(r *bytes.Reader, typ uint64, _ protocol.VersionNumber) (*StreamsBlockedFrame, error) {
func parseStreamsBlockedFrame(r *bytes.Reader, typ uint64, _ protocol.Version) (*StreamsBlockedFrame, error) {
f := &StreamsBlockedFrame{}
switch typ {
case bidiStreamBlockedFrameType:
@ -33,7 +33,7 @@ func parseStreamsBlockedFrame(r *bytes.Reader, typ uint64, _ protocol.VersionNum
return f, nil
}
func (f *StreamsBlockedFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, error) {
func (f *StreamsBlockedFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {
switch f.Type {
case protocol.StreamTypeBidi:
b = append(b, bidiStreamBlockedFrameType)
@ -45,6 +45,6 @@ func (f *StreamsBlockedFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte
}
// Length of a written frame
func (f *StreamsBlockedFrame) Length(_ protocol.VersionNumber) protocol.ByteCount {
func (f *StreamsBlockedFrame) Length(_ protocol.Version) protocol.ByteCount {
return 1 + quicvarint.Len(uint64(f.StreamLimit))
}

View File

@ -7,7 +7,7 @@ import (
"errors"
"fmt"
"io"
"net"
"net/netip"
"sort"
"time"
@ -51,10 +51,7 @@ const (
// PreferredAddress is the value encoding in the preferred_address transport parameter
type PreferredAddress struct {
IPv4 net.IP
IPv4Port uint16
IPv6 net.IP
IPv6Port uint16
IPv4, IPv6 netip.AddrPort
ConnectionID protocol.ConnectionID
StatelessResetToken protocol.StatelessResetToken
}
@ -218,26 +215,24 @@ func (p *TransportParameters) unmarshal(r *bytes.Reader, sentBy protocol.Perspec
func (p *TransportParameters) readPreferredAddress(r *bytes.Reader, expectedLen int) error {
remainingLen := r.Len()
pa := &PreferredAddress{}
ipv4 := make([]byte, 4)
if _, err := io.ReadFull(r, ipv4); err != nil {
var ipv4 [4]byte
if _, err := io.ReadFull(r, ipv4[:]); err != nil {
return err
}
pa.IPv4 = net.IP(ipv4)
port, err := utils.BigEndian.ReadUint16(r)
if err != nil {
return err
}
pa.IPv4Port = port
ipv6 := make([]byte, 16)
if _, err := io.ReadFull(r, ipv6); err != nil {
pa.IPv4 = netip.AddrPortFrom(netip.AddrFrom4(ipv4), port)
var ipv6 [16]byte
if _, err := io.ReadFull(r, ipv6[:]); err != nil {
return err
}
pa.IPv6 = net.IP(ipv6)
port, err = utils.BigEndian.ReadUint16(r)
if err != nil {
return err
}
pa.IPv6Port = port
pa.IPv6 = netip.AddrPortFrom(netip.AddrFrom16(ipv6), port)
connIDLen, err := r.ReadByte()
if err != nil {
return err
@ -384,13 +379,12 @@ func (p *TransportParameters) Marshal(pers protocol.Perspective) []byte {
if p.PreferredAddress != nil {
b = quicvarint.Append(b, uint64(preferredAddressParameterID))
b = quicvarint.Append(b, 4+2+16+2+1+uint64(p.PreferredAddress.ConnectionID.Len())+16)
ipv4 := p.PreferredAddress.IPv4
b = append(b, ipv4[len(ipv4)-4:]...)
b = append(b, []byte{0, 0}...)
binary.BigEndian.PutUint16(b[len(b)-2:], p.PreferredAddress.IPv4Port)
b = append(b, p.PreferredAddress.IPv6...)
b = append(b, []byte{0, 0}...)
binary.BigEndian.PutUint16(b[len(b)-2:], p.PreferredAddress.IPv6Port)
ip4 := p.PreferredAddress.IPv4.Addr().As4()
b = append(b, ip4[:]...)
b = binary.BigEndian.AppendUint16(b, p.PreferredAddress.IPv4.Port())
ip6 := p.PreferredAddress.IPv6.Addr().As16()
b = append(b, ip6[:]...)
b = binary.BigEndian.AppendUint16(b, p.PreferredAddress.IPv6.Port())
b = append(b, uint8(p.PreferredAddress.ConnectionID.Len()))
b = append(b, p.PreferredAddress.ConnectionID.Bytes()...)
b = append(b, p.PreferredAddress.StatelessResetToken[:]...)

View File

@ -1,17 +1,15 @@
package wire
import (
"bytes"
"crypto/rand"
"encoding/binary"
"errors"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/utils"
)
// ParseVersionNegotiationPacket parses a Version Negotiation packet.
func ParseVersionNegotiationPacket(b []byte) (dest, src protocol.ArbitraryLenConnectionID, _ []protocol.VersionNumber, _ error) {
func ParseVersionNegotiationPacket(b []byte) (dest, src protocol.ArbitraryLenConnectionID, _ []protocol.Version, _ error) {
n, dest, src, err := ParseArbitraryLenConnectionIDs(b)
if err != nil {
return nil, nil, nil, err
@ -25,32 +23,31 @@ func ParseVersionNegotiationPacket(b []byte) (dest, src protocol.ArbitraryLenCon
//nolint:stylecheck
return nil, nil, nil, errors.New("Version Negotiation packet has a version list with an invalid length")
}
versions := make([]protocol.VersionNumber, len(b)/4)
versions := make([]protocol.Version, len(b)/4)
for i := 0; len(b) > 0; i++ {
versions[i] = protocol.VersionNumber(binary.BigEndian.Uint32(b[:4]))
versions[i] = protocol.Version(binary.BigEndian.Uint32(b[:4]))
b = b[4:]
}
return dest, src, versions, nil
}
// ComposeVersionNegotiation composes a Version Negotiation
func ComposeVersionNegotiation(destConnID, srcConnID protocol.ArbitraryLenConnectionID, versions []protocol.VersionNumber) []byte {
func ComposeVersionNegotiation(destConnID, srcConnID protocol.ArbitraryLenConnectionID, versions []protocol.Version) []byte {
greasedVersions := protocol.GetGreasedVersions(versions)
expectedLen := 1 /* type byte */ + 4 /* version field */ + 1 /* dest connection ID length field */ + destConnID.Len() + 1 /* src connection ID length field */ + srcConnID.Len() + len(greasedVersions)*4
buf := bytes.NewBuffer(make([]byte, 0, expectedLen))
r := make([]byte, 1)
_, _ = rand.Read(r) // ignore the error here. It is not critical to have perfect random here.
buf := make([]byte, 1+4 /* type byte and version field */, expectedLen)
_, _ = rand.Read(buf[:1]) // ignore the error here. It is not critical to have perfect random here.
// Setting the "QUIC bit" (0x40) is not required by the RFC,
// but it allows clients to demultiplex QUIC with a long list of other protocols.
// See RFC 9443 and https://mailarchive.ietf.org/arch/msg/quic/oR4kxGKY6mjtPC1CZegY1ED4beg/ for details.
buf.WriteByte(r[0] | 0xc0)
utils.BigEndian.WriteUint32(buf, 0) // version 0
buf.WriteByte(uint8(destConnID.Len()))
buf.Write(destConnID.Bytes())
buf.WriteByte(uint8(srcConnID.Len()))
buf.Write(srcConnID.Bytes())
buf[0] |= 0xc0
// The next 4 bytes are left at 0 (version number).
buf = append(buf, uint8(destConnID.Len()))
buf = append(buf, destConnID.Bytes()...)
buf = append(buf, uint8(srcConnID.Len()))
buf = append(buf, srcConnID.Bytes()...)
for _, v := range greasedVersions {
utils.BigEndian.WriteUint32(buf, uint32(v))
buf = binary.BigEndian.AppendUint32(buf, uint32(v))
}
return buf.Bytes()
return buf
}

View File

@ -27,9 +27,9 @@ type ConnectionTracer struct {
UpdatedCongestionState func(CongestionState)
UpdatedPTOCount func(value uint32)
UpdatedKeyFromTLS func(EncryptionLevel, Perspective)
UpdatedKey func(generation KeyPhase, remote bool)
UpdatedKey func(keyPhase KeyPhase, remote bool)
DroppedEncryptionLevel func(EncryptionLevel)
DroppedKey func(generation KeyPhase)
DroppedKey func(keyPhase KeyPhase)
SetLossTimer func(TimerType, EncryptionLevel, time.Time)
LossTimerExpired func(TimerType, EncryptionLevel)
LossTimerCanceled func()

View File

@ -37,7 +37,7 @@ type (
// The StreamType is the type of the stream (unidirectional or bidirectional).
StreamType = protocol.StreamType
// The VersionNumber is the QUIC version.
VersionNumber = protocol.VersionNumber
VersionNumber = protocol.Version
// The Header is the QUIC packet header, before removing header protection.
Header = wire.Header

View File

@ -7,6 +7,8 @@ type Tracer struct {
SentPacket func(net.Addr, *Header, ByteCount, []Frame)
SentVersionNegotiationPacket func(_ net.Addr, dest, src ArbitraryLenConnectionID, _ []VersionNumber)
DroppedPacket func(net.Addr, PacketType, ByteCount, PacketDropReason)
Debug func(name, msg string)
Close func()
}
// NewMultiplexedTracer creates a new tracer that multiplexes events to multiple tracers.
@ -39,5 +41,19 @@ func NewMultiplexedTracer(tracers ...*Tracer) *Tracer {
}
}
},
Debug: func(name, msg string) {
for _, t := range tracers {
if t.Debug != nil {
t.Debug(name, msg)
}
}
},
Close: func() {
for _, t := range tracers {
if t.Close != nil {
t.Close()
}
}
},
}
}

View File

@ -3,12 +3,12 @@
# Install Go manually, since oss-fuzz ships with an outdated Go version.
# See https://github.com/google/oss-fuzz/pull/10643.
export CXX="${CXX} -lresolv" # required by Go 1.20
wget https://go.dev/dl/go1.21.5.linux-amd64.tar.gz \
wget https://go.dev/dl/go1.22.0.linux-amd64.tar.gz \
&& mkdir temp-go \
&& rm -rf /root/.go/* \
&& tar -C temp-go/ -xzf go1.21.5.linux-amd64.tar.gz \
&& tar -C temp-go/ -xzf go1.22.0.linux-amd64.tar.gz \
&& mv temp-go/go/* /root/.go/ \
&& rm -rf temp-go go1.21.5.linux-amd64.tar.gz
&& rm -rf temp-go go1.22.0.linux-amd64.tar.gz
(
# fuzz qpack

View File

@ -129,7 +129,7 @@ func (h *packetHandlerMap) Add(id protocol.ConnectionID, handler packetHandler)
return true
}
func (h *packetHandlerMap) AddWithConnID(clientDestConnID, newConnID protocol.ConnectionID, fn func() (packetHandler, bool)) bool {
func (h *packetHandlerMap) AddWithConnID(clientDestConnID, newConnID protocol.ConnectionID, handler packetHandler) bool {
h.mutex.Lock()
defer h.mutex.Unlock()
@ -137,12 +137,8 @@ func (h *packetHandlerMap) AddWithConnID(clientDestConnID, newConnID protocol.Co
h.logger.Debugf("Not adding connection ID %s for a new connection, as it already exists.", clientDestConnID)
return false
}
conn, ok := fn()
if !ok {
return false
}
h.handlers[clientDestConnID] = conn
h.handlers[newConnID] = conn
h.handlers[clientDestConnID] = handler
h.handlers[newConnID] = handler
h.logger.Debugf("Adding connection IDs %s and %s for a new connection.", clientDestConnID, newConnID)
return true
}
@ -168,18 +164,17 @@ func (h *packetHandlerMap) Retire(id protocol.ConnectionID) {
// Depending on which side closed the connection, we need to:
// * remote close: absorb delayed packets
// * local close: retransmit the CONNECTION_CLOSE packet, in case it was lost
func (h *packetHandlerMap) ReplaceWithClosed(ids []protocol.ConnectionID, pers protocol.Perspective, connClosePacket []byte) {
func (h *packetHandlerMap) ReplaceWithClosed(ids []protocol.ConnectionID, connClosePacket []byte) {
var handler packetHandler
if connClosePacket != nil {
handler = newClosedLocalConn(
func(addr net.Addr, info packetInfo) {
h.enqueueClosePacket(closePacket{payload: connClosePacket, addr: addr, info: info})
},
pers,
h.logger,
)
} else {
handler = newClosedRemoteConn(pers)
handler = newClosedRemoteConn()
}
h.mutex.Lock()
@ -191,7 +186,6 @@ func (h *packetHandlerMap) ReplaceWithClosed(ids []protocol.ConnectionID, pers p
time.AfterFunc(h.deleteRetiredConnsAfter, func() {
h.mutex.Lock()
handler.shutdown()
for _, id := range ids {
delete(h.handlers, id)
}

View File

@ -18,13 +18,13 @@ import (
var errNothingToPack = errors.New("nothing to pack")
type packer interface {
PackCoalescedPacket(onlyAck bool, maxPacketSize protocol.ByteCount, v protocol.VersionNumber) (*coalescedPacket, error)
PackAckOnlyPacket(maxPacketSize protocol.ByteCount, v protocol.VersionNumber) (shortHeaderPacket, *packetBuffer, error)
AppendPacket(buf *packetBuffer, maxPacketSize protocol.ByteCount, v protocol.VersionNumber) (shortHeaderPacket, error)
MaybePackProbePacket(protocol.EncryptionLevel, protocol.ByteCount, protocol.VersionNumber) (*coalescedPacket, error)
PackConnectionClose(*qerr.TransportError, protocol.ByteCount, protocol.VersionNumber) (*coalescedPacket, error)
PackApplicationClose(*qerr.ApplicationError, protocol.ByteCount, protocol.VersionNumber) (*coalescedPacket, error)
PackMTUProbePacket(ping ackhandler.Frame, size protocol.ByteCount, v protocol.VersionNumber) (shortHeaderPacket, *packetBuffer, error)
PackCoalescedPacket(onlyAck bool, maxPacketSize protocol.ByteCount, v protocol.Version) (*coalescedPacket, error)
PackAckOnlyPacket(maxPacketSize protocol.ByteCount, v protocol.Version) (shortHeaderPacket, *packetBuffer, error)
AppendPacket(buf *packetBuffer, maxPacketSize protocol.ByteCount, v protocol.Version) (shortHeaderPacket, error)
MaybePackProbePacket(protocol.EncryptionLevel, protocol.ByteCount, protocol.Version) (*coalescedPacket, error)
PackConnectionClose(*qerr.TransportError, protocol.ByteCount, protocol.Version) (*coalescedPacket, error)
PackApplicationClose(*qerr.ApplicationError, protocol.ByteCount, protocol.Version) (*coalescedPacket, error)
PackMTUProbePacket(ping ackhandler.Frame, size protocol.ByteCount, v protocol.Version) (shortHeaderPacket, *packetBuffer, error)
SetToken([]byte)
}
@ -106,8 +106,8 @@ type sealingManager interface {
type frameSource interface {
HasData() bool
AppendStreamFrames([]ackhandler.StreamFrame, protocol.ByteCount, protocol.VersionNumber) ([]ackhandler.StreamFrame, protocol.ByteCount)
AppendControlFrames([]ackhandler.Frame, protocol.ByteCount, protocol.VersionNumber) ([]ackhandler.Frame, protocol.ByteCount)
AppendStreamFrames([]ackhandler.StreamFrame, protocol.ByteCount, protocol.Version) ([]ackhandler.StreamFrame, protocol.ByteCount)
AppendControlFrames([]ackhandler.Frame, protocol.ByteCount, protocol.Version) ([]ackhandler.Frame, protocol.ByteCount)
}
type ackFrameSource interface {
@ -170,7 +170,7 @@ func newPacketPacker(
}
// PackConnectionClose packs a packet that closes the connection with a transport error.
func (p *packetPacker) PackConnectionClose(e *qerr.TransportError, maxPacketSize protocol.ByteCount, v protocol.VersionNumber) (*coalescedPacket, error) {
func (p *packetPacker) PackConnectionClose(e *qerr.TransportError, maxPacketSize protocol.ByteCount, v protocol.Version) (*coalescedPacket, error) {
var reason string
// don't send details of crypto errors
if !e.ErrorCode.IsCryptoError() {
@ -180,7 +180,7 @@ func (p *packetPacker) PackConnectionClose(e *qerr.TransportError, maxPacketSize
}
// PackApplicationClose packs a packet that closes the connection with an application error.
func (p *packetPacker) PackApplicationClose(e *qerr.ApplicationError, maxPacketSize protocol.ByteCount, v protocol.VersionNumber) (*coalescedPacket, error) {
func (p *packetPacker) PackApplicationClose(e *qerr.ApplicationError, maxPacketSize protocol.ByteCount, v protocol.Version) (*coalescedPacket, error) {
return p.packConnectionClose(true, uint64(e.ErrorCode), 0, e.ErrorMessage, maxPacketSize, v)
}
@ -190,7 +190,7 @@ func (p *packetPacker) packConnectionClose(
frameType uint64,
reason string,
maxPacketSize protocol.ByteCount,
v protocol.VersionNumber,
v protocol.Version,
) (*coalescedPacket, error) {
var sealers [4]sealer
var hdrs [3]*wire.ExtendedHeader
@ -293,7 +293,7 @@ func (p *packetPacker) packConnectionClose(
// longHeaderPacketLength calculates the length of a serialized long header packet.
// It takes into account that packets that have a tiny payload need to be padded,
// such that len(payload) + packet number len >= 4 + AEAD overhead
func (p *packetPacker) longHeaderPacketLength(hdr *wire.ExtendedHeader, pl payload, v protocol.VersionNumber) protocol.ByteCount {
func (p *packetPacker) longHeaderPacketLength(hdr *wire.ExtendedHeader, pl payload, v protocol.Version) protocol.ByteCount {
var paddingLen protocol.ByteCount
pnLen := protocol.ByteCount(hdr.PacketNumberLen)
if pl.length < 4-pnLen {
@ -328,7 +328,7 @@ func (p *packetPacker) initialPaddingLen(frames []ackhandler.Frame, currentSize,
// PackCoalescedPacket packs a new packet.
// It packs an Initial / Handshake if there is data to send in these packet number spaces.
// It should only be called before the handshake is confirmed.
func (p *packetPacker) PackCoalescedPacket(onlyAck bool, maxPacketSize protocol.ByteCount, v protocol.VersionNumber) (*coalescedPacket, error) {
func (p *packetPacker) PackCoalescedPacket(onlyAck bool, maxPacketSize protocol.ByteCount, v protocol.Version) (*coalescedPacket, error) {
var (
initialHdr, handshakeHdr, zeroRTTHdr *wire.ExtendedHeader
initialPayload, handshakePayload, zeroRTTPayload, oneRTTPayload payload
@ -442,7 +442,7 @@ func (p *packetPacker) PackCoalescedPacket(onlyAck bool, maxPacketSize protocol.
// PackAckOnlyPacket packs a packet containing only an ACK in the application data packet number space.
// It should be called after the handshake is confirmed.
func (p *packetPacker) PackAckOnlyPacket(maxPacketSize protocol.ByteCount, v protocol.VersionNumber) (shortHeaderPacket, *packetBuffer, error) {
func (p *packetPacker) PackAckOnlyPacket(maxPacketSize protocol.ByteCount, v protocol.Version) (shortHeaderPacket, *packetBuffer, error) {
buf := getPacketBuffer()
packet, err := p.appendPacket(buf, true, maxPacketSize, v)
return packet, buf, err
@ -450,11 +450,11 @@ func (p *packetPacker) PackAckOnlyPacket(maxPacketSize protocol.ByteCount, v pro
// AppendPacket packs a packet in the application data packet number space.
// It should be called after the handshake is confirmed.
func (p *packetPacker) AppendPacket(buf *packetBuffer, maxPacketSize protocol.ByteCount, v protocol.VersionNumber) (shortHeaderPacket, error) {
func (p *packetPacker) AppendPacket(buf *packetBuffer, maxPacketSize protocol.ByteCount, v protocol.Version) (shortHeaderPacket, error) {
return p.appendPacket(buf, false, maxPacketSize, v)
}
func (p *packetPacker) appendPacket(buf *packetBuffer, onlyAck bool, maxPacketSize protocol.ByteCount, v protocol.VersionNumber) (shortHeaderPacket, error) {
func (p *packetPacker) appendPacket(buf *packetBuffer, onlyAck bool, maxPacketSize protocol.ByteCount, v protocol.Version) (shortHeaderPacket, error) {
sealer, err := p.cryptoSetup.Get1RTTSealer()
if err != nil {
return shortHeaderPacket{}, err
@ -471,7 +471,7 @@ func (p *packetPacker) appendPacket(buf *packetBuffer, onlyAck bool, maxPacketSi
return p.appendShortHeaderPacket(buf, connID, pn, pnLen, kp, pl, 0, maxPacketSize, sealer, false, v)
}
func (p *packetPacker) maybeGetCryptoPacket(maxPacketSize protocol.ByteCount, encLevel protocol.EncryptionLevel, onlyAck, ackAllowed bool, v protocol.VersionNumber) (*wire.ExtendedHeader, payload) {
func (p *packetPacker) maybeGetCryptoPacket(maxPacketSize protocol.ByteCount, encLevel protocol.EncryptionLevel, onlyAck, ackAllowed bool, v protocol.Version) (*wire.ExtendedHeader, payload) {
if onlyAck {
if ack := p.acks.GetAckFrame(encLevel, true); ack != nil {
return p.getLongHeader(encLevel, v), payload{
@ -543,7 +543,7 @@ func (p *packetPacker) maybeGetCryptoPacket(maxPacketSize protocol.ByteCount, en
return hdr, pl
}
func (p *packetPacker) maybeGetAppDataPacketFor0RTT(sealer sealer, maxPacketSize protocol.ByteCount, v protocol.VersionNumber) (*wire.ExtendedHeader, payload) {
func (p *packetPacker) maybeGetAppDataPacketFor0RTT(sealer sealer, maxPacketSize protocol.ByteCount, v protocol.Version) (*wire.ExtendedHeader, payload) {
if p.perspective != protocol.PerspectiveClient {
return nil, payload{}
}
@ -553,12 +553,12 @@ func (p *packetPacker) maybeGetAppDataPacketFor0RTT(sealer sealer, maxPacketSize
return hdr, p.maybeGetAppDataPacket(maxPayloadSize, false, false, v)
}
func (p *packetPacker) maybeGetShortHeaderPacket(sealer handshake.ShortHeaderSealer, hdrLen protocol.ByteCount, maxPacketSize protocol.ByteCount, onlyAck, ackAllowed bool, v protocol.VersionNumber) payload {
func (p *packetPacker) maybeGetShortHeaderPacket(sealer handshake.ShortHeaderSealer, hdrLen protocol.ByteCount, maxPacketSize protocol.ByteCount, onlyAck, ackAllowed bool, v protocol.Version) payload {
maxPayloadSize := maxPacketSize - hdrLen - protocol.ByteCount(sealer.Overhead())
return p.maybeGetAppDataPacket(maxPayloadSize, onlyAck, ackAllowed, v)
}
func (p *packetPacker) maybeGetAppDataPacket(maxPayloadSize protocol.ByteCount, onlyAck, ackAllowed bool, v protocol.VersionNumber) payload {
func (p *packetPacker) maybeGetAppDataPacket(maxPayloadSize protocol.ByteCount, onlyAck, ackAllowed bool, v protocol.Version) payload {
pl := p.composeNextPacket(maxPayloadSize, onlyAck, ackAllowed, v)
// check if we have anything to send
@ -581,7 +581,7 @@ func (p *packetPacker) maybeGetAppDataPacket(maxPayloadSize protocol.ByteCount,
return pl
}
func (p *packetPacker) composeNextPacket(maxFrameSize protocol.ByteCount, onlyAck, ackAllowed bool, v protocol.VersionNumber) payload {
func (p *packetPacker) composeNextPacket(maxFrameSize protocol.ByteCount, onlyAck, ackAllowed bool, v protocol.Version) payload {
if onlyAck {
if ack := p.acks.GetAckFrame(protocol.Encryption1RTT, true); ack != nil {
return payload{ack: ack, length: ack.Length(v)}
@ -589,12 +589,11 @@ func (p *packetPacker) composeNextPacket(maxFrameSize protocol.ByteCount, onlyAc
return payload{}
}
pl := payload{streamFrames: make([]ackhandler.StreamFrame, 0, 1)}
hasData := p.framer.HasData()
hasRetransmission := p.retransmissionQueue.HasAppData()
var hasAck bool
var pl payload
if ackAllowed {
if ack := p.acks.GetAckFrame(protocol.Encryption1RTT, !hasRetransmission && !hasData); ack != nil {
pl.ack = ack
@ -661,7 +660,7 @@ func (p *packetPacker) composeNextPacket(maxFrameSize protocol.ByteCount, onlyAc
return pl
}
func (p *packetPacker) MaybePackProbePacket(encLevel protocol.EncryptionLevel, maxPacketSize protocol.ByteCount, v protocol.VersionNumber) (*coalescedPacket, error) {
func (p *packetPacker) MaybePackProbePacket(encLevel protocol.EncryptionLevel, maxPacketSize protocol.ByteCount, v protocol.Version) (*coalescedPacket, error) {
if encLevel == protocol.Encryption1RTT {
s, err := p.cryptoSetup.Get1RTTSealer()
if err != nil {
@ -727,7 +726,7 @@ func (p *packetPacker) MaybePackProbePacket(encLevel protocol.EncryptionLevel, m
return packet, nil
}
func (p *packetPacker) PackMTUProbePacket(ping ackhandler.Frame, size protocol.ByteCount, v protocol.VersionNumber) (shortHeaderPacket, *packetBuffer, error) {
func (p *packetPacker) PackMTUProbePacket(ping ackhandler.Frame, size protocol.ByteCount, v protocol.Version) (shortHeaderPacket, *packetBuffer, error) {
pl := payload{
frames: []ackhandler.Frame{ping},
length: ping.Frame.Length(v),
@ -745,7 +744,7 @@ func (p *packetPacker) PackMTUProbePacket(ping ackhandler.Frame, size protocol.B
return packet, buffer, err
}
func (p *packetPacker) getLongHeader(encLevel protocol.EncryptionLevel, v protocol.VersionNumber) *wire.ExtendedHeader {
func (p *packetPacker) getLongHeader(encLevel protocol.EncryptionLevel, v protocol.Version) *wire.ExtendedHeader {
pn, pnLen := p.pnManager.PeekPacketNumber(encLevel)
hdr := &wire.ExtendedHeader{
PacketNumber: pn,
@ -768,7 +767,7 @@ func (p *packetPacker) getLongHeader(encLevel protocol.EncryptionLevel, v protoc
return hdr
}
func (p *packetPacker) appendLongHeaderPacket(buffer *packetBuffer, header *wire.ExtendedHeader, pl payload, padding protocol.ByteCount, encLevel protocol.EncryptionLevel, sealer sealer, v protocol.VersionNumber) (*longHeaderPacket, error) {
func (p *packetPacker) appendLongHeaderPacket(buffer *packetBuffer, header *wire.ExtendedHeader, pl payload, padding protocol.ByteCount, encLevel protocol.EncryptionLevel, sealer sealer, v protocol.Version) (*longHeaderPacket, error) {
var paddingLen protocol.ByteCount
pnLen := protocol.ByteCount(header.PacketNumberLen)
if pl.length < 4-pnLen {
@ -814,7 +813,7 @@ func (p *packetPacker) appendShortHeaderPacket(
padding, maxPacketSize protocol.ByteCount,
sealer sealer,
isMTUProbePacket bool,
v protocol.VersionNumber,
v protocol.Version,
) (shortHeaderPacket, error) {
var paddingLen protocol.ByteCount
if pl.length < 4-protocol.ByteCount(pnLen) {
@ -860,7 +859,7 @@ func (p *packetPacker) appendShortHeaderPacket(
// appendPacketPayload serializes the payload of a packet into the raw byte slice.
// It modifies the order of payload.frames.
func (p *packetPacker) appendPacketPayload(raw []byte, pl payload, paddingLen protocol.ByteCount, v protocol.VersionNumber) ([]byte, error) {
func (p *packetPacker) appendPacketPayload(raw []byte, pl payload, paddingLen protocol.ByteCount, v protocol.Version) ([]byte, error) {
payloadOffset := len(raw)
if pl.ack != nil {
var err error

View File

@ -53,7 +53,7 @@ func newPacketUnpacker(cs handshake.CryptoSetup, shortHdrConnIDLen int) *packetU
// If the reserved bits are invalid, the error is wire.ErrInvalidReservedBits.
// If any other error occurred when parsing the header, the error is of type headerParseError.
// If decrypting the payload fails for any reason, the error is the error returned by the AEAD.
func (u *packetUnpacker) UnpackLongHeader(hdr *wire.Header, rcvTime time.Time, data []byte, v protocol.VersionNumber) (*unpackedPacket, error) {
func (u *packetUnpacker) UnpackLongHeader(hdr *wire.Header, rcvTime time.Time, data []byte, v protocol.Version) (*unpackedPacket, error) {
var encLevel protocol.EncryptionLevel
var extHdr *wire.ExtendedHeader
var decrypted []byte
@ -125,7 +125,7 @@ func (u *packetUnpacker) UnpackShortHeader(rcvTime time.Time, data []byte) (prot
return pn, pnLen, kp, decrypted, nil
}
func (u *packetUnpacker) unpackLongHeaderPacket(opener handshake.LongHeaderOpener, hdr *wire.Header, data []byte, v protocol.VersionNumber) (*wire.ExtendedHeader, []byte, error) {
func (u *packetUnpacker) unpackLongHeaderPacket(opener handshake.LongHeaderOpener, hdr *wire.Header, data []byte, v protocol.Version) (*wire.ExtendedHeader, []byte, error) {
extHdr, parseErr := u.unpackLongHeader(opener, hdr, data, v)
// If the reserved bits are set incorrectly, we still need to continue unpacking.
// This avoids a timing side-channel, which otherwise might allow an attacker
@ -187,7 +187,7 @@ func (u *packetUnpacker) unpackShortHeader(hd headerDecryptor, data []byte) (int
}
// The error is either nil, a wire.ErrInvalidReservedBits or of type headerParseError.
func (u *packetUnpacker) unpackLongHeader(hd headerDecryptor, hdr *wire.Header, data []byte, v protocol.VersionNumber) (*wire.ExtendedHeader, error) {
func (u *packetUnpacker) unpackLongHeader(hd headerDecryptor, hdr *wire.Header, data []byte, v protocol.Version) (*wire.ExtendedHeader, error) {
extHdr, err := unpackLongHeader(hd, hdr, data, v)
if err != nil && err != wire.ErrInvalidReservedBits {
return nil, &headerParseError{err: err}
@ -195,7 +195,7 @@ func (u *packetUnpacker) unpackLongHeader(hd headerDecryptor, hdr *wire.Header,
return extHdr, err
}
func unpackLongHeader(hd headerDecryptor, hdr *wire.Header, data []byte, v protocol.VersionNumber) (*wire.ExtendedHeader, error) {
func unpackLongHeader(hd headerDecryptor, hdr *wire.Header, data []byte, v protocol.Version) (*wire.ExtendedHeader, error) {
r := bytes.NewReader(data)
hdrLen := hdr.ParsedLen()

View File

@ -292,10 +292,6 @@ func (s *receiveStream) handleResetStreamFrameImpl(frame *wire.ResetStreamFrame)
return newlyRcvdFinalOffset, nil
}
func (s *receiveStream) CloseRemote(offset protocol.ByteCount) {
s.handleStreamFrame(&wire.StreamFrame{Fin: true, Offset: offset})
}
func (s *receiveStream) SetReadDeadline(t time.Time) error {
s.mutex.Lock()
s.deadline = t

View File

@ -74,7 +74,7 @@ func (q *retransmissionQueue) addAppData(f wire.Frame) {
q.appData = append(q.appData, f)
}
func (q *retransmissionQueue) GetInitialFrame(maxLen protocol.ByteCount, v protocol.VersionNumber) wire.Frame {
func (q *retransmissionQueue) GetInitialFrame(maxLen protocol.ByteCount, v protocol.Version) wire.Frame {
if len(q.initialCryptoData) > 0 {
f := q.initialCryptoData[0]
newFrame, needsSplit := f.MaybeSplitOffFrame(maxLen, v)
@ -97,7 +97,7 @@ func (q *retransmissionQueue) GetInitialFrame(maxLen protocol.ByteCount, v proto
return f
}
func (q *retransmissionQueue) GetHandshakeFrame(maxLen protocol.ByteCount, v protocol.VersionNumber) wire.Frame {
func (q *retransmissionQueue) GetHandshakeFrame(maxLen protocol.ByteCount, v protocol.Version) wire.Frame {
if len(q.handshakeCryptoData) > 0 {
f := q.handshakeCryptoData[0]
newFrame, needsSplit := f.MaybeSplitOffFrame(maxLen, v)
@ -120,7 +120,7 @@ func (q *retransmissionQueue) GetHandshakeFrame(maxLen protocol.ByteCount, v pro
return f
}
func (q *retransmissionQueue) GetAppDataFrame(maxLen protocol.ByteCount, v protocol.VersionNumber) wire.Frame {
func (q *retransmissionQueue) GetAppDataFrame(maxLen protocol.ByteCount, v protocol.Version) wire.Frame {
if len(q.appData) == 0 {
return nil
}

View File

@ -18,7 +18,7 @@ type sendStreamI interface {
SendStream
handleStopSendingFrame(*wire.StopSendingFrame)
hasData() bool
popStreamFrame(maxBytes protocol.ByteCount, v protocol.VersionNumber) (frame ackhandler.StreamFrame, ok, hasMore bool)
popStreamFrame(maxBytes protocol.ByteCount, v protocol.Version) (frame ackhandler.StreamFrame, ok, hasMore bool)
closeForShutdown(error)
updateSendWindow(protocol.ByteCount)
}
@ -198,7 +198,7 @@ func (s *sendStream) canBufferStreamFrame() bool {
// popStreamFrame returns the next STREAM frame that is supposed to be sent on this stream
// maxBytes is the maximum length this frame (including frame header) will have.
func (s *sendStream) popStreamFrame(maxBytes protocol.ByteCount, v protocol.VersionNumber) (af ackhandler.StreamFrame, ok, hasMore bool) {
func (s *sendStream) popStreamFrame(maxBytes protocol.ByteCount, v protocol.Version) (af ackhandler.StreamFrame, ok, hasMore bool) {
s.mutex.Lock()
f, hasMoreData := s.popNewOrRetransmittedStreamFrame(maxBytes, v)
if f != nil {
@ -215,7 +215,7 @@ func (s *sendStream) popStreamFrame(maxBytes protocol.ByteCount, v protocol.Vers
}, true, hasMoreData
}
func (s *sendStream) popNewOrRetransmittedStreamFrame(maxBytes protocol.ByteCount, v protocol.VersionNumber) (*wire.StreamFrame, bool /* has more data to send */) {
func (s *sendStream) popNewOrRetransmittedStreamFrame(maxBytes protocol.ByteCount, v protocol.Version) (*wire.StreamFrame, bool /* has more data to send */) {
if s.cancelWriteErr != nil || s.closeForShutdownErr != nil {
return nil, false
}
@ -269,7 +269,7 @@ func (s *sendStream) popNewOrRetransmittedStreamFrame(maxBytes protocol.ByteCoun
return f, hasMoreData
}
func (s *sendStream) popNewStreamFrame(maxBytes, sendWindow protocol.ByteCount, v protocol.VersionNumber) (*wire.StreamFrame, bool) {
func (s *sendStream) popNewStreamFrame(maxBytes, sendWindow protocol.ByteCount, v protocol.Version) (*wire.StreamFrame, bool) {
if s.nextFrame != nil {
nextFrame := s.nextFrame
s.nextFrame = nil
@ -304,7 +304,7 @@ func (s *sendStream) popNewStreamFrame(maxBytes, sendWindow protocol.ByteCount,
return f, hasMoreData
}
func (s *sendStream) popNewStreamFrameWithoutBuffer(f *wire.StreamFrame, maxBytes, sendWindow protocol.ByteCount, v protocol.VersionNumber) bool {
func (s *sendStream) popNewStreamFrameWithoutBuffer(f *wire.StreamFrame, maxBytes, sendWindow protocol.ByteCount, v protocol.Version) bool {
maxDataLen := f.MaxDataLen(maxBytes, v)
if maxDataLen == 0 { // a STREAM frame must have at least one byte of data
return s.dataForWriting != nil || s.nextFrame != nil || s.finishedWriting
@ -314,7 +314,7 @@ func (s *sendStream) popNewStreamFrameWithoutBuffer(f *wire.StreamFrame, maxByte
return s.dataForWriting != nil || s.nextFrame != nil || s.finishedWriting
}
func (s *sendStream) maybeGetRetransmission(maxBytes protocol.ByteCount, v protocol.VersionNumber) (*wire.StreamFrame, bool /* has more retransmissions */) {
func (s *sendStream) maybeGetRetransmission(maxBytes protocol.ByteCount, v protocol.Version) (*wire.StreamFrame, bool /* has more retransmissions */) {
f := s.retransmissionQueue[0]
newFrame, needsSplit := f.MaybeSplitOffFrame(maxBytes, v)
if needsSplit {
@ -404,11 +404,13 @@ func (s *sendStream) cancelWriteImpl(errorCode qerr.StreamErrorCode, remote bool
}
func (s *sendStream) updateSendWindow(limit protocol.ByteCount) {
updated := s.flowController.UpdateSendWindow(limit)
if !updated { // duplicate or reordered MAX_STREAM_DATA frame
return
}
s.mutex.Lock()
hasStreamData := s.dataForWriting != nil || s.nextFrame != nil
s.mutex.Unlock()
s.flowController.UpdateSendWindow(limit)
if hasStreamData {
s.sender.onHasStreamData(s.streamID)
}

View File

@ -7,7 +7,6 @@ import (
"fmt"
"net"
"sync"
"sync/atomic"
"time"
"github.com/quic-go/quic-go/internal/handshake"
@ -24,15 +23,14 @@ var ErrServerClosed = errors.New("quic: server closed")
// packetHandler handles packets
type packetHandler interface {
handlePacket(receivedPacket)
shutdown()
destroy(error)
getPerspective() protocol.Perspective
closeWithTransportError(qerr.TransportErrorCode)
}
type packetHandlerManager interface {
Get(protocol.ConnectionID) (packetHandler, bool)
GetByResetToken(protocol.StatelessResetToken) (packetHandler, bool)
AddWithConnID(protocol.ConnectionID, protocol.ConnectionID, func() (packetHandler, bool)) bool
AddWithConnID(destConnID, newConnID protocol.ConnectionID, h packetHandler) bool
Close(error)
connRunner
}
@ -41,11 +39,9 @@ type quicConn interface {
EarlyConnection
earlyConnReady() <-chan struct{}
handlePacket(receivedPacket)
GetVersion() protocol.VersionNumber
getPerspective() protocol.Perspective
run() error
destroy(error)
shutdown()
closeWithTransportError(TransportErrorCode)
}
type zeroRTTQueue struct {
@ -98,10 +94,10 @@ type baseServer struct {
*logging.ConnectionTracer,
uint64,
utils.Logger,
protocol.VersionNumber,
protocol.Version,
) quicConn
closeOnce sync.Once
closeMx sync.Mutex
errorChan chan struct{} // is closed when the server is closed
closeErr error
running chan struct{} // closed as soon as run() returns
@ -111,8 +107,9 @@ type baseServer struct {
connectionRefusedQueue chan rejectedPacket
retryQueue chan rejectedPacket
verifySourceAddress func(net.Addr) bool
connQueue chan quicConn
connQueueLen int32 // to be used as an atomic
tracer *logging.Tracer
@ -240,6 +237,7 @@ func newServer(
onClose func(),
tokenGeneratorKey TokenGeneratorKey,
maxTokenAge time.Duration,
verifySourceAddress func(net.Addr) bool,
disableVersionNegotiation bool,
acceptEarly bool,
) *baseServer {
@ -249,9 +247,10 @@ func newServer(
config: config,
tokenGenerator: handshake.NewTokenGenerator(tokenGeneratorKey),
maxTokenAge: maxTokenAge,
verifySourceAddress: verifySourceAddress,
connIDGenerator: connIDGenerator,
connHandler: connHandler,
connQueue: make(chan quicConn),
connQueue: make(chan quicConn, protocol.MaxAcceptQueueSize),
errorChan: make(chan struct{}),
running: make(chan struct{}),
receivedPackets: make(chan receivedPacket, protocol.MaxServerUnprocessedPackets),
@ -322,7 +321,6 @@ func (s *baseServer) accept(ctx context.Context) (quicConn, error) {
case <-ctx.Done():
return nil, ctx.Err()
case conn := <-s.connQueue:
atomic.AddInt32(&s.connQueueLen, -1)
return conn, nil
case <-s.errorChan:
return nil, s.closeErr
@ -335,15 +333,19 @@ func (s *baseServer) Close() error {
}
func (s *baseServer) close(e error, notifyOnClose bool) {
s.closeOnce.Do(func() {
s.closeMx.Lock()
if s.closeErr != nil {
s.closeMx.Unlock()
return
}
s.closeErr = e
close(s.errorChan)
<-s.running
s.closeMx.Unlock()
if notifyOnClose {
s.onClose()
}
})
}
// Addr returns the server's network address
@ -542,10 +544,10 @@ func (s *baseServer) validateToken(token *handshake.Token, addr net.Addr) bool {
func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error {
if len(hdr.Token) == 0 && hdr.DestConnectionID.Len() < protocol.MinConnectionIDLenInitial {
p.buffer.Release()
if s.tracer != nil && s.tracer.DroppedPacket != nil {
s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropUnexpectedPacket)
}
p.buffer.Release()
return errors.New("too short connection ID")
}
@ -560,6 +562,7 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error
var (
token *handshake.Token
retrySrcConnID *protocol.ConnectionID
clientAddrVerified bool
)
origDestConnID := hdr.DestConnectionID
if len(hdr.Token) > 0 {
@ -572,9 +575,9 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error
token = tok
}
}
clientAddrIsValid := s.validateToken(token, p.remoteAddr)
if token != nil && !clientAddrIsValid {
if token != nil {
clientAddrVerified = s.validateToken(token, p.remoteAddr)
if !clientAddrVerified {
// For invalid and expired non-retry tokens, we don't send an INVALID_TOKEN error.
// We just ignore them, and act as if there was no token on this packet at all.
// This also means we might send a Retry later.
@ -593,7 +596,9 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error
return nil
}
}
if token == nil && s.config.RequireAddressValidation(p.remoteAddr) {
}
if token == nil && s.verifySourceAddress != nil && s.verifySourceAddress(p.remoteAddr) {
// Retry invalidates all 0-RTT packets sent.
delete(s.zeroRTTQueues, hdr.DestConnectionID)
select {
@ -605,8 +610,15 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error
return nil
}
if queueLen := atomic.LoadInt32(&s.connQueueLen); queueLen >= protocol.MaxAcceptQueueSize {
s.logger.Debugf("Rejecting new connection. Server currently busy. Accept queue length: %d (max %d)", queueLen, protocol.MaxAcceptQueueSize)
config := s.config
if s.config.GetConfigForClient != nil {
conf, err := s.config.GetConfigForClient(&ClientHelloInfo{
RemoteAddr: p.remoteAddr,
AddrVerified: clientAddrVerified,
})
if err != nil {
s.logger.Debugf("Rejecting new connection due to GetConfigForClient callback")
delete(s.zeroRTTQueues, hdr.DestConnectionID)
select {
case s.connectionRefusedQueue <- rejectedPacket{receivedPacket: p, hdr: hdr}:
default:
@ -615,24 +627,11 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error
}
return nil
}
connID, err := s.connIDGenerator.GenerateConnectionID()
if err != nil {
return err
}
s.logger.Debugf("Changing connection ID to %s.", connID)
var conn quicConn
tracingID := nextConnTracingID()
if added := s.connHandler.AddWithConnID(hdr.DestConnectionID, connID, func() (packetHandler, bool) {
config := s.config
if s.config.GetConfigForClient != nil {
conf, err := s.config.GetConfigForClient(&ClientHelloInfo{RemoteAddr: p.remoteAddr})
if err != nil {
s.logger.Debugf("Rejecting new connection due to GetConfigForClient callback")
return nil, false
}
config = populateConfig(conf)
}
var conn quicConn
tracingID := nextConnTracingID()
var tracer *logging.ConnectionTracer
if config.Tracer != nil {
// Use the same connection ID that is passed to the client's GetLogWriter callback.
@ -642,6 +641,11 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error
}
tracer = config.Tracer(context.WithValue(context.Background(), ConnectionTracingKey, tracingID), protocol.PerspectiveServer, connID)
}
connID, err := s.connIDGenerator.GenerateConnectionID()
if err != nil {
return err
}
s.logger.Debugf("Changing connection ID to %s.", connID)
conn = s.newConn(
newSendConn(s.conn, p.remoteAddr, p.info, s.logger),
s.connHandler,
@ -655,14 +659,23 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error
config,
s.tlsConf,
s.tokenGenerator,
clientAddrIsValid,
clientAddrVerified,
tracer,
tracingID,
s.logger,
hdr.Version,
)
conn.handlePacket(p)
// Adding the connection will fail if the client's chosen Destination Connection ID is already in use.
// This is very unlikely: Even if an attacker chooses a connection ID that's already in use,
// under normal circumstances the packet would just be routed to that connection.
// The only time this collision will occur if we receive the two Initial packets at the same time.
if added := s.connHandler.AddWithConnID(hdr.DestConnectionID, connID, conn); !added {
delete(s.zeroRTTQueues, hdr.DestConnectionID)
conn.closeWithTransportError(qerr.ConnectionRefused)
return nil
}
// Pass queued 0-RTT to the newly established connection.
if q, ok := s.zeroRTTQueues[hdr.DestConnectionID]; ok {
for _, p := range q.packets {
conn.handlePacket(p)
@ -670,56 +683,43 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error
delete(s.zeroRTTQueues, hdr.DestConnectionID)
}
return conn, true
}); !added {
select {
case s.connectionRefusedQueue <- rejectedPacket{receivedPacket: p, hdr: hdr}:
default:
// drop packet if we can't send out the CONNECTION_REFUSED fast enough
p.buffer.Release()
}
return nil
}
go conn.run()
go s.handleNewConn(conn)
if conn == nil {
p.buffer.Release()
return nil
go func() {
if completed := s.handleNewConn(conn); !completed {
return
}
select {
case s.connQueue <- conn:
default:
conn.closeWithTransportError(ConnectionRefused)
}
}()
return nil
}
func (s *baseServer) handleNewConn(conn quicConn) {
connCtx := conn.Context()
func (s *baseServer) handleNewConn(conn quicConn) bool {
if s.acceptEarlyConns {
// wait until the early connection is ready, the handshake fails, or the server is closed
select {
case <-s.errorChan:
conn.destroy(&qerr.TransportError{ErrorCode: ConnectionRefused})
return
conn.closeWithTransportError(ConnectionRefused)
return false
case <-conn.Context().Done():
return false
case <-conn.earlyConnReady():
case <-connCtx.Done():
return
return true
}
} else {
// wait until the handshake is complete (or fails)
}
// wait until the handshake completes, fails, or the server is closed
select {
case <-s.errorChan:
conn.destroy(&qerr.TransportError{ErrorCode: ConnectionRefused})
return
conn.closeWithTransportError(ConnectionRefused)
return false
case <-conn.Context().Done():
return false
case <-conn.HandshakeComplete():
case <-connCtx.Done():
return
}
}
atomic.AddInt32(&s.connQueueLen, 1)
select {
case s.connQueue <- conn:
// blocks until the connection is accepted
case <-connCtx.Done():
atomic.AddInt32(&s.connQueueLen, -1)
// don't pass connections that were already closed to Accept()
return true
}
}

View File

@ -60,7 +60,7 @@ type streamI interface {
// for sending
hasData() bool
handleStopSendingFrame(*wire.StopSendingFrame)
popStreamFrame(maxBytes protocol.ByteCount, v protocol.VersionNumber) (ackhandler.StreamFrame, bool, bool)
popStreamFrame(maxBytes protocol.ByteCount, v protocol.Version) (ackhandler.StreamFrame, bool, bool)
updateSendWindow(protocol.ByteCount)
}

View File

@ -221,7 +221,7 @@ func (c *oobConn) ReadPacket() (receivedPacket, error) {
// unsigned int ipi6_ifindex; /* send/recv interface index */
// };
if len(body) == 20 {
p.info.addr = netip.AddrFrom16(*(*[16]byte)(body[:16]))
p.info.addr = netip.AddrFrom16(*(*[16]byte)(body[:16])).Unmap()
p.info.ifIndex = binary.LittleEndian.Uint32(body[16:])
} else {
invalidCmsgOnceV6.Do(func() {

View File

@ -41,7 +41,8 @@ type Transport struct {
Conn net.PacketConn
// The length of the connection ID in bytes.
// It can be 0, or any value between 4 and 18.
// It can be any value between 1 and 20.
// Due to the increased risk of collisions, it is not recommended to use connection IDs shorter than 4 bytes.
// If unset, a 4 byte connection ID will be used.
ConnectionIDLength int
@ -77,7 +78,19 @@ type Transport struct {
// It has no effect for clients.
DisableVersionNegotiationPackets bool
// VerifySourceAddress decides if a connection attempt originating from unvalidated source
// addresses first needs to go through source address validation using QUIC's Retry mechanism,
// as described in RFC 9000 section 8.1.2.
// Note that the address passed to this callback is unvalidated, and might be spoofed in case
// of an attack.
// Validating the source address adds one additional network roundtrip to the handshake,
// and should therefore only be used if a suspiciously high number of incoming connection is recorded.
// For most use cases, wrapping the Allow function of a rate.Limiter will be a reasonable
// implementation of this callback (negating its return value).
VerifySourceAddress func(net.Addr) bool
// A Tracer traces events that don't belong to a single QUIC connection.
// Tracer.Close is called when the transport is closed.
Tracer *logging.Tracer
handlerMap packetHandlerManager
@ -147,7 +160,7 @@ func (t *Transport) createServer(tlsConf *tls.Config, conf *Config, allow0RTT bo
if t.server != nil {
return nil, errListenerAlreadySet
}
conf = populateServerConfig(conf)
conf = populateConfig(conf)
if err := t.init(false); err != nil {
return nil, err
}
@ -161,6 +174,7 @@ func (t *Transport) createServer(tlsConf *tls.Config, conf *Config, allow0RTT bo
t.closeServer,
*t.TokenGeneratorKey,
t.MaxTokenAge,
t.VerifySourceAddress,
t.DisableVersionNegotiationPackets,
allow0RTT,
)
@ -323,6 +337,9 @@ func (t *Transport) close(e error) {
if t.server != nil {
t.server.close(e, false)
}
if t.Tracer != nil && t.Tracer.Close != nil {
t.Tracer.Close()
}
t.closed = true
}
@ -379,13 +396,21 @@ func (t *Transport) handlePacket(p receivedPacket) {
return
}
if isStatelessReset := t.maybeHandleStatelessReset(p.data); isStatelessReset {
return
}
// If there's a connection associated with the connection ID, pass the packet there.
if handler, ok := t.handlerMap.Get(connID); ok {
handler.handlePacket(p)
return
}
// RFC 9000 section 10.3.1 requires that the stateless reset detection logic is run for both
// packets that cannot be associated with any connections, and for packets that can't be decrypted.
// We deviate from the RFC and ignore the latter: If a packet's connection ID is associated with an
// existing connection, it is dropped there if if it can't be decrypted.
// Stateless resets use random connection IDs, and at reasonable connection ID lengths collisions are
// exceedingly rare. In the unlikely event that a stateless reset is misrouted to an existing connection,
// it is to be expected that the next stateless reset will be correctly detected.
if isStatelessReset := t.maybeHandleStatelessReset(p.data); isStatelessReset {
return
}
if !wire.IsLongHeaderPacket(p.data[0]) {
t.maybeSendStatelessReset(p)
return

37
vendor/go.uber.org/mock/CONTRIBUTORS generated vendored
View File

@ -1,37 +0,0 @@
# This is the official list of people who can contribute (and typically
# have contributed) code to the gomock repository.
# The AUTHORS file lists the copyright holders; this file
# lists people. For example, Google employees are listed here
# but not in AUTHORS, because Google holds the copyright.
#
# The submission process automatically checks to make sure
# that people submitting code are listed in this file (by email address).
#
# Names should be added to this file only after verifying that
# the individual or the individual's organization has agreed to
# the appropriate Contributor License Agreement, found here:
#
# http://code.google.com/legal/individual-cla-v1.0.html
# http://code.google.com/legal/corporate-cla-v1.0.html
#
# The agreement for individuals can be filled out on the web.
#
# When adding J Random Contributor's name to this file,
# either J's name or J's organization's name should be
# added to the AUTHORS file, depending on whether the
# individual or corporate CLA was used.
# Names should be added to this file like so:
# Name <email address>
#
# An entry with two email addresses specifies that the
# first address should be used in the submit logs and
# that the second address should be recognized as the
# same person when interacting with Rietveld.
# Please keep the list sorted.
Aaron Jacobs <jacobsa@google.com> <aaronjjacobs@gmail.com>
Alex Reece <awreece@gmail.com>
David Symonds <dsymonds@golang.org>
Ryan Barrett <ryanb@google.com>

View File

@ -11,8 +11,10 @@
package main
import (
"errors"
"fmt"
"go/ast"
"go/token"
"strings"
"go.uber.org/mock/mockgen/model"
@ -98,6 +100,16 @@ func (p *fileParser) parseGenericMethod(field *ast.Field, it *namedInterface, if
case *ast.IndexListExpr:
indices = v.Indices
typ = v.X
case *ast.UnaryExpr:
if v.Op == token.TILDE {
return nil, errConstraintInterface
}
return nil, fmt.Errorf("~T may only appear as constraint for %T", field.Type)
case *ast.BinaryExpr:
if v.Op == token.OR {
return nil, errConstraintInterface
}
return nil, fmt.Errorf("A|B may only appear as constraint for %T", field.Type)
default:
return nil, fmt.Errorf("don't know how to mock method of type %T", field.Type)
}
@ -114,3 +126,5 @@ func (p *fileParser) parseGenericMethod(field *ast.Field, it *namedInterface, if
return p.parseMethod(nf, it, iface, pkg, tps)
}
var errConstraintInterface = errors.New("interface contains constraints")

View File

@ -31,6 +31,7 @@ import (
"os/exec"
"path"
"path/filepath"
"runtime"
"sort"
"strconv"
"strings"
@ -314,8 +315,14 @@ func (g *generator) Generate(pkg *model.Package, outputPkgName string, outputPac
}
g.p("//")
g.p("// Generated by this command:")
g.p("//")
// only log the name of the executable, not the full path
g.p("// %v", strings.Join(append([]string{filepath.Base(os.Args[0])}, os.Args[1:]...), " "))
name := filepath.Base(os.Args[0])
if runtime.GOOS == "windows" {
name = strings.TrimSuffix(name, ".exe")
}
g.p("//\t%v", strings.Join(append([]string{name}, os.Args[1:]...), " "))
g.p("//")
// Get all required imports, and generate unique names for them all.
im := pkg.Imports()
@ -371,7 +378,7 @@ func (g *generator) Generate(pkg *model.Package, outputPkgName string, outputPac
}
i := 0
for localNames[pkgName] || token.Lookup(pkgName).IsKeyword() {
for localNames[pkgName] || token.Lookup(pkgName).IsKeyword() || pkgName == "any" {
pkgName = base + strconv.Itoa(i)
i++
}
@ -386,6 +393,10 @@ func (g *generator) Generate(pkg *model.Package, outputPkgName string, outputPac
}
if *writePkgComment {
// Ensure there's an empty line before the package to follow the recommendations:
// https://github.com/golang/go/wiki/CodeReviewComments#package-comments
g.p("")
g.p("// Package %v is a generated GoMock package.", outputPkgName)
}
g.p("package %v", outputPkgName)
@ -508,7 +519,7 @@ func (g *generator) GenerateMockMethods(mockType string, intf *model.Interface,
g.p("")
_ = g.GenerateMockMethod(mockType, m, pkgOverride, shortTp)
g.p("")
_ = g.GenerateMockRecorderMethod(intf, mockType, m, shortTp, typed)
_ = g.GenerateMockRecorderMethod(intf, m, shortTp, typed)
if typed {
g.p("")
_ = g.GenerateMockReturnCallMethod(intf, m, pkgOverride, longTp, shortTp)
@ -596,7 +607,8 @@ func (g *generator) GenerateMockMethod(mockType string, m *model.Method, pkgOver
return nil
}
func (g *generator) GenerateMockRecorderMethod(intf *model.Interface, mockType string, m *model.Method, shortTp string, typed bool) error {
func (g *generator) GenerateMockRecorderMethod(intf *model.Interface, m *model.Method, shortTp string, typed bool) error {
mockType := g.mockName(intf.Name)
argNames := g.getArgNames(m, true)
var argString string
@ -621,7 +633,7 @@ func (g *generator) GenerateMockRecorderMethod(intf *model.Interface, mockType s
g.p("// %v indicates an expected call of %v.", m.Name, m.Name)
if typed {
g.p("func (%s *%vMockRecorder%v) %v(%v) *%s%sCall%s {", idRecv, mockType, shortTp, m.Name, argString, intf.Name, m.Name, shortTp)
g.p("func (%s *%vMockRecorder%v) %v(%v) *%s%sCall%s {", idRecv, mockType, shortTp, m.Name, argString, mockType, m.Name, shortTp)
} else {
g.p("func (%s *%vMockRecorder%v) %v(%v) *gomock.Call {", idRecv, mockType, shortTp, m.Name, argString)
}
@ -650,7 +662,7 @@ func (g *generator) GenerateMockRecorderMethod(intf *model.Interface, mockType s
}
if typed {
g.p(`call := %s.mock.ctrl.RecordCallWithMethodType(%s.mock, "%s", reflect.TypeOf((*%s%s)(nil).%s)%s)`, idRecv, idRecv, m.Name, mockType, shortTp, m.Name, callArgs)
g.p(`return &%s%sCall%s{Call: call}`, intf.Name, m.Name, shortTp)
g.p(`return &%s%sCall%s{Call: call}`, mockType, m.Name, shortTp)
} else {
g.p(`return %s.mock.ctrl.RecordCallWithMethodType(%s.mock, "%s", reflect.TypeOf((*%s%s)(nil).%s)%s)`, idRecv, idRecv, m.Name, mockType, shortTp, m.Name, callArgs)
}
@ -661,6 +673,7 @@ func (g *generator) GenerateMockRecorderMethod(intf *model.Interface, mockType s
}
func (g *generator) GenerateMockReturnCallMethod(intf *model.Interface, m *model.Method, pkgOverride, longTp, shortTp string) error {
mockType := g.mockName(intf.Name)
argNames := g.getArgNames(m, true /* in */)
retNames := g.getArgNames(m, false /* out */)
argTypes := g.getArgTypes(m, pkgOverride, true /* in */)
@ -683,10 +696,10 @@ func (g *generator) GenerateMockReturnCallMethod(intf *model.Interface, m *model
ia := newIdentifierAllocator(argNames)
idRecv := ia.allocateIdentifier("c")
recvStructName := intf.Name + m.Name
recvStructName := mockType + m.Name
g.p("// %s%sCall wrap *gomock.Call", intf.Name, m.Name)
g.p("type %s%sCall%s struct{", intf.Name, m.Name, longTp)
g.p("// %s%sCall wrap *gomock.Call", mockType, m.Name)
g.p("type %s%sCall%s struct{", mockType, m.Name, longTp)
g.in()
g.p("*gomock.Call")
g.out()

View File

@ -305,7 +305,7 @@ type PredeclaredType string
func (pt PredeclaredType) String(map[string]string, string) string { return string(pt) }
func (pt PredeclaredType) addImports(map[string]bool) {}
// TypeParametersType contains type paramters for a NamedType.
// TypeParametersType contains type parameters for a NamedType.
type TypeParametersType struct {
TypeParameters []Type
}

View File

@ -232,6 +232,9 @@ func (p *fileParser) parseFile(importPath string, file *ast.File) (*model.Packag
continue
}
i, err := p.parseInterface(ni.name.String(), importPath, ni)
if errors.Is(err, errConstraintInterface) {
continue
}
if err != nil {
return nil, err
}

View File

@ -187,6 +187,7 @@ type reflectData struct {
// gob encoding of a model.Package to standard output.
// JSON doesn't work because of the model.Type interface.
var reflectProgram = template.Must(template.New("program").Parse(`
// Code generated by MockGen. DO NOT EDIT.
package main
import (

4
vendor/modules.txt vendored
View File

@ -230,7 +230,7 @@ github.com/prometheus/common/model
github.com/prometheus/procfs
github.com/prometheus/procfs/internal/fs
github.com/prometheus/procfs/internal/util
# github.com/quic-go/quic-go v0.40.1-0.20240101045026-22b7f7744eb6
# github.com/quic-go/quic-go v0.42.0
## explicit; go 1.21
github.com/quic-go/quic-go
github.com/quic-go/quic-go/internal/ackhandler
@ -313,7 +313,7 @@ go.opentelemetry.io/proto/otlp/trace/v1
go.uber.org/automaxprocs/internal/cgroups
go.uber.org/automaxprocs/internal/runtime
go.uber.org/automaxprocs/maxprocs
# go.uber.org/mock v0.3.0
# go.uber.org/mock v0.4.0
## explicit; go 1.20
go.uber.org/mock/mockgen
go.uber.org/mock/mockgen/model