TUN-8371: Bump quic-go to v0.42.0

## Summary
We discovered that we were being impacted by a bug in quic-go,
that could create deadlocks and not close connections.

This commit bumps quic-go to the version that contains the fix
to prevent that from happening.
This commit is contained in:
João "Pisco" Fernandes 2024-04-18 18:26:01 +01:00 committed by chungthuang
parent 5e5f2f4d8c
commit 84833011ec
79 changed files with 730 additions and 763 deletions

4
go.mod
View File

@ -25,7 +25,7 @@ require (
github.com/pkg/errors v0.9.1 github.com/pkg/errors v0.9.1
github.com/prometheus/client_golang v1.13.0 github.com/prometheus/client_golang v1.13.0
github.com/prometheus/client_model v0.2.0 github.com/prometheus/client_model v0.2.0
github.com/quic-go/quic-go v0.40.1-0.20240101045026-22b7f7744eb6 github.com/quic-go/quic-go v0.42.0
github.com/rs/zerolog v1.20.0 github.com/rs/zerolog v1.20.0
github.com/stretchr/testify v1.8.4 github.com/stretchr/testify v1.8.4
github.com/urfave/cli/v2 v2.3.0 github.com/urfave/cli/v2 v2.3.0
@ -84,7 +84,7 @@ require (
github.com/prometheus/procfs v0.8.0 // indirect github.com/prometheus/procfs v0.8.0 // indirect
github.com/russross/blackfriday/v2 v2.1.0 // indirect github.com/russross/blackfriday/v2 v2.1.0 // indirect
go.opentelemetry.io/otel/metric v1.21.0 // indirect go.opentelemetry.io/otel/metric v1.21.0 // indirect
go.uber.org/mock v0.3.0 // indirect go.uber.org/mock v0.4.0 // indirect
golang.org/x/exp v0.0.0-20221205204356-47842c84f3db // indirect golang.org/x/exp v0.0.0-20221205204356-47842c84f3db // indirect
golang.org/x/mod v0.11.0 // indirect golang.org/x/mod v0.11.0 // indirect
golang.org/x/oauth2 v0.13.0 // indirect golang.org/x/oauth2 v0.13.0 // indirect

10
go.sum
View File

@ -322,8 +322,8 @@ github.com/prometheus/procfs v0.6.0/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1
github.com/prometheus/procfs v0.7.3/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1xBZuNvfVA= github.com/prometheus/procfs v0.7.3/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1xBZuNvfVA=
github.com/prometheus/procfs v0.8.0 h1:ODq8ZFEaYeCaZOJlZZdJA2AbQR98dSHSM1KW/You5mo= github.com/prometheus/procfs v0.8.0 h1:ODq8ZFEaYeCaZOJlZZdJA2AbQR98dSHSM1KW/You5mo=
github.com/prometheus/procfs v0.8.0/go.mod h1:z7EfXMXOkbkqb9IINtpCn86r/to3BnA0uaxHdg830/4= github.com/prometheus/procfs v0.8.0/go.mod h1:z7EfXMXOkbkqb9IINtpCn86r/to3BnA0uaxHdg830/4=
github.com/quic-go/quic-go v0.40.1-0.20240101045026-22b7f7744eb6 h1:OI4WiysowCcxLtcZMGBZildo12di3ljcMN4vWdUQpoU= github.com/quic-go/quic-go v0.42.0 h1:uSfdap0eveIl8KXnipv9K7nlwZ5IqLlYOpJ58u5utpM=
github.com/quic-go/quic-go v0.40.1-0.20240101045026-22b7f7744eb6/go.mod h1:qCkNjqczPEvgsOnxZ0eCD14lv+B2LHlFAB++CNOh9hA= github.com/quic-go/quic-go v0.42.0/go.mod h1:132kz4kL3F9vxhW3CtQJLDVwcFe5wdWeJXXijhsO57M=
github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4=
github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ= github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ=
github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog= github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog=
@ -381,8 +381,8 @@ go.opentelemetry.io/proto/otlp v1.0.0 h1:T0TX0tmXU8a3CbNXzEKGeU5mIVOdf0oykP+u2lI
go.opentelemetry.io/proto/otlp v1.0.0/go.mod h1:Sy6pihPLfYHkr3NkUbEhGHFhINUSI/v80hjKIs5JXpM= go.opentelemetry.io/proto/otlp v1.0.0/go.mod h1:Sy6pihPLfYHkr3NkUbEhGHFhINUSI/v80hjKIs5JXpM=
go.uber.org/automaxprocs v1.4.0 h1:CpDZl6aOlLhReez+8S3eEotD7Jx0Os++lemPlMULQP0= go.uber.org/automaxprocs v1.4.0 h1:CpDZl6aOlLhReez+8S3eEotD7Jx0Os++lemPlMULQP0=
go.uber.org/automaxprocs v1.4.0/go.mod h1:/mTEdr7LvHhs0v7mjdxDreTz1OG5zdZGqgOnhWiR/+Q= go.uber.org/automaxprocs v1.4.0/go.mod h1:/mTEdr7LvHhs0v7mjdxDreTz1OG5zdZGqgOnhWiR/+Q=
go.uber.org/mock v0.3.0 h1:3mUxI1No2/60yUYax92Pt8eNOEecx2D3lcXZh2NEZJo= go.uber.org/mock v0.4.0 h1:VcM4ZOtdbR4f6VXfiOpwpVJDL6lCReaZ6mw31wqh7KU=
go.uber.org/mock v0.3.0/go.mod h1:a6FSlNadKUHUa9IP5Vyt1zh4fC7uAwxMutEAscFbkZc= go.uber.org/mock v0.4.0/go.mod h1:a6FSlNadKUHUa9IP5Vyt1zh4fC7uAwxMutEAscFbkZc=
golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
@ -552,6 +552,8 @@ golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk=
golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY=

View File

@ -2,16 +2,6 @@ run:
skip-files: skip-files:
- internal/handshake/cipher_suite.go - internal/handshake/cipher_suite.go
linters-settings: linters-settings:
depguard:
rules:
qtls:
list-mode: lax
files:
- "!internal/qtls/**"
- "$all"
deny:
- pkg: github.com/quic-go/qtls-go1-20
desc: "importing qtls only allowed in internal/qtls"
misspell: misspell:
ignore-words: ignore-words:
- ect - ect
@ -20,7 +10,6 @@ linters:
disable-all: true disable-all: true
enable: enable:
- asciicheck - asciicheck
- depguard
- exhaustive - exhaustive
- exportloopref - exportloopref
- goimports - goimports

View File

@ -183,26 +183,20 @@ quic-go logs a wide range of events defined in [draft-ietf-quic-qlog-quic-events
qlog files can be processed by a number of 3rd-party tools. [qviz](https://qvis.quictools.info/) has proven very useful for debugging all kinds of QUIC connection failures. qlog files can be processed by a number of 3rd-party tools. [qviz](https://qvis.quictools.info/) has proven very useful for debugging all kinds of QUIC connection failures.
qlog is activated by setting a `Tracer` callback on the `Config`. It is called as soon as quic-go decides to starts the QUIC handshake on a new connection. qlog can be activated by setting the `Tracer` callback on the `Config`. It is called as soon as quic-go decides to start the QUIC handshake on a new connection.
A useful implementation of this callback could look like this: `qlog.DefaultTracer` provides a tracer implementation which writes qlog files to a directory specified by the `QLOGDIR` environment variable, if set.
The default qlog tracer can be used like this:
```go ```go
quic.Config{ quic.Config{
Tracer: func(ctx context.Context, p logging.Perspective, connID quic.ConnectionID) *logging.ConnectionTracer { Tracer: qlog.DefaultTracer,
role := "server"
if p == logging.PerspectiveClient {
role = "client"
}
filename := fmt.Sprintf("./log_%s_%s.qlog", connID, role)
f, err := os.Create(filename)
// handle the error
return qlog.NewConnectionTracer(f, p, connID)
}
} }
``` ```
This implementation of the callback creates a new qlog file in the current directory named `log_<client / server>_<QUIC connection ID>.qlog`. This example creates a new qlog file under `<QLOGDIR>/<Original Destination Connection ID>_<Vantage Point>.qlog`, e.g. `qlogs/2e0407da_client.qlog`.
For custom qlog behavior, `qlog.NewConnectionTracer` can be used.
## Using HTTP/3 ## Using HTTP/3
### As a server ### As a server
@ -232,11 +226,13 @@ http.Client{
| [algernon](https://github.com/xyproto/algernon) | Small self-contained pure-Go web server with Lua, Markdown, HTTP/2, QUIC, Redis and PostgreSQL support | ![GitHub Repo stars](https://img.shields.io/github/stars/xyproto/algernon?style=flat-square) | | [algernon](https://github.com/xyproto/algernon) | Small self-contained pure-Go web server with Lua, Markdown, HTTP/2, QUIC, Redis and PostgreSQL support | ![GitHub Repo stars](https://img.shields.io/github/stars/xyproto/algernon?style=flat-square) |
| [caddy](https://github.com/caddyserver/caddy/) | Fast, multi-platform web server with automatic HTTPS | ![GitHub Repo stars](https://img.shields.io/github/stars/caddyserver/caddy?style=flat-square) | | [caddy](https://github.com/caddyserver/caddy/) | Fast, multi-platform web server with automatic HTTPS | ![GitHub Repo stars](https://img.shields.io/github/stars/caddyserver/caddy?style=flat-square) |
| [cloudflared](https://github.com/cloudflare/cloudflared) | A tunneling daemon that proxies traffic from the Cloudflare network to your origins | ![GitHub Repo stars](https://img.shields.io/github/stars/cloudflare/cloudflared?style=flat-square) | | [cloudflared](https://github.com/cloudflare/cloudflared) | A tunneling daemon that proxies traffic from the Cloudflare network to your origins | ![GitHub Repo stars](https://img.shields.io/github/stars/cloudflare/cloudflared?style=flat-square) |
| [frp](https://github.com/fatedier/frp) | A fast reverse proxy to help you expose a local server behind a NAT or firewall to the internet | ![GitHub Repo stars](https://img.shields.io/github/stars/fatedier/frp?style=flat-square) |
| [go-libp2p](https://github.com/libp2p/go-libp2p) | libp2p implementation in Go, powering [Kubo](https://github.com/ipfs/kubo) (IPFS) and [Lotus](https://github.com/filecoin-project/lotus) (Filecoin), among others | ![GitHub Repo stars](https://img.shields.io/github/stars/libp2p/go-libp2p?style=flat-square) | | [go-libp2p](https://github.com/libp2p/go-libp2p) | libp2p implementation in Go, powering [Kubo](https://github.com/ipfs/kubo) (IPFS) and [Lotus](https://github.com/filecoin-project/lotus) (Filecoin), among others | ![GitHub Repo stars](https://img.shields.io/github/stars/libp2p/go-libp2p?style=flat-square) |
| [gost](https://github.com/go-gost/gost) | A simple security tunnel written in Go | ![GitHub Repo stars](https://img.shields.io/github/stars/go-gost/gost?style=flat-square) | | [gost](https://github.com/go-gost/gost) | A simple security tunnel written in Go | ![GitHub Repo stars](https://img.shields.io/github/stars/go-gost/gost?style=flat-square) |
| [Hysteria](https://github.com/apernet/hysteria) | A powerful, lightning fast and censorship resistant proxy | ![GitHub Repo stars](https://img.shields.io/github/stars/apernet/hysteria?style=flat-square) | | [Hysteria](https://github.com/apernet/hysteria) | A powerful, lightning fast and censorship resistant proxy | ![GitHub Repo stars](https://img.shields.io/github/stars/apernet/hysteria?style=flat-square) |
| [Mercure](https://github.com/dunglas/mercure) | An open, easy, fast, reliable and battery-efficient solution for real-time communications | ![GitHub Repo stars](https://img.shields.io/github/stars/dunglas/mercure?style=flat-square) | | [Mercure](https://github.com/dunglas/mercure) | An open, easy, fast, reliable and battery-efficient solution for real-time communications | ![GitHub Repo stars](https://img.shields.io/github/stars/dunglas/mercure?style=flat-square) |
| [OONI Probe](https://github.com/ooni/probe-cli) | Next generation OONI Probe. Library and CLI tool. | ![GitHub Repo stars](https://img.shields.io/github/stars/ooni/probe-cli?style=flat-square) | | [OONI Probe](https://github.com/ooni/probe-cli) | Next generation OONI Probe. Library and CLI tool. | ![GitHub Repo stars](https://img.shields.io/github/stars/ooni/probe-cli?style=flat-square) |
| [RoadRunner](https://github.com/roadrunner-server/roadrunner) | High-performance PHP application server, process manager written in Go and powered with plugins | ![GitHub Repo stars](https://img.shields.io/github/stars/roadrunner-server/roadrunner?style=flat-square) |
| [syncthing](https://github.com/syncthing/syncthing/) | Open Source Continuous File Synchronization | ![GitHub Repo stars](https://img.shields.io/github/stars/syncthing/syncthing?style=flat-square) | | [syncthing](https://github.com/syncthing/syncthing/) | Open Source Continuous File Synchronization | ![GitHub Repo stars](https://img.shields.io/github/stars/syncthing/syncthing?style=flat-square) |
| [traefik](https://github.com/traefik/traefik) | The Cloud Native Application Proxy | ![GitHub Repo stars](https://img.shields.io/github/stars/traefik/traefik?style=flat-square) | | [traefik](https://github.com/traefik/traefik) | The Cloud Native Application Proxy | ![GitHub Repo stars](https://img.shields.io/github/stars/traefik/traefik?style=flat-square) |
| [v2ray-core](https://github.com/v2fly/v2ray-core) | A platform for building proxies to bypass network restrictions | ![GitHub Repo stars](https://img.shields.io/github/stars/v2fly/v2ray-core?style=flat-square) | | [v2ray-core](https://github.com/v2fly/v2ray-core) | A platform for building proxies to bypass network restrictions | ![GitHub Repo stars](https://img.shields.io/github/stars/v2fly/v2ray-core?style=flat-square) |

View File

@ -28,7 +28,7 @@ type client struct {
initialPacketNumber protocol.PacketNumber initialPacketNumber protocol.PacketNumber
hasNegotiatedVersion bool hasNegotiatedVersion bool
version protocol.VersionNumber version protocol.Version
handshakeChan chan struct{} handshakeChan chan struct{}
@ -232,7 +232,7 @@ func (c *client) dial(ctx context.Context) error {
select { select {
case <-ctx.Done(): case <-ctx.Done():
c.conn.shutdown() c.conn.destroy(nil)
return context.Cause(ctx) return context.Cause(ctx)
case err := <-errorChan: case err := <-errorChan:
return err return err

View File

@ -4,7 +4,6 @@ import (
"math/bits" "math/bits"
"net" "net"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/utils" "github.com/quic-go/quic-go/internal/utils"
) )
@ -12,9 +11,8 @@ import (
// When receiving packets for such a connection, we need to retransmit the packet containing the CONNECTION_CLOSE frame, // When receiving packets for such a connection, we need to retransmit the packet containing the CONNECTION_CLOSE frame,
// with an exponential backoff. // with an exponential backoff.
type closedLocalConn struct { type closedLocalConn struct {
counter uint32 counter uint32
perspective protocol.Perspective logger utils.Logger
logger utils.Logger
sendPacket func(net.Addr, packetInfo) sendPacket func(net.Addr, packetInfo)
} }
@ -22,11 +20,10 @@ type closedLocalConn struct {
var _ packetHandler = &closedLocalConn{} var _ packetHandler = &closedLocalConn{}
// newClosedLocalConn creates a new closedLocalConn and runs it. // newClosedLocalConn creates a new closedLocalConn and runs it.
func newClosedLocalConn(sendPacket func(net.Addr, packetInfo), pers protocol.Perspective, logger utils.Logger) packetHandler { func newClosedLocalConn(sendPacket func(net.Addr, packetInfo), logger utils.Logger) packetHandler {
return &closedLocalConn{ return &closedLocalConn{
sendPacket: sendPacket, sendPacket: sendPacket,
perspective: pers, logger: logger,
logger: logger,
} }
} }
@ -41,24 +38,20 @@ func (c *closedLocalConn) handlePacket(p receivedPacket) {
c.sendPacket(p.remoteAddr, p.info) c.sendPacket(p.remoteAddr, p.info)
} }
func (c *closedLocalConn) shutdown() {} func (c *closedLocalConn) destroy(error) {}
func (c *closedLocalConn) destroy(error) {} func (c *closedLocalConn) closeWithTransportError(TransportErrorCode) {}
func (c *closedLocalConn) getPerspective() protocol.Perspective { return c.perspective }
// A closedRemoteConn is a connection that was closed remotely. // A closedRemoteConn is a connection that was closed remotely.
// For such a connection, we might receive reordered packets that were sent before the CONNECTION_CLOSE. // For such a connection, we might receive reordered packets that were sent before the CONNECTION_CLOSE.
// We can just ignore those packets. // We can just ignore those packets.
type closedRemoteConn struct { type closedRemoteConn struct{}
perspective protocol.Perspective
}
var _ packetHandler = &closedRemoteConn{} var _ packetHandler = &closedRemoteConn{}
func newClosedRemoteConn(pers protocol.Perspective) packetHandler { func newClosedRemoteConn() packetHandler {
return &closedRemoteConn{perspective: pers} return &closedRemoteConn{}
} }
func (s *closedRemoteConn) handlePacket(receivedPacket) {} func (c *closedRemoteConn) handlePacket(receivedPacket) {}
func (s *closedRemoteConn) shutdown() {} func (c *closedRemoteConn) destroy(error) {}
func (s *closedRemoteConn) destroy(error) {} func (c *closedRemoteConn) closeWithTransportError(TransportErrorCode) {}
func (s *closedRemoteConn) getPerspective() protocol.Perspective { return s.perspective }

View File

@ -5,6 +5,8 @@ coverage:
- interop/ - interop/
- internal/handshake/cipher_suite.go - internal/handshake/cipher_suite.go
- internal/utils/linkedlist/linkedlist.go - internal/utils/linkedlist/linkedlist.go
- internal/testdata
- testutils/
- fuzzing/ - fuzzing/
- metrics/ - metrics/
status: status:

View File

@ -2,7 +2,6 @@ package quic
import ( import (
"fmt" "fmt"
"net"
"time" "time"
"github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/protocol"
@ -49,16 +48,6 @@ func validateConfig(config *Config) error {
return nil return nil
} }
// populateServerConfig populates fields in the quic.Config with their default values, if none are set
// it may be called with nil
func populateServerConfig(config *Config) *Config {
config = populateConfig(config)
if config.RequireAddressValidation == nil {
config.RequireAddressValidation = func(net.Addr) bool { return false }
}
return config
}
// populateConfig populates fields in the quic.Config with their default values, if none are set // populateConfig populates fields in the quic.Config with their default values, if none are set
// it may be called with nil // it may be called with nil
func populateConfig(config *Config) *Config { func populateConfig(config *Config) *Config {
@ -111,7 +100,6 @@ func populateConfig(config *Config) *Config {
Versions: versions, Versions: versions,
HandshakeIdleTimeout: handshakeIdleTimeout, HandshakeIdleTimeout: handshakeIdleTimeout,
MaxIdleTimeout: idleTimeout, MaxIdleTimeout: idleTimeout,
RequireAddressValidation: config.RequireAddressValidation,
KeepAlivePeriod: config.KeepAlivePeriod, KeepAlivePeriod: config.KeepAlivePeriod,
InitialStreamReceiveWindow: initialStreamReceiveWindow, InitialStreamReceiveWindow: initialStreamReceiveWindow,
MaxStreamReceiveWindow: maxStreamReceiveWindow, MaxStreamReceiveWindow: maxStreamReceiveWindow,

View File

@ -19,7 +19,7 @@ type connIDGenerator struct {
getStatelessResetToken func(protocol.ConnectionID) protocol.StatelessResetToken getStatelessResetToken func(protocol.ConnectionID) protocol.StatelessResetToken
removeConnectionID func(protocol.ConnectionID) removeConnectionID func(protocol.ConnectionID)
retireConnectionID func(protocol.ConnectionID) retireConnectionID func(protocol.ConnectionID)
replaceWithClosed func([]protocol.ConnectionID, protocol.Perspective, []byte) replaceWithClosed func([]protocol.ConnectionID, []byte)
queueControlFrame func(wire.Frame) queueControlFrame func(wire.Frame)
} }
@ -30,7 +30,7 @@ func newConnIDGenerator(
getStatelessResetToken func(protocol.ConnectionID) protocol.StatelessResetToken, getStatelessResetToken func(protocol.ConnectionID) protocol.StatelessResetToken,
removeConnectionID func(protocol.ConnectionID), removeConnectionID func(protocol.ConnectionID),
retireConnectionID func(protocol.ConnectionID), retireConnectionID func(protocol.ConnectionID),
replaceWithClosed func([]protocol.ConnectionID, protocol.Perspective, []byte), replaceWithClosed func([]protocol.ConnectionID, []byte),
queueControlFrame func(wire.Frame), queueControlFrame func(wire.Frame),
generator ConnectionIDGenerator, generator ConnectionIDGenerator,
) *connIDGenerator { ) *connIDGenerator {
@ -126,7 +126,7 @@ func (m *connIDGenerator) RemoveAll() {
} }
} }
func (m *connIDGenerator) ReplaceWithClosed(pers protocol.Perspective, connClose []byte) { func (m *connIDGenerator) ReplaceWithClosed(connClose []byte) {
connIDs := make([]protocol.ConnectionID, 0, len(m.activeSrcConnIDs)+1) connIDs := make([]protocol.ConnectionID, 0, len(m.activeSrcConnIDs)+1)
if m.initialClientDestConnID != nil { if m.initialClientDestConnID != nil {
connIDs = append(connIDs, *m.initialClientDestConnID) connIDs = append(connIDs, *m.initialClientDestConnID)
@ -134,5 +134,5 @@ func (m *connIDGenerator) ReplaceWithClosed(pers protocol.Perspective, connClose
for _, connID := range m.activeSrcConnIDs { for _, connID := range m.activeSrcConnIDs {
connIDs = append(connIDs, connID) connIDs = append(connIDs, connID)
} }
m.replaceWithClosed(connIDs, pers, connClose) m.replaceWithClosed(connIDs, connClose)
} }

View File

@ -25,7 +25,7 @@ import (
) )
type unpacker interface { type unpacker interface {
UnpackLongHeader(hdr *wire.Header, rcvTime time.Time, data []byte, v protocol.VersionNumber) (*unpackedPacket, error) UnpackLongHeader(hdr *wire.Header, rcvTime time.Time, data []byte, v protocol.Version) (*unpackedPacket, error)
UnpackShortHeader(rcvTime time.Time, data []byte) (protocol.PacketNumber, protocol.PacketNumberLen, protocol.KeyPhaseBit, []byte, error) UnpackShortHeader(rcvTime time.Time, data []byte) (protocol.PacketNumber, protocol.PacketNumberLen, protocol.KeyPhaseBit, []byte, error)
} }
@ -93,7 +93,7 @@ type connRunner interface {
GetStatelessResetToken(protocol.ConnectionID) protocol.StatelessResetToken GetStatelessResetToken(protocol.ConnectionID) protocol.StatelessResetToken
Retire(protocol.ConnectionID) Retire(protocol.ConnectionID)
Remove(protocol.ConnectionID) Remove(protocol.ConnectionID)
ReplaceWithClosed([]protocol.ConnectionID, protocol.Perspective, []byte) ReplaceWithClosed([]protocol.ConnectionID, []byte)
AddResetToken(protocol.StatelessResetToken, packetHandler) AddResetToken(protocol.StatelessResetToken, packetHandler)
RemoveResetToken(protocol.StatelessResetToken) RemoveResetToken(protocol.StatelessResetToken)
} }
@ -106,7 +106,7 @@ type closeError struct {
type errCloseForRecreating struct { type errCloseForRecreating struct {
nextPacketNumber protocol.PacketNumber nextPacketNumber protocol.PacketNumber
nextVersion protocol.VersionNumber nextVersion protocol.Version
} }
func (e *errCloseForRecreating) Error() string { func (e *errCloseForRecreating) Error() string {
@ -128,7 +128,7 @@ type connection struct {
srcConnIDLen int srcConnIDLen int
perspective protocol.Perspective perspective protocol.Perspective
version protocol.VersionNumber version protocol.Version
config *Config config *Config
conn sendConn conn sendConn
@ -177,6 +177,7 @@ type connection struct {
earlyConnReadyChan chan struct{} earlyConnReadyChan chan struct{}
sentFirstPacket bool sentFirstPacket bool
droppedInitialKeys bool
handshakeComplete bool handshakeComplete bool
handshakeConfirmed bool handshakeConfirmed bool
@ -235,7 +236,7 @@ var newConnection = func(
tracer *logging.ConnectionTracer, tracer *logging.ConnectionTracer,
tracingID uint64, tracingID uint64,
logger utils.Logger, logger utils.Logger,
v protocol.VersionNumber, v protocol.Version,
) quicConn { ) quicConn {
s := &connection{ s := &connection{
conn: conn, conn: conn,
@ -348,7 +349,7 @@ var newClientConnection = func(
tracer *logging.ConnectionTracer, tracer *logging.ConnectionTracer,
tracingID uint64, tracingID uint64,
logger utils.Logger, logger utils.Logger,
v protocol.VersionNumber, v protocol.Version,
) quicConn { ) quicConn {
s := &connection{ s := &connection{
conn: conn, conn: conn,
@ -453,7 +454,7 @@ func (s *connection) preSetup() {
s.handshakeStream = newCryptoStream() s.handshakeStream = newCryptoStream()
s.sendQueue = newSendQueue(s.conn) s.sendQueue = newSendQueue(s.conn)
s.retransmissionQueue = newRetransmissionQueue() s.retransmissionQueue = newRetransmissionQueue()
s.frameParser = wire.NewFrameParser(s.config.EnableDatagrams) s.frameParser = *wire.NewFrameParser(s.config.EnableDatagrams)
s.rttStats = &utils.RTTStats{} s.rttStats = &utils.RTTStats{}
s.connFlowController = flowcontrol.NewConnectionFlowController( s.connFlowController = flowcontrol.NewConnectionFlowController(
protocol.ByteCount(s.config.InitialConnectionReceiveWindow), protocol.ByteCount(s.config.InitialConnectionReceiveWindow),
@ -520,6 +521,9 @@ func (s *connection) run() error {
runLoop: runLoop:
for { for {
if s.framer.QueuedTooManyControlFrames() {
s.closeLocal(&qerr.TransportError{ErrorCode: InternalError})
}
// Close immediately if requested // Close immediately if requested
select { select {
case closeErr = <-s.closeChan: case closeErr = <-s.closeChan:
@ -1148,7 +1152,7 @@ func (s *connection) handleUnpackedLongHeaderPacket(
if !s.receivedFirstPacket { if !s.receivedFirstPacket {
s.receivedFirstPacket = true s.receivedFirstPacket = true
if !s.versionNegotiated && s.tracer != nil && s.tracer.NegotiatedVersion != nil { if !s.versionNegotiated && s.tracer != nil && s.tracer.NegotiatedVersion != nil {
var clientVersions, serverVersions []protocol.VersionNumber var clientVersions, serverVersions []protocol.Version
switch s.perspective { switch s.perspective {
case protocol.PerspectiveClient: case protocol.PerspectiveClient:
clientVersions = s.config.Versions clientVersions = s.config.Versions
@ -1185,7 +1189,8 @@ func (s *connection) handleUnpackedLongHeaderPacket(
} }
} }
if s.perspective == protocol.PerspectiveServer && packet.encryptionLevel == protocol.EncryptionHandshake { if s.perspective == protocol.PerspectiveServer && packet.encryptionLevel == protocol.EncryptionHandshake &&
!s.droppedInitialKeys {
// On the server side, Initial keys are dropped as soon as the first Handshake packet is received. // On the server side, Initial keys are dropped as soon as the first Handshake packet is received.
// See Section 4.9.1 of RFC 9001. // See Section 4.9.1 of RFC 9001.
if err := s.dropEncryptionLevel(protocol.EncryptionInitial); err != nil { if err := s.dropEncryptionLevel(protocol.EncryptionInitial); err != nil {
@ -1572,13 +1577,6 @@ func (s *connection) closeRemote(e error) {
}) })
} }
// Close the connection. It sends a NO_ERROR application error.
// It waits until the run loop has stopped before returning
func (s *connection) shutdown() {
s.closeLocal(nil)
<-s.ctx.Done()
}
func (s *connection) CloseWithError(code ApplicationErrorCode, desc string) error { func (s *connection) CloseWithError(code ApplicationErrorCode, desc string) error {
s.closeLocal(&qerr.ApplicationError{ s.closeLocal(&qerr.ApplicationError{
ErrorCode: code, ErrorCode: code,
@ -1588,6 +1586,11 @@ func (s *connection) CloseWithError(code ApplicationErrorCode, desc string) erro
return nil return nil
} }
func (s *connection) closeWithTransportError(code TransportErrorCode) {
s.closeLocal(&qerr.TransportError{ErrorCode: code})
<-s.ctx.Done()
}
func (s *connection) handleCloseError(closeErr *closeError) { func (s *connection) handleCloseError(closeErr *closeError) {
e := closeErr.err e := closeErr.err
if e == nil { if e == nil {
@ -1632,7 +1635,7 @@ func (s *connection) handleCloseError(closeErr *closeError) {
// If this is a remote close we're done here // If this is a remote close we're done here
if closeErr.remote { if closeErr.remote {
s.connIDGenerator.ReplaceWithClosed(s.perspective, nil) s.connIDGenerator.ReplaceWithClosed(nil)
return return
} }
if closeErr.immediate { if closeErr.immediate {
@ -1649,7 +1652,7 @@ func (s *connection) handleCloseError(closeErr *closeError) {
if err != nil { if err != nil {
s.logger.Debugf("Error sending CONNECTION_CLOSE: %s", err) s.logger.Debugf("Error sending CONNECTION_CLOSE: %s", err)
} }
s.connIDGenerator.ReplaceWithClosed(s.perspective, connClosePacket) s.connIDGenerator.ReplaceWithClosed(connClosePacket)
} }
func (s *connection) dropEncryptionLevel(encLevel protocol.EncryptionLevel) error { func (s *connection) dropEncryptionLevel(encLevel protocol.EncryptionLevel) error {
@ -1661,6 +1664,7 @@ func (s *connection) dropEncryptionLevel(encLevel protocol.EncryptionLevel) erro
//nolint:exhaustive // only Initial and 0-RTT need special treatment //nolint:exhaustive // only Initial and 0-RTT need special treatment
switch encLevel { switch encLevel {
case protocol.EncryptionInitial: case protocol.EncryptionInitial:
s.droppedInitialKeys = true
s.cryptoStreamHandler.DiscardInitialKeys() s.cryptoStreamHandler.DiscardInitialKeys()
case protocol.Encryption0RTT: case protocol.Encryption0RTT:
s.streamsMap.ResetFor0RTT() s.streamsMap.ResetFor0RTT()
@ -2077,7 +2081,8 @@ func (s *connection) sendPackedCoalescedPacket(packet *coalescedPacket, ecn prot
largestAcked = p.ack.LargestAcked() largestAcked = p.ack.LargestAcked()
} }
s.sentPacketHandler.SentPacket(now, p.header.PacketNumber, largestAcked, p.streamFrames, p.frames, p.EncryptionLevel(), ecn, p.length, false) s.sentPacketHandler.SentPacket(now, p.header.PacketNumber, largestAcked, p.streamFrames, p.frames, p.EncryptionLevel(), ecn, p.length, false)
if s.perspective == protocol.PerspectiveClient && p.EncryptionLevel() == protocol.EncryptionHandshake { if s.perspective == protocol.PerspectiveClient && p.EncryptionLevel() == protocol.EncryptionHandshake &&
!s.droppedInitialKeys {
// On the client side, Initial keys are dropped as soon as the first Handshake packet is sent. // On the client side, Initial keys are dropped as soon as the first Handshake packet is sent.
// See Section 4.9.1 of RFC 9001. // See Section 4.9.1 of RFC 9001.
if err := s.dropEncryptionLevel(protocol.EncryptionInitial); err != nil { if err := s.dropEncryptionLevel(protocol.EncryptionInitial); err != nil {
@ -2377,11 +2382,7 @@ func (s *connection) RemoteAddr() net.Addr {
return s.conn.RemoteAddr() return s.conn.RemoteAddr()
} }
func (s *connection) getPerspective() protocol.Perspective { func (s *connection) GetVersion() protocol.Version {
return s.perspective
}
func (s *connection) GetVersion() protocol.VersionNumber {
return s.version return s.version
} }

View File

@ -15,15 +15,25 @@ type framer interface {
HasData() bool HasData() bool
QueueControlFrame(wire.Frame) QueueControlFrame(wire.Frame)
AppendControlFrames([]ackhandler.Frame, protocol.ByteCount, protocol.VersionNumber) ([]ackhandler.Frame, protocol.ByteCount) AppendControlFrames([]ackhandler.Frame, protocol.ByteCount, protocol.Version) ([]ackhandler.Frame, protocol.ByteCount)
AddActiveStream(protocol.StreamID) AddActiveStream(protocol.StreamID)
AppendStreamFrames([]ackhandler.StreamFrame, protocol.ByteCount, protocol.VersionNumber) ([]ackhandler.StreamFrame, protocol.ByteCount) AppendStreamFrames([]ackhandler.StreamFrame, protocol.ByteCount, protocol.Version) ([]ackhandler.StreamFrame, protocol.ByteCount)
Handle0RTTRejection() error Handle0RTTRejection() error
// QueuedTooManyControlFrames says if the control frame queue exceeded its maximum queue length.
// This is a hack.
// It is easier to implement than propagating an error return value in QueueControlFrame.
// The correct solution would be to queue frames with their respective structs.
// See https://github.com/quic-go/quic-go/issues/4271 for the queueing of stream-related control frames.
QueuedTooManyControlFrames() bool
} }
const maxPathResponses = 256 const (
maxPathResponses = 256
maxControlFrames = 16 << 10
)
type framerI struct { type framerI struct {
mutex sync.Mutex mutex sync.Mutex
@ -33,9 +43,10 @@ type framerI struct {
activeStreams map[protocol.StreamID]struct{} activeStreams map[protocol.StreamID]struct{}
streamQueue ringbuffer.RingBuffer[protocol.StreamID] streamQueue ringbuffer.RingBuffer[protocol.StreamID]
controlFrameMutex sync.Mutex controlFrameMutex sync.Mutex
controlFrames []wire.Frame controlFrames []wire.Frame
pathResponses []*wire.PathResponseFrame pathResponses []*wire.PathResponseFrame
queuedTooManyControlFrames bool
} }
var _ framer = &framerI{} var _ framer = &framerI{}
@ -73,10 +84,15 @@ func (f *framerI) QueueControlFrame(frame wire.Frame) {
f.pathResponses = append(f.pathResponses, pr) f.pathResponses = append(f.pathResponses, pr)
return return
} }
// This is a hack.
if len(f.controlFrames) >= maxControlFrames {
f.queuedTooManyControlFrames = true
return
}
f.controlFrames = append(f.controlFrames, frame) f.controlFrames = append(f.controlFrames, frame)
} }
func (f *framerI) AppendControlFrames(frames []ackhandler.Frame, maxLen protocol.ByteCount, v protocol.VersionNumber) ([]ackhandler.Frame, protocol.ByteCount) { func (f *framerI) AppendControlFrames(frames []ackhandler.Frame, maxLen protocol.ByteCount, v protocol.Version) ([]ackhandler.Frame, protocol.ByteCount) {
f.controlFrameMutex.Lock() f.controlFrameMutex.Lock()
defer f.controlFrameMutex.Unlock() defer f.controlFrameMutex.Unlock()
@ -105,6 +121,10 @@ func (f *framerI) AppendControlFrames(frames []ackhandler.Frame, maxLen protocol
return frames, length return frames, length
} }
func (f *framerI) QueuedTooManyControlFrames() bool {
return f.queuedTooManyControlFrames
}
func (f *framerI) AddActiveStream(id protocol.StreamID) { func (f *framerI) AddActiveStream(id protocol.StreamID) {
f.mutex.Lock() f.mutex.Lock()
if _, ok := f.activeStreams[id]; !ok { if _, ok := f.activeStreams[id]; !ok {
@ -114,7 +134,7 @@ func (f *framerI) AddActiveStream(id protocol.StreamID) {
f.mutex.Unlock() f.mutex.Unlock()
} }
func (f *framerI) AppendStreamFrames(frames []ackhandler.StreamFrame, maxLen protocol.ByteCount, v protocol.VersionNumber) ([]ackhandler.StreamFrame, protocol.ByteCount) { func (f *framerI) AppendStreamFrames(frames []ackhandler.StreamFrame, maxLen protocol.ByteCount, v protocol.Version) ([]ackhandler.StreamFrame, protocol.ByteCount) {
startLen := len(frames) startLen := len(frames)
var length protocol.ByteCount var length protocol.ByteCount
f.mutex.Lock() f.mutex.Lock()

View File

@ -16,8 +16,12 @@ import (
// The StreamID is the ID of a QUIC stream. // The StreamID is the ID of a QUIC stream.
type StreamID = protocol.StreamID type StreamID = protocol.StreamID
// A Version is a QUIC version number.
type Version = protocol.Version
// A VersionNumber is a QUIC version number. // A VersionNumber is a QUIC version number.
type VersionNumber = protocol.VersionNumber // Deprecated: VersionNumber was renamed to Version.
type VersionNumber = Version
const ( const (
// Version1 is RFC 9000 // Version1 is RFC 9000
@ -159,6 +163,9 @@ type Connection interface {
OpenStream() (Stream, error) OpenStream() (Stream, error)
// OpenStreamSync opens a new bidirectional QUIC stream. // OpenStreamSync opens a new bidirectional QUIC stream.
// It blocks until a new stream can be opened. // It blocks until a new stream can be opened.
// There is no signaling to the peer about new streams:
// The peer can only accept the stream after data has been sent on the stream,
// or the stream has been reset or closed.
// If the error is non-nil, it satisfies the net.Error interface. // If the error is non-nil, it satisfies the net.Error interface.
// If the connection was closed due to a timeout, Timeout() will be true. // If the connection was closed due to a timeout, Timeout() will be true.
OpenStreamSync(context.Context) (Stream, error) OpenStreamSync(context.Context) (Stream, error)
@ -255,7 +262,7 @@ type Config struct {
GetConfigForClient func(info *ClientHelloInfo) (*Config, error) GetConfigForClient func(info *ClientHelloInfo) (*Config, error)
// The QUIC versions that can be negotiated. // The QUIC versions that can be negotiated.
// If not set, it uses all versions available. // If not set, it uses all versions available.
Versions []VersionNumber Versions []Version
// HandshakeIdleTimeout is the idle timeout before completion of the handshake. // HandshakeIdleTimeout is the idle timeout before completion of the handshake.
// If we don't receive any packet from the peer within this time, the connection attempt is aborted. // If we don't receive any packet from the peer within this time, the connection attempt is aborted.
// Additionally, if the handshake doesn't complete in twice this time, the connection attempt is also aborted. // Additionally, if the handshake doesn't complete in twice this time, the connection attempt is also aborted.
@ -267,11 +274,6 @@ type Config struct {
// If the timeout is exceeded, the connection is closed. // If the timeout is exceeded, the connection is closed.
// If this value is zero, the timeout is set to 30 seconds. // If this value is zero, the timeout is set to 30 seconds.
MaxIdleTimeout time.Duration MaxIdleTimeout time.Duration
// RequireAddressValidation determines if a QUIC Retry packet is sent.
// This allows the server to verify the client's address, at the cost of increasing the handshake latency by 1 RTT.
// See https://datatracker.ietf.org/doc/html/rfc9000#section-8 for details.
// If not set, every client is forced to prove its remote address.
RequireAddressValidation func(net.Addr) bool
// The TokenStore stores tokens received from the server. // The TokenStore stores tokens received from the server.
// Tokens are used to skip address validation on future connection attempts. // Tokens are used to skip address validation on future connection attempts.
// The key used to store tokens is the ServerName from the tls.Config, if set // The key used to store tokens is the ServerName from the tls.Config, if set
@ -331,8 +333,15 @@ type Config struct {
Tracer func(context.Context, logging.Perspective, ConnectionID) *logging.ConnectionTracer Tracer func(context.Context, logging.Perspective, ConnectionID) *logging.ConnectionTracer
} }
// ClientHelloInfo contains information about an incoming connection attempt.
type ClientHelloInfo struct { type ClientHelloInfo struct {
// RemoteAddr is the remote address on the Initial packet.
// Unless AddrVerified is set, the address is not yet verified, and could be a spoofed IP address.
RemoteAddr net.Addr RemoteAddr net.Addr
// AddrVerified says if the remote address was verified using QUIC's Retry mechanism.
// Note that the Retry mechanism costs one network roundtrip,
// and is not performed unless Transport.MaxUnvalidatedHandshakes is surpassed.
AddrVerified bool
} }
// ConnectionState records basic details about a QUIC connection // ConnectionState records basic details about a QUIC connection
@ -347,7 +356,7 @@ type ConnectionState struct {
// Used0RTT says if 0-RTT resumption was used. // Used0RTT says if 0-RTT resumption was used.
Used0RTT bool Used0RTT bool
// Version is the QUIC version of the QUIC connection. // Version is the QUIC version of the QUIC connection.
Version VersionNumber Version Version
// GSO says if generic segmentation offload is used // GSO says if generic segmentation offload is used
GSO bool GSO bool
} }

View File

@ -20,5 +20,5 @@ func NewAckHandler(
logger utils.Logger, logger utils.Logger,
) (SentPacketHandler, ReceivedPacketHandler) { ) (SentPacketHandler, ReceivedPacketHandler) {
sph := newSentPacketHandler(initialPacketNumber, initialMaxDatagramSize, rttStats, clientAddressValidated, enableECN, pers, tracer, logger) sph := newSentPacketHandler(initialPacketNumber, initialMaxDatagramSize, rttStats, clientAddressValidated, enableECN, pers, tracer, logger)
return sph, newReceivedPacketHandler(sph, rttStats, logger) return sph, newReceivedPacketHandler(sph, logger)
} }

View File

@ -14,23 +14,19 @@ type receivedPacketHandler struct {
initialPackets *receivedPacketTracker initialPackets *receivedPacketTracker
handshakePackets *receivedPacketTracker handshakePackets *receivedPacketTracker
appDataPackets *receivedPacketTracker appDataPackets appDataReceivedPacketTracker
lowest1RTTPacket protocol.PacketNumber lowest1RTTPacket protocol.PacketNumber
} }
var _ ReceivedPacketHandler = &receivedPacketHandler{} var _ ReceivedPacketHandler = &receivedPacketHandler{}
func newReceivedPacketHandler( func newReceivedPacketHandler(sentPackets sentPacketTracker, logger utils.Logger) ReceivedPacketHandler {
sentPackets sentPacketTracker,
rttStats *utils.RTTStats,
logger utils.Logger,
) ReceivedPacketHandler {
return &receivedPacketHandler{ return &receivedPacketHandler{
sentPackets: sentPackets, sentPackets: sentPackets,
initialPackets: newReceivedPacketTracker(rttStats, logger), initialPackets: newReceivedPacketTracker(),
handshakePackets: newReceivedPacketTracker(rttStats, logger), handshakePackets: newReceivedPacketTracker(),
appDataPackets: newReceivedPacketTracker(rttStats, logger), appDataPackets: *newAppDataReceivedPacketTracker(logger),
lowest1RTTPacket: protocol.InvalidPacketNumber, lowest1RTTPacket: protocol.InvalidPacketNumber,
} }
} }
@ -88,41 +84,28 @@ func (h *receivedPacketHandler) DropPackets(encLevel protocol.EncryptionLevel) {
} }
func (h *receivedPacketHandler) GetAlarmTimeout() time.Time { func (h *receivedPacketHandler) GetAlarmTimeout() time.Time {
var initialAlarm, handshakeAlarm time.Time return h.appDataPackets.GetAlarmTimeout()
if h.initialPackets != nil {
initialAlarm = h.initialPackets.GetAlarmTimeout()
}
if h.handshakePackets != nil {
handshakeAlarm = h.handshakePackets.GetAlarmTimeout()
}
oneRTTAlarm := h.appDataPackets.GetAlarmTimeout()
return utils.MinNonZeroTime(utils.MinNonZeroTime(initialAlarm, handshakeAlarm), oneRTTAlarm)
} }
func (h *receivedPacketHandler) GetAckFrame(encLevel protocol.EncryptionLevel, onlyIfQueued bool) *wire.AckFrame { func (h *receivedPacketHandler) GetAckFrame(encLevel protocol.EncryptionLevel, onlyIfQueued bool) *wire.AckFrame {
var ack *wire.AckFrame
//nolint:exhaustive // 0-RTT packets can't contain ACK frames. //nolint:exhaustive // 0-RTT packets can't contain ACK frames.
switch encLevel { switch encLevel {
case protocol.EncryptionInitial: case protocol.EncryptionInitial:
if h.initialPackets != nil { if h.initialPackets != nil {
ack = h.initialPackets.GetAckFrame(onlyIfQueued) return h.initialPackets.GetAckFrame()
} }
return nil
case protocol.EncryptionHandshake: case protocol.EncryptionHandshake:
if h.handshakePackets != nil { if h.handshakePackets != nil {
ack = h.handshakePackets.GetAckFrame(onlyIfQueued) return h.handshakePackets.GetAckFrame()
} }
return nil
case protocol.Encryption1RTT: case protocol.Encryption1RTT:
// 0-RTT packets can't contain ACK frames
return h.appDataPackets.GetAckFrame(onlyIfQueued) return h.appDataPackets.GetAckFrame(onlyIfQueued)
default: default:
// 0-RTT packets can't contain ACK frames
return nil return nil
} }
// For Initial and Handshake ACKs, the delay time is ignored by the receiver.
// Set it to 0 in order to save bytes.
if ack != nil {
ack.DelayTime = 0
}
return ack
} }
func (h *receivedPacketHandler) IsPotentiallyDuplicate(pn protocol.PacketNumber, encLevel protocol.EncryptionLevel) bool { func (h *receivedPacketHandler) IsPotentiallyDuplicate(pn protocol.PacketNumber, encLevel protocol.EncryptionLevel) bool {

View File

@ -9,40 +9,19 @@ import (
"github.com/quic-go/quic-go/internal/wire" "github.com/quic-go/quic-go/internal/wire"
) )
// number of ack-eliciting packets received before sending an ack. // The receivedPacketTracker tracks packets for the Initial and Handshake packet number space.
const packetsBeforeAck = 2 // Every received packet is acknowledged immediately.
type receivedPacketTracker struct { type receivedPacketTracker struct {
largestObserved protocol.PacketNumber ect0, ect1, ecnce uint64
ignoreBelow protocol.PacketNumber
largestObservedRcvdTime time.Time
ect0, ect1, ecnce uint64
packetHistory *receivedPacketHistory packetHistory receivedPacketHistory
maxAckDelay time.Duration
rttStats *utils.RTTStats
lastAck *wire.AckFrame
hasNewAck bool // true as soon as we received an ack-eliciting new packet hasNewAck bool // true as soon as we received an ack-eliciting new packet
ackQueued bool // true once we received more than 2 (or later in the connection 10) ack-eliciting packets
ackElicitingPacketsReceivedSinceLastAck int
ackAlarm time.Time
lastAck *wire.AckFrame
logger utils.Logger
} }
func newReceivedPacketTracker( func newReceivedPacketTracker() *receivedPacketTracker {
rttStats *utils.RTTStats, return &receivedPacketTracker{packetHistory: *newReceivedPacketHistory()}
logger utils.Logger,
) *receivedPacketTracker {
return &receivedPacketTracker{
packetHistory: newReceivedPacketHistory(),
maxAckDelay: protocol.MaxAckDelay,
rttStats: rttStats,
logger: logger,
}
} }
func (h *receivedPacketTracker) ReceivedPacket(pn protocol.PacketNumber, ecn protocol.ECN, rcvTime time.Time, ackEliciting bool) error { func (h *receivedPacketTracker) ReceivedPacket(pn protocol.PacketNumber, ecn protocol.ECN, rcvTime time.Time, ackEliciting bool) error {
@ -50,16 +29,6 @@ func (h *receivedPacketTracker) ReceivedPacket(pn protocol.PacketNumber, ecn pro
return fmt.Errorf("recevedPacketTracker BUG: ReceivedPacket called for old / duplicate packet %d", pn) return fmt.Errorf("recevedPacketTracker BUG: ReceivedPacket called for old / duplicate packet %d", pn)
} }
isMissing := h.isMissing(pn)
if pn >= h.largestObserved {
h.largestObserved = pn
h.largestObservedRcvdTime = rcvTime
}
if ackEliciting {
h.hasNewAck = true
h.maybeQueueACK(pn, rcvTime, ecn, isMissing)
}
//nolint:exhaustive // Only need to count ECT(0), ECT(1) and ECN-CE. //nolint:exhaustive // Only need to count ECT(0), ECT(1) and ECN-CE.
switch ecn { switch ecn {
case protocol.ECT0: case protocol.ECT0:
@ -69,12 +38,99 @@ func (h *receivedPacketTracker) ReceivedPacket(pn protocol.PacketNumber, ecn pro
case protocol.ECNCE: case protocol.ECNCE:
h.ecnce++ h.ecnce++
} }
if !ackEliciting {
return nil
}
h.hasNewAck = true
return nil
}
func (h *receivedPacketTracker) GetAckFrame() *wire.AckFrame {
if !h.hasNewAck {
return nil
}
// This function always returns the same ACK frame struct, filled with the most recent values.
ack := h.lastAck
if ack == nil {
ack = &wire.AckFrame{}
}
ack.Reset()
ack.ECT0 = h.ect0
ack.ECT1 = h.ect1
ack.ECNCE = h.ecnce
ack.AckRanges = h.packetHistory.AppendAckRanges(ack.AckRanges)
h.lastAck = ack
h.hasNewAck = false
return ack
}
func (h *receivedPacketTracker) IsPotentiallyDuplicate(pn protocol.PacketNumber) bool {
return h.packetHistory.IsPotentiallyDuplicate(pn)
}
// number of ack-eliciting packets received before sending an ACK
const packetsBeforeAck = 2
// The appDataReceivedPacketTracker tracks packets received in the Application Data packet number space.
// It waits until at least 2 packets were received before queueing an ACK, or until the max_ack_delay was reached.
type appDataReceivedPacketTracker struct {
receivedPacketTracker
largestObservedRcvdTime time.Time
largestObserved protocol.PacketNumber
ignoreBelow protocol.PacketNumber
maxAckDelay time.Duration
ackQueued bool // true if we need send a new ACK
ackElicitingPacketsReceivedSinceLastAck int
ackAlarm time.Time
logger utils.Logger
}
func newAppDataReceivedPacketTracker(logger utils.Logger) *appDataReceivedPacketTracker {
h := &appDataReceivedPacketTracker{
receivedPacketTracker: *newReceivedPacketTracker(),
maxAckDelay: protocol.MaxAckDelay,
logger: logger,
}
return h
}
func (h *appDataReceivedPacketTracker) ReceivedPacket(pn protocol.PacketNumber, ecn protocol.ECN, rcvTime time.Time, ackEliciting bool) error {
if err := h.receivedPacketTracker.ReceivedPacket(pn, ecn, rcvTime, ackEliciting); err != nil {
return err
}
if pn >= h.largestObserved {
h.largestObserved = pn
h.largestObservedRcvdTime = rcvTime
}
if !ackEliciting {
return nil
}
h.ackElicitingPacketsReceivedSinceLastAck++
isMissing := h.isMissing(pn)
if !h.ackQueued && h.shouldQueueACK(pn, ecn, isMissing) {
h.ackQueued = true
h.ackAlarm = time.Time{} // cancel the ack alarm
}
if !h.ackQueued {
// No ACK queued, but we'll need to acknowledge the packet after max_ack_delay.
h.ackAlarm = rcvTime.Add(h.maxAckDelay)
if h.logger.Debug() {
h.logger.Debugf("\tSetting ACK timer to max ack delay: %s", h.maxAckDelay)
}
}
return nil return nil
} }
// IgnoreBelow sets a lower limit for acknowledging packets. // IgnoreBelow sets a lower limit for acknowledging packets.
// Packets with packet numbers smaller than p will not be acked. // Packets with packet numbers smaller than p will not be acked.
func (h *receivedPacketTracker) IgnoreBelow(pn protocol.PacketNumber) { func (h *appDataReceivedPacketTracker) IgnoreBelow(pn protocol.PacketNumber) {
if pn <= h.ignoreBelow { if pn <= h.ignoreBelow {
return return
} }
@ -86,14 +142,14 @@ func (h *receivedPacketTracker) IgnoreBelow(pn protocol.PacketNumber) {
} }
// isMissing says if a packet was reported missing in the last ACK. // isMissing says if a packet was reported missing in the last ACK.
func (h *receivedPacketTracker) isMissing(p protocol.PacketNumber) bool { func (h *appDataReceivedPacketTracker) isMissing(p protocol.PacketNumber) bool {
if h.lastAck == nil || p < h.ignoreBelow { if h.lastAck == nil || p < h.ignoreBelow {
return false return false
} }
return p < h.lastAck.LargestAcked() && !h.lastAck.AcksPacket(p) return p < h.lastAck.LargestAcked() && !h.lastAck.AcksPacket(p)
} }
func (h *receivedPacketTracker) hasNewMissingPackets() bool { func (h *appDataReceivedPacketTracker) hasNewMissingPackets() bool {
if h.lastAck == nil { if h.lastAck == nil {
return false return false
} }
@ -101,31 +157,21 @@ func (h *receivedPacketTracker) hasNewMissingPackets() bool {
return highestRange.Smallest > h.lastAck.LargestAcked()+1 && highestRange.Len() == 1 return highestRange.Smallest > h.lastAck.LargestAcked()+1 && highestRange.Len() == 1
} }
// maybeQueueACK queues an ACK, if necessary. func (h *appDataReceivedPacketTracker) shouldQueueACK(pn protocol.PacketNumber, ecn protocol.ECN, wasMissing bool) bool {
func (h *receivedPacketTracker) maybeQueueACK(pn protocol.PacketNumber, rcvTime time.Time, ecn protocol.ECN, wasMissing bool) {
// always acknowledge the first packet // always acknowledge the first packet
if h.lastAck == nil { if h.lastAck == nil {
if !h.ackQueued { h.logger.Debugf("\tQueueing ACK because the first packet should be acknowledged.")
h.logger.Debugf("\tQueueing ACK because the first packet should be acknowledged.") return true
}
h.ackQueued = true
return
} }
if h.ackQueued {
return
}
h.ackElicitingPacketsReceivedSinceLastAck++
// Send an ACK if this packet was reported missing in an ACK sent before. // Send an ACK if this packet was reported missing in an ACK sent before.
// Ack decimation with reordering relies on the timer to send an ACK, but if // Ack decimation with reordering relies on the timer to send an ACK, but if
// missing packets we reported in the previous ack, send an ACK immediately. // missing packets we reported in the previous ACK, send an ACK immediately.
if wasMissing { if wasMissing {
if h.logger.Debug() { if h.logger.Debug() {
h.logger.Debugf("\tQueueing ACK because packet %d was missing before.", pn) h.logger.Debugf("\tQueueing ACK because packet %d was missing before.", pn)
} }
h.ackQueued = true return true
} }
// send an ACK every 2 ack-eliciting packets // send an ACK every 2 ack-eliciting packets
@ -133,68 +179,42 @@ func (h *receivedPacketTracker) maybeQueueACK(pn protocol.PacketNumber, rcvTime
if h.logger.Debug() { if h.logger.Debug() {
h.logger.Debugf("\tQueueing ACK because packet %d packets were received after the last ACK (using initial threshold: %d).", h.ackElicitingPacketsReceivedSinceLastAck, packetsBeforeAck) h.logger.Debugf("\tQueueing ACK because packet %d packets were received after the last ACK (using initial threshold: %d).", h.ackElicitingPacketsReceivedSinceLastAck, packetsBeforeAck)
} }
h.ackQueued = true return true
} else if h.ackAlarm.IsZero() {
if h.logger.Debug() {
h.logger.Debugf("\tSetting ACK timer to max ack delay: %s", h.maxAckDelay)
}
h.ackAlarm = rcvTime.Add(h.maxAckDelay)
} }
// queue an ACK if there are new missing packets to report // queue an ACK if there are new missing packets to report
if h.hasNewMissingPackets() { if h.hasNewMissingPackets() {
h.logger.Debugf("\tQueuing ACK because there's a new missing packet to report.") h.logger.Debugf("\tQueuing ACK because there's a new missing packet to report.")
h.ackQueued = true return true
} }
// queue an ACK if the packet was ECN-CE marked // queue an ACK if the packet was ECN-CE marked
if ecn == protocol.ECNCE { if ecn == protocol.ECNCE {
h.logger.Debugf("\tQueuing ACK because the packet was ECN-CE marked.") h.logger.Debugf("\tQueuing ACK because the packet was ECN-CE marked.")
h.ackQueued = true return true
}
if h.ackQueued {
// cancel the ack alarm
h.ackAlarm = time.Time{}
} }
return false
} }
func (h *receivedPacketTracker) GetAckFrame(onlyIfQueued bool) *wire.AckFrame { func (h *appDataReceivedPacketTracker) GetAckFrame(onlyIfQueued bool) *wire.AckFrame {
if !h.hasNewAck {
return nil
}
now := time.Now() now := time.Now()
if onlyIfQueued { if onlyIfQueued && !h.ackQueued {
if !h.ackQueued && (h.ackAlarm.IsZero() || h.ackAlarm.After(now)) { if h.ackAlarm.IsZero() || h.ackAlarm.After(now) {
return nil return nil
} }
if h.logger.Debug() && !h.ackQueued && !h.ackAlarm.IsZero() { if h.logger.Debug() && !h.ackAlarm.IsZero() {
h.logger.Debugf("Sending ACK because the ACK timer expired.") h.logger.Debugf("Sending ACK because the ACK timer expired.")
} }
} }
ack := h.receivedPacketTracker.GetAckFrame()
// This function always returns the same ACK frame struct, filled with the most recent values.
ack := h.lastAck
if ack == nil { if ack == nil {
ack = &wire.AckFrame{} return nil
} }
ack.Reset()
ack.DelayTime = max(0, now.Sub(h.largestObservedRcvdTime)) ack.DelayTime = max(0, now.Sub(h.largestObservedRcvdTime))
ack.ECT0 = h.ect0
ack.ECT1 = h.ect1
ack.ECNCE = h.ecnce
ack.AckRanges = h.packetHistory.AppendAckRanges(ack.AckRanges)
h.lastAck = ack
h.ackAlarm = time.Time{}
h.ackQueued = false h.ackQueued = false
h.hasNewAck = false h.ackAlarm = time.Time{}
h.ackElicitingPacketsReceivedSinceLastAck = 0 h.ackElicitingPacketsReceivedSinceLastAck = 0
return ack return ack
} }
func (h *receivedPacketTracker) GetAlarmTimeout() time.Time { return h.ackAlarm } func (h *appDataReceivedPacketTracker) GetAlarmTimeout() time.Time { return h.ackAlarm }
func (h *receivedPacketTracker) IsPotentiallyDuplicate(pn protocol.PacketNumber) bool {
return h.packetHistory.IsPotentiallyDuplicate(pn)
}

View File

@ -48,10 +48,12 @@ func (c *baseFlowController) AddBytesSent(n protocol.ByteCount) {
} }
// UpdateSendWindow is called after receiving a MAX_{STREAM_}DATA frame. // UpdateSendWindow is called after receiving a MAX_{STREAM_}DATA frame.
func (c *baseFlowController) UpdateSendWindow(offset protocol.ByteCount) { func (c *baseFlowController) UpdateSendWindow(offset protocol.ByteCount) (updated bool) {
if offset > c.sendWindow { if offset > c.sendWindow {
c.sendWindow = offset c.sendWindow = offset
return true
} }
return false
} }
func (c *baseFlowController) sendWindowSize() protocol.ByteCount { func (c *baseFlowController) sendWindowSize() protocol.ByteCount {

View File

@ -5,7 +5,7 @@ import "github.com/quic-go/quic-go/internal/protocol"
type flowController interface { type flowController interface {
// for sending // for sending
SendWindowSize() protocol.ByteCount SendWindowSize() protocol.ByteCount
UpdateSendWindow(protocol.ByteCount) UpdateSendWindow(protocol.ByteCount) (updated bool)
AddBytesSent(protocol.ByteCount) AddBytesSent(protocol.ByteCount)
// for receiving // for receiving
AddBytesRead(protocol.ByteCount) AddBytesRead(protocol.ByteCount)
@ -16,12 +16,11 @@ type flowController interface {
// A StreamFlowController is a flow controller for a QUIC stream. // A StreamFlowController is a flow controller for a QUIC stream.
type StreamFlowController interface { type StreamFlowController interface {
flowController flowController
// for receiving // UpdateHighestReceived is called when a new highest offset is received
// UpdateHighestReceived should be called when a new highest offset is received
// final has to be to true if this is the final offset of the stream, // final has to be to true if this is the final offset of the stream,
// as contained in a STREAM frame with FIN bit, and the RESET_STREAM frame // as contained in a STREAM frame with FIN bit, and the RESET_STREAM frame
UpdateHighestReceived(offset protocol.ByteCount, final bool) error UpdateHighestReceived(offset protocol.ByteCount, final bool) error
// Abandon should be called when reading from the stream is aborted early, // Abandon is called when reading from the stream is aborted early,
// and there won't be any further calls to AddBytesRead. // and there won't be any further calls to AddBytesRead.
Abandon() Abandon()
} }

View File

@ -1,13 +1,12 @@
package handshake package handshake
import ( import (
"crypto/cipher"
"encoding/binary" "encoding/binary"
"github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/protocol"
) )
func createAEAD(suite *cipherSuite, trafficSecret []byte, v protocol.VersionNumber) cipher.AEAD { func createAEAD(suite *cipherSuite, trafficSecret []byte, v protocol.Version) *xorNonceAEAD {
keyLabel := hkdfLabelKeyV1 keyLabel := hkdfLabelKeyV1
ivLabel := hkdfLabelIVV1 ivLabel := hkdfLabelIVV1
if v == protocol.Version2 { if v == protocol.Version2 {
@ -20,28 +19,26 @@ func createAEAD(suite *cipherSuite, trafficSecret []byte, v protocol.VersionNumb
} }
type longHeaderSealer struct { type longHeaderSealer struct {
aead cipher.AEAD aead *xorNonceAEAD
headerProtector headerProtector headerProtector headerProtector
nonceBuf [8]byte
// use a single slice to avoid allocations
nonceBuf []byte
} }
var _ LongHeaderSealer = &longHeaderSealer{} var _ LongHeaderSealer = &longHeaderSealer{}
func newLongHeaderSealer(aead cipher.AEAD, headerProtector headerProtector) LongHeaderSealer { func newLongHeaderSealer(aead *xorNonceAEAD, headerProtector headerProtector) LongHeaderSealer {
if aead.NonceSize() != 8 {
panic("unexpected nonce size")
}
return &longHeaderSealer{ return &longHeaderSealer{
aead: aead, aead: aead,
headerProtector: headerProtector, headerProtector: headerProtector,
nonceBuf: make([]byte, aead.NonceSize()),
} }
} }
func (s *longHeaderSealer) Seal(dst, src []byte, pn protocol.PacketNumber, ad []byte) []byte { func (s *longHeaderSealer) Seal(dst, src []byte, pn protocol.PacketNumber, ad []byte) []byte {
binary.BigEndian.PutUint64(s.nonceBuf[len(s.nonceBuf)-8:], uint64(pn)) binary.BigEndian.PutUint64(s.nonceBuf[:], uint64(pn))
// The AEAD we're using here will be the qtls.aeadAESGCM13. return s.aead.Seal(dst, s.nonceBuf[:], src, ad)
// It uses the nonce provided here and XOR it with the IV.
return s.aead.Seal(dst, s.nonceBuf, src, ad)
} }
func (s *longHeaderSealer) EncryptHeader(sample []byte, firstByte *byte, pnBytes []byte) { func (s *longHeaderSealer) EncryptHeader(sample []byte, firstByte *byte, pnBytes []byte) {
@ -53,21 +50,23 @@ func (s *longHeaderSealer) Overhead() int {
} }
type longHeaderOpener struct { type longHeaderOpener struct {
aead cipher.AEAD aead *xorNonceAEAD
headerProtector headerProtector headerProtector headerProtector
highestRcvdPN protocol.PacketNumber // highest packet number received (which could be successfully unprotected) highestRcvdPN protocol.PacketNumber // highest packet number received (which could be successfully unprotected)
// use a single slice to avoid allocations // use a single array to avoid allocations
nonceBuf []byte nonceBuf [8]byte
} }
var _ LongHeaderOpener = &longHeaderOpener{} var _ LongHeaderOpener = &longHeaderOpener{}
func newLongHeaderOpener(aead cipher.AEAD, headerProtector headerProtector) LongHeaderOpener { func newLongHeaderOpener(aead *xorNonceAEAD, headerProtector headerProtector) LongHeaderOpener {
if aead.NonceSize() != 8 {
panic("unexpected nonce size")
}
return &longHeaderOpener{ return &longHeaderOpener{
aead: aead, aead: aead,
headerProtector: headerProtector, headerProtector: headerProtector,
nonceBuf: make([]byte, aead.NonceSize()),
} }
} }
@ -76,10 +75,8 @@ func (o *longHeaderOpener) DecodePacketNumber(wirePN protocol.PacketNumber, wire
} }
func (o *longHeaderOpener) Open(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error) { func (o *longHeaderOpener) Open(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error) {
binary.BigEndian.PutUint64(o.nonceBuf[len(o.nonceBuf)-8:], uint64(pn)) binary.BigEndian.PutUint64(o.nonceBuf[:], uint64(pn))
// The AEAD we're using here will be the qtls.aeadAESGCM13. dec, err := o.aead.Open(dst, o.nonceBuf[:], src, ad)
// It uses the nonce provided here and XOR it with the IV.
dec, err := o.aead.Open(dst, o.nonceBuf, src, ad)
if err == nil { if err == nil {
o.highestRcvdPN = max(o.highestRcvdPN, pn) o.highestRcvdPN = max(o.highestRcvdPN, pn)
} else { } else {

View File

@ -18,7 +18,7 @@ type cipherSuite struct {
ID uint16 ID uint16
Hash crypto.Hash Hash crypto.Hash
KeyLen int KeyLen int
AEAD func(key, nonceMask []byte) cipher.AEAD AEAD func(key, nonceMask []byte) *xorNonceAEAD
} }
func (s cipherSuite) IVLen() int { return aeadNonceLength } func (s cipherSuite) IVLen() int { return aeadNonceLength }
@ -36,7 +36,7 @@ func getCipherSuite(id uint16) *cipherSuite {
} }
} }
func aeadAESGCMTLS13(key, nonceMask []byte) cipher.AEAD { func aeadAESGCMTLS13(key, nonceMask []byte) *xorNonceAEAD {
if len(nonceMask) != aeadNonceLength { if len(nonceMask) != aeadNonceLength {
panic("tls: internal error: wrong nonce length") panic("tls: internal error: wrong nonce length")
} }
@ -54,7 +54,7 @@ func aeadAESGCMTLS13(key, nonceMask []byte) cipher.AEAD {
return ret return ret
} }
func aeadChaCha20Poly1305(key, nonceMask []byte) cipher.AEAD { func aeadChaCha20Poly1305(key, nonceMask []byte) *xorNonceAEAD {
if len(nonceMask) != aeadNonceLength { if len(nonceMask) != aeadNonceLength {
panic("tls: internal error: wrong nonce length") panic("tls: internal error: wrong nonce length")
} }

View File

@ -8,7 +8,6 @@ import (
"fmt" "fmt"
"net" "net"
"strings" "strings"
"sync"
"sync/atomic" "sync/atomic"
"time" "time"
@ -33,7 +32,7 @@ type cryptoSetup struct {
events []Event events []Event
version protocol.VersionNumber version protocol.Version
ourParams *wire.TransportParameters ourParams *wire.TransportParameters
peerParams *wire.TransportParameters peerParams *wire.TransportParameters
@ -48,8 +47,6 @@ type cryptoSetup struct {
perspective protocol.Perspective perspective protocol.Perspective
mutex sync.Mutex // protects all members below
handshakeCompleteTime time.Time handshakeCompleteTime time.Time
zeroRTTOpener LongHeaderOpener // only set for the server zeroRTTOpener LongHeaderOpener // only set for the server
@ -79,7 +76,7 @@ func NewCryptoSetupClient(
rttStats *utils.RTTStats, rttStats *utils.RTTStats,
tracer *logging.ConnectionTracer, tracer *logging.ConnectionTracer,
logger utils.Logger, logger utils.Logger,
version protocol.VersionNumber, version protocol.Version,
) CryptoSetup { ) CryptoSetup {
cs := newCryptoSetup( cs := newCryptoSetup(
connID, connID,
@ -114,7 +111,7 @@ func NewCryptoSetupServer(
rttStats *utils.RTTStats, rttStats *utils.RTTStats,
tracer *logging.ConnectionTracer, tracer *logging.ConnectionTracer,
logger utils.Logger, logger utils.Logger,
version protocol.VersionNumber, version protocol.Version,
) CryptoSetup { ) CryptoSetup {
cs := newCryptoSetup( cs := newCryptoSetup(
connID, connID,
@ -172,7 +169,7 @@ func newCryptoSetup(
tracer *logging.ConnectionTracer, tracer *logging.ConnectionTracer,
logger utils.Logger, logger utils.Logger,
perspective protocol.Perspective, perspective protocol.Perspective,
version protocol.VersionNumber, version protocol.Version,
) *cryptoSetup { ) *cryptoSetup {
initialSealer, initialOpener := NewInitialAEAD(connID, perspective, version) initialSealer, initialOpener := NewInitialAEAD(connID, perspective, version)
if tracer != nil && tracer.UpdatedKeyFromTLS != nil { if tracer != nil && tracer.UpdatedKeyFromTLS != nil {
@ -269,10 +266,10 @@ func (h *cryptoSetup) handleEvent(ev tls.QUICEvent) (done bool, err error) {
case tls.QUICNoEvent: case tls.QUICNoEvent:
return true, nil return true, nil
case tls.QUICSetReadSecret: case tls.QUICSetReadSecret:
h.SetReadKey(ev.Level, ev.Suite, ev.Data) h.setReadKey(ev.Level, ev.Suite, ev.Data)
return false, nil return false, nil
case tls.QUICSetWriteSecret: case tls.QUICSetWriteSecret:
h.SetWriteKey(ev.Level, ev.Suite, ev.Data) h.setWriteKey(ev.Level, ev.Suite, ev.Data)
return false, nil return false, nil
case tls.QUICTransportParameters: case tls.QUICTransportParameters:
return false, h.handleTransportParameters(ev.Data) return false, h.handleTransportParameters(ev.Data)
@ -434,19 +431,16 @@ func (h *cryptoSetup) handleSessionTicket(sessionTicketData []byte, using0RTT bo
func (h *cryptoSetup) rejected0RTT() { func (h *cryptoSetup) rejected0RTT() {
h.logger.Debugf("0-RTT was rejected. Dropping 0-RTT keys.") h.logger.Debugf("0-RTT was rejected. Dropping 0-RTT keys.")
h.mutex.Lock()
had0RTTKeys := h.zeroRTTSealer != nil had0RTTKeys := h.zeroRTTSealer != nil
h.zeroRTTSealer = nil h.zeroRTTSealer = nil
h.mutex.Unlock()
if had0RTTKeys { if had0RTTKeys {
h.events = append(h.events, Event{Kind: EventDiscard0RTTKeys}) h.events = append(h.events, Event{Kind: EventDiscard0RTTKeys})
} }
} }
func (h *cryptoSetup) SetReadKey(el tls.QUICEncryptionLevel, suiteID uint16, trafficSecret []byte) { func (h *cryptoSetup) setReadKey(el tls.QUICEncryptionLevel, suiteID uint16, trafficSecret []byte) {
suite := getCipherSuite(suiteID) suite := getCipherSuite(suiteID)
h.mutex.Lock()
//nolint:exhaustive // The TLS stack doesn't export Initial keys. //nolint:exhaustive // The TLS stack doesn't export Initial keys.
switch el { switch el {
case tls.QUICEncryptionLevelEarly: case tls.QUICEncryptionLevelEarly:
@ -478,16 +472,14 @@ func (h *cryptoSetup) SetReadKey(el tls.QUICEncryptionLevel, suiteID uint16, tra
default: default:
panic("unexpected read encryption level") panic("unexpected read encryption level")
} }
h.mutex.Unlock()
h.events = append(h.events, Event{Kind: EventReceivedReadKeys}) h.events = append(h.events, Event{Kind: EventReceivedReadKeys})
if h.tracer != nil && h.tracer.UpdatedKeyFromTLS != nil { if h.tracer != nil && h.tracer.UpdatedKeyFromTLS != nil {
h.tracer.UpdatedKeyFromTLS(qtls.FromTLSEncryptionLevel(el), h.perspective.Opposite()) h.tracer.UpdatedKeyFromTLS(qtls.FromTLSEncryptionLevel(el), h.perspective.Opposite())
} }
} }
func (h *cryptoSetup) SetWriteKey(el tls.QUICEncryptionLevel, suiteID uint16, trafficSecret []byte) { func (h *cryptoSetup) setWriteKey(el tls.QUICEncryptionLevel, suiteID uint16, trafficSecret []byte) {
suite := getCipherSuite(suiteID) suite := getCipherSuite(suiteID)
h.mutex.Lock()
//nolint:exhaustive // The TLS stack doesn't export Initial keys. //nolint:exhaustive // The TLS stack doesn't export Initial keys.
switch el { switch el {
case tls.QUICEncryptionLevelEarly: case tls.QUICEncryptionLevelEarly:
@ -498,7 +490,6 @@ func (h *cryptoSetup) SetWriteKey(el tls.QUICEncryptionLevel, suiteID uint16, tr
createAEAD(suite, trafficSecret, h.version), createAEAD(suite, trafficSecret, h.version),
newHeaderProtector(suite, trafficSecret, true, h.version), newHeaderProtector(suite, trafficSecret, true, h.version),
) )
h.mutex.Unlock()
if h.logger.Debug() { if h.logger.Debug() {
h.logger.Debugf("Installed 0-RTT Write keys (using %s)", tls.CipherSuiteName(suite.ID)) h.logger.Debugf("Installed 0-RTT Write keys (using %s)", tls.CipherSuiteName(suite.ID))
} }
@ -533,7 +524,6 @@ func (h *cryptoSetup) SetWriteKey(el tls.QUICEncryptionLevel, suiteID uint16, tr
default: default:
panic("unexpected write encryption level") panic("unexpected write encryption level")
} }
h.mutex.Unlock()
if h.tracer != nil && h.tracer.UpdatedKeyFromTLS != nil { if h.tracer != nil && h.tracer.UpdatedKeyFromTLS != nil {
h.tracer.UpdatedKeyFromTLS(qtls.FromTLSEncryptionLevel(el), h.perspective) h.tracer.UpdatedKeyFromTLS(qtls.FromTLSEncryptionLevel(el), h.perspective)
} }
@ -555,11 +545,9 @@ func (h *cryptoSetup) writeRecord(encLevel tls.QUICEncryptionLevel, p []byte) {
} }
func (h *cryptoSetup) DiscardInitialKeys() { func (h *cryptoSetup) DiscardInitialKeys() {
h.mutex.Lock()
dropped := h.initialOpener != nil dropped := h.initialOpener != nil
h.initialOpener = nil h.initialOpener = nil
h.initialSealer = nil h.initialSealer = nil
h.mutex.Unlock()
if dropped { if dropped {
h.logger.Debugf("Dropping Initial keys.") h.logger.Debugf("Dropping Initial keys.")
} }
@ -574,22 +562,17 @@ func (h *cryptoSetup) SetHandshakeConfirmed() {
h.aead.SetHandshakeConfirmed() h.aead.SetHandshakeConfirmed()
// drop Handshake keys // drop Handshake keys
var dropped bool var dropped bool
h.mutex.Lock()
if h.handshakeOpener != nil { if h.handshakeOpener != nil {
h.handshakeOpener = nil h.handshakeOpener = nil
h.handshakeSealer = nil h.handshakeSealer = nil
dropped = true dropped = true
} }
h.mutex.Unlock()
if dropped { if dropped {
h.logger.Debugf("Dropping Handshake keys.") h.logger.Debugf("Dropping Handshake keys.")
} }
} }
func (h *cryptoSetup) GetInitialSealer() (LongHeaderSealer, error) { func (h *cryptoSetup) GetInitialSealer() (LongHeaderSealer, error) {
h.mutex.Lock()
defer h.mutex.Unlock()
if h.initialSealer == nil { if h.initialSealer == nil {
return nil, ErrKeysDropped return nil, ErrKeysDropped
} }
@ -597,9 +580,6 @@ func (h *cryptoSetup) GetInitialSealer() (LongHeaderSealer, error) {
} }
func (h *cryptoSetup) Get0RTTSealer() (LongHeaderSealer, error) { func (h *cryptoSetup) Get0RTTSealer() (LongHeaderSealer, error) {
h.mutex.Lock()
defer h.mutex.Unlock()
if h.zeroRTTSealer == nil { if h.zeroRTTSealer == nil {
return nil, ErrKeysDropped return nil, ErrKeysDropped
} }
@ -607,9 +587,6 @@ func (h *cryptoSetup) Get0RTTSealer() (LongHeaderSealer, error) {
} }
func (h *cryptoSetup) GetHandshakeSealer() (LongHeaderSealer, error) { func (h *cryptoSetup) GetHandshakeSealer() (LongHeaderSealer, error) {
h.mutex.Lock()
defer h.mutex.Unlock()
if h.handshakeSealer == nil { if h.handshakeSealer == nil {
if h.initialSealer == nil { if h.initialSealer == nil {
return nil, ErrKeysDropped return nil, ErrKeysDropped
@ -620,9 +597,6 @@ func (h *cryptoSetup) GetHandshakeSealer() (LongHeaderSealer, error) {
} }
func (h *cryptoSetup) Get1RTTSealer() (ShortHeaderSealer, error) { func (h *cryptoSetup) Get1RTTSealer() (ShortHeaderSealer, error) {
h.mutex.Lock()
defer h.mutex.Unlock()
if !h.has1RTTSealer { if !h.has1RTTSealer {
return nil, ErrKeysNotYetAvailable return nil, ErrKeysNotYetAvailable
} }
@ -630,9 +604,6 @@ func (h *cryptoSetup) Get1RTTSealer() (ShortHeaderSealer, error) {
} }
func (h *cryptoSetup) GetInitialOpener() (LongHeaderOpener, error) { func (h *cryptoSetup) GetInitialOpener() (LongHeaderOpener, error) {
h.mutex.Lock()
defer h.mutex.Unlock()
if h.initialOpener == nil { if h.initialOpener == nil {
return nil, ErrKeysDropped return nil, ErrKeysDropped
} }
@ -640,9 +611,6 @@ func (h *cryptoSetup) GetInitialOpener() (LongHeaderOpener, error) {
} }
func (h *cryptoSetup) Get0RTTOpener() (LongHeaderOpener, error) { func (h *cryptoSetup) Get0RTTOpener() (LongHeaderOpener, error) {
h.mutex.Lock()
defer h.mutex.Unlock()
if h.zeroRTTOpener == nil { if h.zeroRTTOpener == nil {
if h.initialOpener != nil { if h.initialOpener != nil {
return nil, ErrKeysNotYetAvailable return nil, ErrKeysNotYetAvailable
@ -654,9 +622,6 @@ func (h *cryptoSetup) Get0RTTOpener() (LongHeaderOpener, error) {
} }
func (h *cryptoSetup) GetHandshakeOpener() (LongHeaderOpener, error) { func (h *cryptoSetup) GetHandshakeOpener() (LongHeaderOpener, error) {
h.mutex.Lock()
defer h.mutex.Unlock()
if h.handshakeOpener == nil { if h.handshakeOpener == nil {
if h.initialOpener != nil { if h.initialOpener != nil {
return nil, ErrKeysNotYetAvailable return nil, ErrKeysNotYetAvailable
@ -668,9 +633,6 @@ func (h *cryptoSetup) GetHandshakeOpener() (LongHeaderOpener, error) {
} }
func (h *cryptoSetup) Get1RTTOpener() (ShortHeaderOpener, error) { func (h *cryptoSetup) Get1RTTOpener() (ShortHeaderOpener, error) {
h.mutex.Lock()
defer h.mutex.Unlock()
if h.zeroRTTOpener != nil && time.Since(h.handshakeCompleteTime) > 3*h.rttStats.PTO(true) { if h.zeroRTTOpener != nil && time.Since(h.handshakeCompleteTime) > 3*h.rttStats.PTO(true) {
h.zeroRTTOpener = nil h.zeroRTTOpener = nil
h.logger.Debugf("Dropping 0-RTT keys.") h.logger.Debugf("Dropping 0-RTT keys.")

View File

@ -17,14 +17,14 @@ type headerProtector interface {
DecryptHeader(sample []byte, firstByte *byte, hdrBytes []byte) DecryptHeader(sample []byte, firstByte *byte, hdrBytes []byte)
} }
func hkdfHeaderProtectionLabel(v protocol.VersionNumber) string { func hkdfHeaderProtectionLabel(v protocol.Version) string {
if v == protocol.Version2 { if v == protocol.Version2 {
return "quicv2 hp" return "quicv2 hp"
} }
return "quic hp" return "quic hp"
} }
func newHeaderProtector(suite *cipherSuite, trafficSecret []byte, isLongHeader bool, v protocol.VersionNumber) headerProtector { func newHeaderProtector(suite *cipherSuite, trafficSecret []byte, isLongHeader bool, v protocol.Version) headerProtector {
hkdfLabel := hkdfHeaderProtectionLabel(v) hkdfLabel := hkdfHeaderProtectionLabel(v)
switch suite.ID { switch suite.ID {
case tls.TLS_AES_128_GCM_SHA256, tls.TLS_AES_256_GCM_SHA384: case tls.TLS_AES_128_GCM_SHA256, tls.TLS_AES_256_GCM_SHA384:
@ -37,7 +37,7 @@ func newHeaderProtector(suite *cipherSuite, trafficSecret []byte, isLongHeader b
} }
type aesHeaderProtector struct { type aesHeaderProtector struct {
mask []byte mask [16]byte // AES always has a 16 byte block size
block cipher.Block block cipher.Block
isLongHeader bool isLongHeader bool
} }
@ -52,7 +52,6 @@ func newAESHeaderProtector(suite *cipherSuite, trafficSecret []byte, isLongHeade
} }
return &aesHeaderProtector{ return &aesHeaderProtector{
block: block, block: block,
mask: make([]byte, block.BlockSize()),
isLongHeader: isLongHeader, isLongHeader: isLongHeader,
} }
} }
@ -69,7 +68,7 @@ func (p *aesHeaderProtector) apply(sample []byte, firstByte *byte, hdrBytes []by
if len(sample) != len(p.mask) { if len(sample) != len(p.mask) {
panic("invalid sample size") panic("invalid sample size")
} }
p.block.Encrypt(p.mask, sample) p.block.Encrypt(p.mask[:], sample)
if p.isLongHeader { if p.isLongHeader {
*firstByte ^= p.mask[0] & 0xf *firstByte ^= p.mask[0] & 0xf
} else { } else {

View File

@ -7,7 +7,7 @@ import (
"golang.org/x/crypto/hkdf" "golang.org/x/crypto/hkdf"
) )
// hkdfExpandLabel HKDF expands a label. // hkdfExpandLabel HKDF expands a label as defined in RFC 8446, section 7.1.
// Since this implementation avoids using a cryptobyte.Builder, it is about 15% faster than the // Since this implementation avoids using a cryptobyte.Builder, it is about 15% faster than the
// hkdfExpandLabel in the standard library. // hkdfExpandLabel in the standard library.
func hkdfExpandLabel(hash crypto.Hash, secret, context []byte, label string, length int) []byte { func hkdfExpandLabel(hash crypto.Hash, secret, context []byte, label string, length int) []byte {

View File

@ -21,7 +21,7 @@ const (
hkdfLabelIVV2 = "quicv2 iv" hkdfLabelIVV2 = "quicv2 iv"
) )
func getSalt(v protocol.VersionNumber) []byte { func getSalt(v protocol.Version) []byte {
if v == protocol.Version2 { if v == protocol.Version2 {
return quicSaltV2 return quicSaltV2
} }
@ -31,7 +31,7 @@ func getSalt(v protocol.VersionNumber) []byte {
var initialSuite = getCipherSuite(tls.TLS_AES_128_GCM_SHA256) var initialSuite = getCipherSuite(tls.TLS_AES_128_GCM_SHA256)
// NewInitialAEAD creates a new AEAD for Initial encryption / decryption. // NewInitialAEAD creates a new AEAD for Initial encryption / decryption.
func NewInitialAEAD(connID protocol.ConnectionID, pers protocol.Perspective, v protocol.VersionNumber) (LongHeaderSealer, LongHeaderOpener) { func NewInitialAEAD(connID protocol.ConnectionID, pers protocol.Perspective, v protocol.Version) (LongHeaderSealer, LongHeaderOpener) {
clientSecret, serverSecret := computeSecrets(connID, v) clientSecret, serverSecret := computeSecrets(connID, v)
var mySecret, otherSecret []byte var mySecret, otherSecret []byte
if pers == protocol.PerspectiveClient { if pers == protocol.PerspectiveClient {
@ -51,14 +51,14 @@ func NewInitialAEAD(connID protocol.ConnectionID, pers protocol.Perspective, v p
newLongHeaderOpener(decrypter, newAESHeaderProtector(initialSuite, otherSecret, true, hkdfHeaderProtectionLabel(v))) newLongHeaderOpener(decrypter, newAESHeaderProtector(initialSuite, otherSecret, true, hkdfHeaderProtectionLabel(v)))
} }
func computeSecrets(connID protocol.ConnectionID, v protocol.VersionNumber) (clientSecret, serverSecret []byte) { func computeSecrets(connID protocol.ConnectionID, v protocol.Version) (clientSecret, serverSecret []byte) {
initialSecret := hkdf.Extract(crypto.SHA256.New, connID.Bytes(), getSalt(v)) initialSecret := hkdf.Extract(crypto.SHA256.New, connID.Bytes(), getSalt(v))
clientSecret = hkdfExpandLabel(crypto.SHA256, initialSecret, []byte{}, "client in", crypto.SHA256.Size()) clientSecret = hkdfExpandLabel(crypto.SHA256, initialSecret, []byte{}, "client in", crypto.SHA256.Size())
serverSecret = hkdfExpandLabel(crypto.SHA256, initialSecret, []byte{}, "server in", crypto.SHA256.Size()) serverSecret = hkdfExpandLabel(crypto.SHA256, initialSecret, []byte{}, "server in", crypto.SHA256.Size())
return return
} }
func computeInitialKeyAndIV(secret []byte, v protocol.VersionNumber) (key, iv []byte) { func computeInitialKeyAndIV(secret []byte, v protocol.Version) (key, iv []byte) {
keyLabel := hkdfLabelKeyV1 keyLabel := hkdfLabelKeyV1
ivLabel := hkdfLabelIVV1 ivLabel := hkdfLabelIVV1
if v == protocol.Version2 { if v == protocol.Version2 {

View File

@ -40,7 +40,7 @@ var (
) )
// GetRetryIntegrityTag calculates the integrity tag on a Retry packet // GetRetryIntegrityTag calculates the integrity tag on a Retry packet
func GetRetryIntegrityTag(retry []byte, origDestConnID protocol.ConnectionID, version protocol.VersionNumber) *[16]byte { func GetRetryIntegrityTag(retry []byte, origDestConnID protocol.ConnectionID, version protocol.Version) *[16]byte {
retryMutex.Lock() retryMutex.Lock()
defer retryMutex.Unlock() defer retryMutex.Unlock()

View File

@ -59,7 +59,7 @@ type updatableAEAD struct {
tracer *logging.ConnectionTracer tracer *logging.ConnectionTracer
logger utils.Logger logger utils.Logger
version protocol.VersionNumber version protocol.Version
// use a single slice to avoid allocations // use a single slice to avoid allocations
nonceBuf []byte nonceBuf []byte
@ -70,7 +70,7 @@ var (
_ ShortHeaderSealer = &updatableAEAD{} _ ShortHeaderSealer = &updatableAEAD{}
) )
func newUpdatableAEAD(rttStats *utils.RTTStats, tracer *logging.ConnectionTracer, logger utils.Logger, version protocol.VersionNumber) *updatableAEAD { func newUpdatableAEAD(rttStats *utils.RTTStats, tracer *logging.ConnectionTracer, logger utils.Logger, version protocol.Version) *updatableAEAD {
return &updatableAEAD{ return &updatableAEAD{
firstPacketNumber: protocol.InvalidPacketNumber, firstPacketNumber: protocol.InvalidPacketNumber,
largestAcked: protocol.InvalidPacketNumber, largestAcked: protocol.InvalidPacketNumber,
@ -133,7 +133,7 @@ func (a *updatableAEAD) SetReadKey(suite *cipherSuite, trafficSecret []byte) {
// SetWriteKey sets the write key. // SetWriteKey sets the write key.
// For the client, this function is called after SetReadKey. // For the client, this function is called after SetReadKey.
// For the server, this function is called before SetWriteKey. // For the server, this function is called before SetReadKey.
func (a *updatableAEAD) SetWriteKey(suite *cipherSuite, trafficSecret []byte) { func (a *updatableAEAD) SetWriteKey(suite *cipherSuite, trafficSecret []byte) {
a.sendAEAD = createAEAD(suite, trafficSecret, a.version) a.sendAEAD = createAEAD(suite, trafficSecret, a.version)
a.headerEncrypter = newHeaderProtector(suite, trafficSecret, false, a.version) a.headerEncrypter = newHeaderProtector(suite, trafficSecret, false, a.version)

View File

@ -17,9 +17,9 @@ func (p Perspective) Opposite() Perspective {
func (p Perspective) String() string { func (p Perspective) String() string {
switch p { switch p {
case PerspectiveServer: case PerspectiveServer:
return "Server" return "server"
case PerspectiveClient: case PerspectiveClient:
return "Client" return "client"
default: default:
return "invalid perspective" return "invalid perspective"
} }

View File

@ -1,14 +1,17 @@
package protocol package protocol
import ( import (
"crypto/rand"
"encoding/binary" "encoding/binary"
"fmt" "fmt"
"math" "math"
"sync"
"time"
"golang.org/x/exp/rand"
) )
// VersionNumber is a version number as int // Version is a version number as int
type VersionNumber uint32 type Version uint32
// gQUIC version range as defined in the wiki: https://github.com/quicwg/base-drafts/wiki/QUIC-Versions // gQUIC version range as defined in the wiki: https://github.com/quicwg/base-drafts/wiki/QUIC-Versions
const ( const (
@ -18,22 +21,22 @@ const (
// The version numbers, making grepping easier // The version numbers, making grepping easier
const ( const (
VersionUnknown VersionNumber = math.MaxUint32 VersionUnknown Version = math.MaxUint32
versionDraft29 VersionNumber = 0xff00001d // draft-29 used to be a widely deployed version versionDraft29 Version = 0xff00001d // draft-29 used to be a widely deployed version
Version1 VersionNumber = 0x1 Version1 Version = 0x1
Version2 VersionNumber = 0x6b3343cf Version2 Version = 0x6b3343cf
) )
// SupportedVersions lists the versions that the server supports // SupportedVersions lists the versions that the server supports
// must be in sorted descending order // must be in sorted descending order
var SupportedVersions = []VersionNumber{Version1, Version2} var SupportedVersions = []Version{Version1, Version2}
// IsValidVersion says if the version is known to quic-go // IsValidVersion says if the version is known to quic-go
func IsValidVersion(v VersionNumber) bool { func IsValidVersion(v Version) bool {
return v == Version1 || IsSupportedVersion(SupportedVersions, v) return v == Version1 || IsSupportedVersion(SupportedVersions, v)
} }
func (vn VersionNumber) String() string { func (vn Version) String() string {
//nolint:exhaustive //nolint:exhaustive
switch vn { switch vn {
case VersionUnknown: case VersionUnknown:
@ -52,16 +55,16 @@ func (vn VersionNumber) String() string {
} }
} }
func (vn VersionNumber) isGQUIC() bool { func (vn Version) isGQUIC() bool {
return vn > gquicVersion0 && vn <= maxGquicVersion return vn > gquicVersion0 && vn <= maxGquicVersion
} }
func (vn VersionNumber) toGQUICVersion() int { func (vn Version) toGQUICVersion() int {
return int(10*(vn-gquicVersion0)/0x100) + int(vn%0x10) return int(10*(vn-gquicVersion0)/0x100) + int(vn%0x10)
} }
// IsSupportedVersion returns true if the server supports this version // IsSupportedVersion returns true if the server supports this version
func IsSupportedVersion(supported []VersionNumber, v VersionNumber) bool { func IsSupportedVersion(supported []Version, v Version) bool {
for _, t := range supported { for _, t := range supported {
if t == v { if t == v {
return true return true
@ -74,7 +77,7 @@ func IsSupportedVersion(supported []VersionNumber, v VersionNumber) bool {
// ours is a slice of versions that we support, sorted by our preference (descending) // ours is a slice of versions that we support, sorted by our preference (descending)
// theirs is a slice of versions offered by the peer. The order does not matter. // theirs is a slice of versions offered by the peer. The order does not matter.
// The bool returned indicates if a matching version was found. // The bool returned indicates if a matching version was found.
func ChooseSupportedVersion(ours, theirs []VersionNumber) (VersionNumber, bool) { func ChooseSupportedVersion(ours, theirs []Version) (Version, bool) {
for _, ourVer := range ours { for _, ourVer := range ours {
for _, theirVer := range theirs { for _, theirVer := range theirs {
if ourVer == theirVer { if ourVer == theirVer {
@ -85,19 +88,25 @@ func ChooseSupportedVersion(ours, theirs []VersionNumber) (VersionNumber, bool)
return 0, false return 0, false
} }
// generateReservedVersion generates a reserved version number (v & 0x0f0f0f0f == 0x0a0a0a0a) var (
func generateReservedVersion() VersionNumber { versionNegotiationMx sync.Mutex
b := make([]byte, 4) versionNegotiationRand = rand.New(rand.NewSource(uint64(time.Now().UnixNano())))
_, _ = rand.Read(b) // ignore the error here. Failure to read random data doesn't break anything )
return VersionNumber((binary.BigEndian.Uint32(b) | 0x0a0a0a0a) & 0xfafafafa)
// generateReservedVersion generates a reserved version (v & 0x0f0f0f0f == 0x0a0a0a0a)
func generateReservedVersion() Version {
var b [4]byte
_, _ = versionNegotiationRand.Read(b[:]) // ignore the error here. Failure to read random data doesn't break anything
return Version((binary.BigEndian.Uint32(b[:]) | 0x0a0a0a0a) & 0xfafafafa)
} }
// GetGreasedVersions adds one reserved version number to a slice of version numbers, at a random position // GetGreasedVersions adds one reserved version number to a slice of version numbers, at a random position.
func GetGreasedVersions(supported []VersionNumber) []VersionNumber { // It doesn't modify the supported slice.
b := make([]byte, 1) func GetGreasedVersions(supported []Version) []Version {
_, _ = rand.Read(b) // ignore the error here. Failure to read random data doesn't break anything versionNegotiationMx.Lock()
randPos := int(b[0]) % (len(supported) + 1) defer versionNegotiationMx.Unlock()
greased := make([]VersionNumber, len(supported)+1) randPos := rand.Intn(len(supported) + 1)
greased := make([]Version, len(supported)+1)
copy(greased, supported[:randPos]) copy(greased, supported[:randPos])
greased[randPos] = generateReservedVersion() greased[randPos] = generateReservedVersion()
copy(greased[randPos+1:], supported[randPos:]) copy(greased[randPos+1:], supported[randPos:])

View File

@ -101,8 +101,8 @@ func (e *HandshakeTimeoutError) Is(target error) bool { return target == net.Err
// A VersionNegotiationError occurs when the client and the server can't agree on a QUIC version. // A VersionNegotiationError occurs when the client and the server can't agree on a QUIC version.
type VersionNegotiationError struct { type VersionNegotiationError struct {
Ours []protocol.VersionNumber Ours []protocol.Version
Theirs []protocol.VersionNumber Theirs []protocol.Version
} }
func (e *VersionNegotiationError) Error() string { func (e *VersionNegotiationError) Error() string {

View File

@ -1,23 +1,11 @@
package qtls package qtls
import ( import (
"crypto"
"crypto/cipher"
"crypto/tls" "crypto/tls"
"fmt" "fmt"
"unsafe" "unsafe"
) )
type cipherSuiteTLS13 struct {
ID uint16
KeyLen int
AEAD func(key, fixedNonce []byte) cipher.AEAD
Hash crypto.Hash
}
//go:linkname cipherSuiteTLS13ByID crypto/tls.cipherSuiteTLS13ByID
func cipherSuiteTLS13ByID(id uint16) *cipherSuiteTLS13
//go:linkname cipherSuitesTLS13 crypto/tls.cipherSuitesTLS13 //go:linkname cipherSuitesTLS13 crypto/tls.cipherSuitesTLS13
var cipherSuitesTLS13 []unsafe.Pointer var cipherSuitesTLS13 []unsafe.Pointer

View File

@ -1,12 +1,12 @@
//go:build go1.21
package qtls package qtls
import ( import (
"crypto/tls" "crypto/tls"
"sync"
) )
type clientSessionCache struct { type clientSessionCache struct {
mx sync.Mutex
getData func(earlyData bool) []byte getData func(earlyData bool) []byte
setData func(data []byte, earlyData bool) (allowEarlyData bool) setData func(data []byte, earlyData bool) (allowEarlyData bool)
wrapped tls.ClientSessionCache wrapped tls.ClientSessionCache
@ -14,7 +14,10 @@ type clientSessionCache struct {
var _ tls.ClientSessionCache = &clientSessionCache{} var _ tls.ClientSessionCache = &clientSessionCache{}
func (c clientSessionCache) Put(key string, cs *tls.ClientSessionState) { func (c *clientSessionCache) Put(key string, cs *tls.ClientSessionState) {
c.mx.Lock()
defer c.mx.Unlock()
if cs == nil { if cs == nil {
c.wrapped.Put(key, nil) c.wrapped.Put(key, nil)
return return
@ -34,7 +37,10 @@ func (c clientSessionCache) Put(key string, cs *tls.ClientSessionState) {
c.wrapped.Put(key, newCS) c.wrapped.Put(key, newCS)
} }
func (c clientSessionCache) Get(key string) (*tls.ClientSessionState, bool) { func (c *clientSessionCache) Get(key string) (*tls.ClientSessionState, bool) {
c.mx.Lock()
defer c.mx.Unlock()
cs, ok := c.wrapped.Get(key) cs, ok := c.wrapped.Get(key)
if !ok || cs == nil { if !ok || cs == nil {
return cs, ok return cs, ok

View File

@ -27,18 +27,6 @@ func MinTime(a, b time.Time) time.Time {
return a return a
} }
// MinNonZeroTime returns the earlist time that is not time.Time{}
// If both a and b are time.Time{}, it returns time.Time{}
func MinNonZeroTime(a, b time.Time) time.Time {
if a.IsZero() {
return b
}
if b.IsZero() {
return a
}
return MinTime(a, b)
}
// MaxTime returns the later time // MaxTime returns the later time
func MaxTime(a, b time.Time) time.Time { func MaxTime(a, b time.Time) time.Time {
if a.After(b) { if a.After(b) {

View File

@ -22,7 +22,7 @@ type AckFrame struct {
} }
// parseAckFrame reads an ACK frame // parseAckFrame reads an ACK frame
func parseAckFrame(frame *AckFrame, r *bytes.Reader, typ uint64, ackDelayExponent uint8, _ protocol.VersionNumber) error { func parseAckFrame(frame *AckFrame, r *bytes.Reader, typ uint64, ackDelayExponent uint8, _ protocol.Version) error {
ecn := typ == ackECNFrameType ecn := typ == ackECNFrameType
la, err := quicvarint.Read(r) la, err := quicvarint.Read(r)
@ -110,7 +110,7 @@ func parseAckFrame(frame *AckFrame, r *bytes.Reader, typ uint64, ackDelayExponen
} }
// Append appends an ACK frame. // Append appends an ACK frame.
func (f *AckFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, error) { func (f *AckFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {
hasECN := f.ECT0 > 0 || f.ECT1 > 0 || f.ECNCE > 0 hasECN := f.ECT0 > 0 || f.ECT1 > 0 || f.ECNCE > 0
if hasECN { if hasECN {
b = append(b, ackECNFrameType) b = append(b, ackECNFrameType)
@ -143,7 +143,7 @@ func (f *AckFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, error) {
} }
// Length of a written frame // Length of a written frame
func (f *AckFrame) Length(_ protocol.VersionNumber) protocol.ByteCount { func (f *AckFrame) Length(_ protocol.Version) protocol.ByteCount {
largestAcked := f.AckRanges[0].Largest largestAcked := f.AckRanges[0].Largest
numRanges := f.numEncodableAckRanges() numRanges := f.numEncodableAckRanges()

View File

@ -16,7 +16,7 @@ type ConnectionCloseFrame struct {
ReasonPhrase string ReasonPhrase string
} }
func parseConnectionCloseFrame(r *bytes.Reader, typ uint64, _ protocol.VersionNumber) (*ConnectionCloseFrame, error) { func parseConnectionCloseFrame(r *bytes.Reader, typ uint64, _ protocol.Version) (*ConnectionCloseFrame, error) {
f := &ConnectionCloseFrame{IsApplicationError: typ == applicationCloseFrameType} f := &ConnectionCloseFrame{IsApplicationError: typ == applicationCloseFrameType}
ec, err := quicvarint.Read(r) ec, err := quicvarint.Read(r)
if err != nil { if err != nil {
@ -53,7 +53,7 @@ func parseConnectionCloseFrame(r *bytes.Reader, typ uint64, _ protocol.VersionNu
} }
// Length of a written frame // Length of a written frame
func (f *ConnectionCloseFrame) Length(protocol.VersionNumber) protocol.ByteCount { func (f *ConnectionCloseFrame) Length(protocol.Version) protocol.ByteCount {
length := 1 + quicvarint.Len(f.ErrorCode) + quicvarint.Len(uint64(len(f.ReasonPhrase))) + protocol.ByteCount(len(f.ReasonPhrase)) length := 1 + quicvarint.Len(f.ErrorCode) + quicvarint.Len(uint64(len(f.ReasonPhrase))) + protocol.ByteCount(len(f.ReasonPhrase))
if !f.IsApplicationError { if !f.IsApplicationError {
length += quicvarint.Len(f.FrameType) // for the frame type length += quicvarint.Len(f.FrameType) // for the frame type
@ -61,7 +61,7 @@ func (f *ConnectionCloseFrame) Length(protocol.VersionNumber) protocol.ByteCount
return length return length
} }
func (f *ConnectionCloseFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, error) { func (f *ConnectionCloseFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {
if f.IsApplicationError { if f.IsApplicationError {
b = append(b, applicationCloseFrameType) b = append(b, applicationCloseFrameType)
} else { } else {

View File

@ -14,7 +14,7 @@ type CryptoFrame struct {
Data []byte Data []byte
} }
func parseCryptoFrame(r *bytes.Reader, _ protocol.VersionNumber) (*CryptoFrame, error) { func parseCryptoFrame(r *bytes.Reader, _ protocol.Version) (*CryptoFrame, error) {
frame := &CryptoFrame{} frame := &CryptoFrame{}
offset, err := quicvarint.Read(r) offset, err := quicvarint.Read(r)
if err != nil { if err != nil {
@ -38,7 +38,7 @@ func parseCryptoFrame(r *bytes.Reader, _ protocol.VersionNumber) (*CryptoFrame,
return frame, nil return frame, nil
} }
func (f *CryptoFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, error) { func (f *CryptoFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {
b = append(b, cryptoFrameType) b = append(b, cryptoFrameType)
b = quicvarint.Append(b, uint64(f.Offset)) b = quicvarint.Append(b, uint64(f.Offset))
b = quicvarint.Append(b, uint64(len(f.Data))) b = quicvarint.Append(b, uint64(len(f.Data)))
@ -47,7 +47,7 @@ func (f *CryptoFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, error)
} }
// Length of a written frame // Length of a written frame
func (f *CryptoFrame) Length(_ protocol.VersionNumber) protocol.ByteCount { func (f *CryptoFrame) Length(_ protocol.Version) protocol.ByteCount {
return 1 + quicvarint.Len(uint64(f.Offset)) + quicvarint.Len(uint64(len(f.Data))) + protocol.ByteCount(len(f.Data)) return 1 + quicvarint.Len(uint64(f.Offset)) + quicvarint.Len(uint64(len(f.Data))) + protocol.ByteCount(len(f.Data))
} }
@ -71,7 +71,7 @@ func (f *CryptoFrame) MaxDataLen(maxSize protocol.ByteCount) protocol.ByteCount
// The frame might not be split if: // The frame might not be split if:
// * the size is large enough to fit the whole frame // * the size is large enough to fit the whole frame
// * the size is too small to fit even a 1-byte frame. In that case, the frame returned is nil. // * the size is too small to fit even a 1-byte frame. In that case, the frame returned is nil.
func (f *CryptoFrame) MaybeSplitOffFrame(maxSize protocol.ByteCount, version protocol.VersionNumber) (*CryptoFrame, bool /* was splitting required */) { func (f *CryptoFrame) MaybeSplitOffFrame(maxSize protocol.ByteCount, version protocol.Version) (*CryptoFrame, bool /* was splitting required */) {
if f.Length(version) <= maxSize { if f.Length(version) <= maxSize {
return nil, false return nil, false
} }

View File

@ -12,7 +12,7 @@ type DataBlockedFrame struct {
MaximumData protocol.ByteCount MaximumData protocol.ByteCount
} }
func parseDataBlockedFrame(r *bytes.Reader, _ protocol.VersionNumber) (*DataBlockedFrame, error) { func parseDataBlockedFrame(r *bytes.Reader, _ protocol.Version) (*DataBlockedFrame, error) {
offset, err := quicvarint.Read(r) offset, err := quicvarint.Read(r)
if err != nil { if err != nil {
return nil, err return nil, err
@ -20,12 +20,12 @@ func parseDataBlockedFrame(r *bytes.Reader, _ protocol.VersionNumber) (*DataBloc
return &DataBlockedFrame{MaximumData: protocol.ByteCount(offset)}, nil return &DataBlockedFrame{MaximumData: protocol.ByteCount(offset)}, nil
} }
func (f *DataBlockedFrame) Append(b []byte, version protocol.VersionNumber) ([]byte, error) { func (f *DataBlockedFrame) Append(b []byte, version protocol.Version) ([]byte, error) {
b = append(b, dataBlockedFrameType) b = append(b, dataBlockedFrameType)
return quicvarint.Append(b, uint64(f.MaximumData)), nil return quicvarint.Append(b, uint64(f.MaximumData)), nil
} }
// Length of a written frame // Length of a written frame
func (f *DataBlockedFrame) Length(version protocol.VersionNumber) protocol.ByteCount { func (f *DataBlockedFrame) Length(version protocol.Version) protocol.ByteCount {
return 1 + quicvarint.Len(uint64(f.MaximumData)) return 1 + quicvarint.Len(uint64(f.MaximumData))
} }

View File

@ -20,7 +20,7 @@ type DatagramFrame struct {
Data []byte Data []byte
} }
func parseDatagramFrame(r *bytes.Reader, typ uint64, _ protocol.VersionNumber) (*DatagramFrame, error) { func parseDatagramFrame(r *bytes.Reader, typ uint64, _ protocol.Version) (*DatagramFrame, error) {
f := &DatagramFrame{} f := &DatagramFrame{}
f.DataLenPresent = typ&0x1 > 0 f.DataLenPresent = typ&0x1 > 0
@ -45,7 +45,7 @@ func parseDatagramFrame(r *bytes.Reader, typ uint64, _ protocol.VersionNumber) (
return f, nil return f, nil
} }
func (f *DatagramFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, error) { func (f *DatagramFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {
typ := uint8(0x30) typ := uint8(0x30)
if f.DataLenPresent { if f.DataLenPresent {
typ ^= 0b1 typ ^= 0b1
@ -59,7 +59,7 @@ func (f *DatagramFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, erro
} }
// MaxDataLen returns the maximum data length // MaxDataLen returns the maximum data length
func (f *DatagramFrame) MaxDataLen(maxSize protocol.ByteCount, version protocol.VersionNumber) protocol.ByteCount { func (f *DatagramFrame) MaxDataLen(maxSize protocol.ByteCount, version protocol.Version) protocol.ByteCount {
headerLen := protocol.ByteCount(1) headerLen := protocol.ByteCount(1)
if f.DataLenPresent { if f.DataLenPresent {
// pretend that the data size will be 1 bytes // pretend that the data size will be 1 bytes
@ -77,7 +77,7 @@ func (f *DatagramFrame) MaxDataLen(maxSize protocol.ByteCount, version protocol.
} }
// Length of a written frame // Length of a written frame
func (f *DatagramFrame) Length(_ protocol.VersionNumber) protocol.ByteCount { func (f *DatagramFrame) Length(_ protocol.Version) protocol.ByteCount {
length := 1 + protocol.ByteCount(len(f.Data)) length := 1 + protocol.ByteCount(len(f.Data))
if f.DataLenPresent { if f.DataLenPresent {
length += quicvarint.Len(uint64(len(f.Data))) length += quicvarint.Len(uint64(len(f.Data)))

View File

@ -32,7 +32,7 @@ type ExtendedHeader struct {
parsedLen protocol.ByteCount parsedLen protocol.ByteCount
} }
func (h *ExtendedHeader) parse(b *bytes.Reader, v protocol.VersionNumber) (bool /* reserved bits valid */, error) { func (h *ExtendedHeader) parse(b *bytes.Reader, v protocol.Version) (bool /* reserved bits valid */, error) {
startLen := b.Len() startLen := b.Len()
// read the (now unencrypted) first byte // read the (now unencrypted) first byte
var err error var err error
@ -51,7 +51,7 @@ func (h *ExtendedHeader) parse(b *bytes.Reader, v protocol.VersionNumber) (bool
return reservedBitsValid, err return reservedBitsValid, err
} }
func (h *ExtendedHeader) parseLongHeader(b *bytes.Reader, _ protocol.VersionNumber) (bool /* reserved bits valid */, error) { func (h *ExtendedHeader) parseLongHeader(b *bytes.Reader, _ protocol.Version) (bool /* reserved bits valid */, error) {
if err := h.readPacketNumber(b); err != nil { if err := h.readPacketNumber(b); err != nil {
return false, err return false, err
} }
@ -95,7 +95,7 @@ func (h *ExtendedHeader) readPacketNumber(b *bytes.Reader) error {
} }
// Append appends the Header. // Append appends the Header.
func (h *ExtendedHeader) Append(b []byte, v protocol.VersionNumber) ([]byte, error) { func (h *ExtendedHeader) Append(b []byte, v protocol.Version) ([]byte, error) {
if h.DestConnectionID.Len() > protocol.MaxConnIDLen { if h.DestConnectionID.Len() > protocol.MaxConnIDLen {
return nil, fmt.Errorf("invalid connection ID length: %d bytes", h.DestConnectionID.Len()) return nil, fmt.Errorf("invalid connection ID length: %d bytes", h.DestConnectionID.Len())
} }
@ -162,7 +162,7 @@ func (h *ExtendedHeader) ParsedLen() protocol.ByteCount {
} }
// GetLength determines the length of the Header. // GetLength determines the length of the Header.
func (h *ExtendedHeader) GetLength(_ protocol.VersionNumber) protocol.ByteCount { func (h *ExtendedHeader) GetLength(_ protocol.Version) protocol.ByteCount {
length := 1 /* type byte */ + 4 /* version */ + 1 /* dest conn ID len */ + protocol.ByteCount(h.DestConnectionID.Len()) + 1 /* src conn ID len */ + protocol.ByteCount(h.SrcConnectionID.Len()) + protocol.ByteCount(h.PacketNumberLen) + 2 /* length */ length := 1 /* type byte */ + 4 /* version */ + 1 /* dest conn ID len */ + protocol.ByteCount(h.DestConnectionID.Len()) + 1 /* src conn ID len */ + protocol.ByteCount(h.SrcConnectionID.Len()) + protocol.ByteCount(h.PacketNumberLen) + 2 /* length */
if h.Type == protocol.PacketTypeInitial { if h.Type == protocol.PacketTypeInitial {
length += quicvarint.Len(uint64(len(h.Token))) + protocol.ByteCount(len(h.Token)) length += quicvarint.Len(uint64(len(h.Token))) + protocol.ByteCount(len(h.Token))

View File

@ -36,7 +36,8 @@ const (
handshakeDoneFrameType = 0x1e handshakeDoneFrameType = 0x1e
) )
type frameParser struct { // The FrameParser parses QUIC frames, one by one.
type FrameParser struct {
r bytes.Reader // cached bytes.Reader, so we don't have to repeatedly allocate them r bytes.Reader // cached bytes.Reader, so we don't have to repeatedly allocate them
ackDelayExponent uint8 ackDelayExponent uint8
@ -47,11 +48,9 @@ type frameParser struct {
ackFrame *AckFrame ackFrame *AckFrame
} }
var _ FrameParser = &frameParser{}
// NewFrameParser creates a new frame parser. // NewFrameParser creates a new frame parser.
func NewFrameParser(supportsDatagrams bool) *frameParser { func NewFrameParser(supportsDatagrams bool) *FrameParser {
return &frameParser{ return &FrameParser{
r: *bytes.NewReader(nil), r: *bytes.NewReader(nil),
supportsDatagrams: supportsDatagrams, supportsDatagrams: supportsDatagrams,
ackFrame: &AckFrame{}, ackFrame: &AckFrame{},
@ -60,7 +59,7 @@ func NewFrameParser(supportsDatagrams bool) *frameParser {
// ParseNext parses the next frame. // ParseNext parses the next frame.
// It skips PADDING frames. // It skips PADDING frames.
func (p *frameParser) ParseNext(data []byte, encLevel protocol.EncryptionLevel, v protocol.VersionNumber) (int, Frame, error) { func (p *FrameParser) ParseNext(data []byte, encLevel protocol.EncryptionLevel, v protocol.Version) (int, Frame, error) {
startLen := len(data) startLen := len(data)
p.r.Reset(data) p.r.Reset(data)
frame, err := p.parseNext(&p.r, encLevel, v) frame, err := p.parseNext(&p.r, encLevel, v)
@ -69,7 +68,7 @@ func (p *frameParser) ParseNext(data []byte, encLevel protocol.EncryptionLevel,
return n, frame, err return n, frame, err
} }
func (p *frameParser) parseNext(r *bytes.Reader, encLevel protocol.EncryptionLevel, v protocol.VersionNumber) (Frame, error) { func (p *FrameParser) parseNext(r *bytes.Reader, encLevel protocol.EncryptionLevel, v protocol.Version) (Frame, error) {
for r.Len() != 0 { for r.Len() != 0 {
typ, err := quicvarint.Read(r) typ, err := quicvarint.Read(r)
if err != nil { if err != nil {
@ -95,7 +94,7 @@ func (p *frameParser) parseNext(r *bytes.Reader, encLevel protocol.EncryptionLev
return nil, nil return nil, nil
} }
func (p *frameParser) parseFrame(r *bytes.Reader, typ uint64, encLevel protocol.EncryptionLevel, v protocol.VersionNumber) (Frame, error) { func (p *FrameParser) parseFrame(r *bytes.Reader, typ uint64, encLevel protocol.EncryptionLevel, v protocol.Version) (Frame, error) {
var frame Frame var frame Frame
var err error var err error
if typ&0xf8 == 0x8 { if typ&0xf8 == 0x8 {
@ -163,7 +162,7 @@ func (p *frameParser) parseFrame(r *bytes.Reader, typ uint64, encLevel protocol.
return frame, nil return frame, nil
} }
func (p *frameParser) isAllowedAtEncLevel(f Frame, encLevel protocol.EncryptionLevel) bool { func (p *FrameParser) isAllowedAtEncLevel(f Frame, encLevel protocol.EncryptionLevel) bool {
switch encLevel { switch encLevel {
case protocol.EncryptionInitial, protocol.EncryptionHandshake: case protocol.EncryptionInitial, protocol.EncryptionHandshake:
switch f.(type) { switch f.(type) {
@ -186,6 +185,8 @@ func (p *frameParser) isAllowedAtEncLevel(f Frame, encLevel protocol.EncryptionL
} }
} }
func (p *frameParser) SetAckDelayExponent(exp uint8) { // SetAckDelayExponent sets the acknowledgment delay exponent (sent in the transport parameters).
// This value is used to scale the ACK Delay field in the ACK frame.
func (p *FrameParser) SetAckDelayExponent(exp uint8) {
p.ackDelayExponent = exp p.ackDelayExponent = exp
} }

View File

@ -7,11 +7,11 @@ import (
// A HandshakeDoneFrame is a HANDSHAKE_DONE frame // A HandshakeDoneFrame is a HANDSHAKE_DONE frame
type HandshakeDoneFrame struct{} type HandshakeDoneFrame struct{}
func (f *HandshakeDoneFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, error) { func (f *HandshakeDoneFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {
return append(b, handshakeDoneFrameType), nil return append(b, handshakeDoneFrameType), nil
} }
// Length of a written frame // Length of a written frame
func (f *HandshakeDoneFrame) Length(_ protocol.VersionNumber) protocol.ByteCount { func (f *HandshakeDoneFrame) Length(_ protocol.Version) protocol.ByteCount {
return 1 return 1
} }

View File

@ -85,11 +85,11 @@ func IsLongHeaderPacket(firstByte byte) bool {
// ParseVersion parses the QUIC version. // ParseVersion parses the QUIC version.
// It should only be called for Long Header packets (Short Header packets don't contain a version number). // It should only be called for Long Header packets (Short Header packets don't contain a version number).
func ParseVersion(data []byte) (protocol.VersionNumber, error) { func ParseVersion(data []byte) (protocol.Version, error) {
if len(data) < 5 { if len(data) < 5 {
return 0, io.EOF return 0, io.EOF
} }
return protocol.VersionNumber(binary.BigEndian.Uint32(data[1:5])), nil return protocol.Version(binary.BigEndian.Uint32(data[1:5])), nil
} }
// IsVersionNegotiationPacket says if this is a version negotiation packet // IsVersionNegotiationPacket says if this is a version negotiation packet
@ -109,7 +109,7 @@ func Is0RTTPacket(b []byte) bool {
if !IsLongHeaderPacket(b[0]) { if !IsLongHeaderPacket(b[0]) {
return false return false
} }
version := protocol.VersionNumber(binary.BigEndian.Uint32(b[1:5])) version := protocol.Version(binary.BigEndian.Uint32(b[1:5]))
//nolint:exhaustive // We only need to test QUIC versions that we support. //nolint:exhaustive // We only need to test QUIC versions that we support.
switch version { switch version {
case protocol.Version1: case protocol.Version1:
@ -128,7 +128,7 @@ type Header struct {
typeByte byte typeByte byte
Type protocol.PacketType Type protocol.PacketType
Version protocol.VersionNumber Version protocol.Version
SrcConnectionID protocol.ConnectionID SrcConnectionID protocol.ConnectionID
DestConnectionID protocol.ConnectionID DestConnectionID protocol.ConnectionID
@ -184,7 +184,7 @@ func (h *Header) parseLongHeader(b *bytes.Reader) error {
if err != nil { if err != nil {
return err return err
} }
h.Version = protocol.VersionNumber(v) h.Version = protocol.Version(v)
if h.Version != 0 && h.typeByte&0x40 == 0 { if h.Version != 0 && h.typeByte&0x40 == 0 {
return errors.New("not a QUIC packet") return errors.New("not a QUIC packet")
} }
@ -278,7 +278,7 @@ func (h *Header) ParsedLen() protocol.ByteCount {
// ParseExtended parses the version dependent part of the header. // ParseExtended parses the version dependent part of the header.
// The Reader has to be set such that it points to the first byte of the header. // The Reader has to be set such that it points to the first byte of the header.
func (h *Header) ParseExtended(b *bytes.Reader, ver protocol.VersionNumber) (*ExtendedHeader, error) { func (h *Header) ParseExtended(b *bytes.Reader, ver protocol.Version) (*ExtendedHeader, error) {
extHdr := h.toExtendedHeader() extHdr := h.toExtendedHeader()
reservedBitsValid, err := extHdr.parse(b, ver) reservedBitsValid, err := extHdr.parse(b, ver)
if err != nil { if err != nil {

View File

@ -6,12 +6,6 @@ import (
// A Frame in QUIC // A Frame in QUIC
type Frame interface { type Frame interface {
Append(b []byte, version protocol.VersionNumber) ([]byte, error) Append(b []byte, version protocol.Version) ([]byte, error)
Length(version protocol.VersionNumber) protocol.ByteCount Length(version protocol.Version) protocol.ByteCount
}
// A FrameParser parses QUIC frames, one by one.
type FrameParser interface {
ParseNext([]byte, protocol.EncryptionLevel, protocol.VersionNumber) (int, Frame, error)
SetAckDelayExponent(uint8)
} }

View File

@ -63,7 +63,9 @@ func LogFrame(logger utils.Logger, frame Frame, sent bool) {
logger.Debugf("\t%s &wire.StreamsBlockedFrame{Type: bidi, MaxStreams: %d}", dir, f.StreamLimit) logger.Debugf("\t%s &wire.StreamsBlockedFrame{Type: bidi, MaxStreams: %d}", dir, f.StreamLimit)
} }
case *NewConnectionIDFrame: case *NewConnectionIDFrame:
logger.Debugf("\t%s &wire.NewConnectionIDFrame{SequenceNumber: %d, ConnectionID: %s, StatelessResetToken: %#x}", dir, f.SequenceNumber, f.ConnectionID, f.StatelessResetToken) logger.Debugf("\t%s &wire.NewConnectionIDFrame{SequenceNumber: %d, RetirePriorTo: %d, ConnectionID: %s, StatelessResetToken: %#x}", dir, f.SequenceNumber, f.RetirePriorTo, f.ConnectionID, f.StatelessResetToken)
case *RetireConnectionIDFrame:
logger.Debugf("\t%s &wire.RetireConnectionIDFrame{SequenceNumber: %d}", dir, f.SequenceNumber)
case *NewTokenFrame: case *NewTokenFrame:
logger.Debugf("\t%s &wire.NewTokenFrame{Token: %#x}", dir, f.Token) logger.Debugf("\t%s &wire.NewTokenFrame{Token: %#x}", dir, f.Token)
default: default:

View File

@ -13,7 +13,7 @@ type MaxDataFrame struct {
} }
// parseMaxDataFrame parses a MAX_DATA frame // parseMaxDataFrame parses a MAX_DATA frame
func parseMaxDataFrame(r *bytes.Reader, _ protocol.VersionNumber) (*MaxDataFrame, error) { func parseMaxDataFrame(r *bytes.Reader, _ protocol.Version) (*MaxDataFrame, error) {
frame := &MaxDataFrame{} frame := &MaxDataFrame{}
byteOffset, err := quicvarint.Read(r) byteOffset, err := quicvarint.Read(r)
if err != nil { if err != nil {
@ -23,13 +23,13 @@ func parseMaxDataFrame(r *bytes.Reader, _ protocol.VersionNumber) (*MaxDataFrame
return frame, nil return frame, nil
} }
func (f *MaxDataFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, error) { func (f *MaxDataFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {
b = append(b, maxDataFrameType) b = append(b, maxDataFrameType)
b = quicvarint.Append(b, uint64(f.MaximumData)) b = quicvarint.Append(b, uint64(f.MaximumData))
return b, nil return b, nil
} }
// Length of a written frame // Length of a written frame
func (f *MaxDataFrame) Length(_ protocol.VersionNumber) protocol.ByteCount { func (f *MaxDataFrame) Length(_ protocol.Version) protocol.ByteCount {
return 1 + quicvarint.Len(uint64(f.MaximumData)) return 1 + quicvarint.Len(uint64(f.MaximumData))
} }

View File

@ -13,7 +13,7 @@ type MaxStreamDataFrame struct {
MaximumStreamData protocol.ByteCount MaximumStreamData protocol.ByteCount
} }
func parseMaxStreamDataFrame(r *bytes.Reader, _ protocol.VersionNumber) (*MaxStreamDataFrame, error) { func parseMaxStreamDataFrame(r *bytes.Reader, _ protocol.Version) (*MaxStreamDataFrame, error) {
sid, err := quicvarint.Read(r) sid, err := quicvarint.Read(r)
if err != nil { if err != nil {
return nil, err return nil, err
@ -29,7 +29,7 @@ func parseMaxStreamDataFrame(r *bytes.Reader, _ protocol.VersionNumber) (*MaxStr
}, nil }, nil
} }
func (f *MaxStreamDataFrame) Append(b []byte, version protocol.VersionNumber) ([]byte, error) { func (f *MaxStreamDataFrame) Append(b []byte, version protocol.Version) ([]byte, error) {
b = append(b, maxStreamDataFrameType) b = append(b, maxStreamDataFrameType)
b = quicvarint.Append(b, uint64(f.StreamID)) b = quicvarint.Append(b, uint64(f.StreamID))
b = quicvarint.Append(b, uint64(f.MaximumStreamData)) b = quicvarint.Append(b, uint64(f.MaximumStreamData))
@ -37,6 +37,6 @@ func (f *MaxStreamDataFrame) Append(b []byte, version protocol.VersionNumber) ([
} }
// Length of a written frame // Length of a written frame
func (f *MaxStreamDataFrame) Length(version protocol.VersionNumber) protocol.ByteCount { func (f *MaxStreamDataFrame) Length(version protocol.Version) protocol.ByteCount {
return 1 + quicvarint.Len(uint64(f.StreamID)) + quicvarint.Len(uint64(f.MaximumStreamData)) return 1 + quicvarint.Len(uint64(f.StreamID)) + quicvarint.Len(uint64(f.MaximumStreamData))
} }

View File

@ -14,7 +14,7 @@ type MaxStreamsFrame struct {
MaxStreamNum protocol.StreamNum MaxStreamNum protocol.StreamNum
} }
func parseMaxStreamsFrame(r *bytes.Reader, typ uint64, _ protocol.VersionNumber) (*MaxStreamsFrame, error) { func parseMaxStreamsFrame(r *bytes.Reader, typ uint64, _ protocol.Version) (*MaxStreamsFrame, error) {
f := &MaxStreamsFrame{} f := &MaxStreamsFrame{}
switch typ { switch typ {
case bidiMaxStreamsFrameType: case bidiMaxStreamsFrameType:
@ -33,7 +33,7 @@ func parseMaxStreamsFrame(r *bytes.Reader, typ uint64, _ protocol.VersionNumber)
return f, nil return f, nil
} }
func (f *MaxStreamsFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, error) { func (f *MaxStreamsFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {
switch f.Type { switch f.Type {
case protocol.StreamTypeBidi: case protocol.StreamTypeBidi:
b = append(b, bidiMaxStreamsFrameType) b = append(b, bidiMaxStreamsFrameType)
@ -45,6 +45,6 @@ func (f *MaxStreamsFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, er
} }
// Length of a written frame // Length of a written frame
func (f *MaxStreamsFrame) Length(protocol.VersionNumber) protocol.ByteCount { func (f *MaxStreamsFrame) Length(protocol.Version) protocol.ByteCount {
return 1 + quicvarint.Len(uint64(f.MaxStreamNum)) return 1 + quicvarint.Len(uint64(f.MaxStreamNum))
} }

View File

@ -18,7 +18,7 @@ type NewConnectionIDFrame struct {
StatelessResetToken protocol.StatelessResetToken StatelessResetToken protocol.StatelessResetToken
} }
func parseNewConnectionIDFrame(r *bytes.Reader, _ protocol.VersionNumber) (*NewConnectionIDFrame, error) { func parseNewConnectionIDFrame(r *bytes.Reader, _ protocol.Version) (*NewConnectionIDFrame, error) {
seq, err := quicvarint.Read(r) seq, err := quicvarint.Read(r)
if err != nil { if err != nil {
return nil, err return nil, err
@ -57,7 +57,7 @@ func parseNewConnectionIDFrame(r *bytes.Reader, _ protocol.VersionNumber) (*NewC
return frame, nil return frame, nil
} }
func (f *NewConnectionIDFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, error) { func (f *NewConnectionIDFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {
b = append(b, newConnectionIDFrameType) b = append(b, newConnectionIDFrameType)
b = quicvarint.Append(b, f.SequenceNumber) b = quicvarint.Append(b, f.SequenceNumber)
b = quicvarint.Append(b, f.RetirePriorTo) b = quicvarint.Append(b, f.RetirePriorTo)
@ -72,6 +72,6 @@ func (f *NewConnectionIDFrame) Append(b []byte, _ protocol.VersionNumber) ([]byt
} }
// Length of a written frame // Length of a written frame
func (f *NewConnectionIDFrame) Length(protocol.VersionNumber) protocol.ByteCount { func (f *NewConnectionIDFrame) Length(protocol.Version) protocol.ByteCount {
return 1 + quicvarint.Len(f.SequenceNumber) + quicvarint.Len(f.RetirePriorTo) + 1 /* connection ID length */ + protocol.ByteCount(f.ConnectionID.Len()) + 16 return 1 + quicvarint.Len(f.SequenceNumber) + quicvarint.Len(f.RetirePriorTo) + 1 /* connection ID length */ + protocol.ByteCount(f.ConnectionID.Len()) + 16
} }

View File

@ -14,7 +14,7 @@ type NewTokenFrame struct {
Token []byte Token []byte
} }
func parseNewTokenFrame(r *bytes.Reader, _ protocol.VersionNumber) (*NewTokenFrame, error) { func parseNewTokenFrame(r *bytes.Reader, _ protocol.Version) (*NewTokenFrame, error) {
tokenLen, err := quicvarint.Read(r) tokenLen, err := quicvarint.Read(r)
if err != nil { if err != nil {
return nil, err return nil, err
@ -32,7 +32,7 @@ func parseNewTokenFrame(r *bytes.Reader, _ protocol.VersionNumber) (*NewTokenFra
return &NewTokenFrame{Token: token}, nil return &NewTokenFrame{Token: token}, nil
} }
func (f *NewTokenFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, error) { func (f *NewTokenFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {
b = append(b, newTokenFrameType) b = append(b, newTokenFrameType)
b = quicvarint.Append(b, uint64(len(f.Token))) b = quicvarint.Append(b, uint64(len(f.Token)))
b = append(b, f.Token...) b = append(b, f.Token...)
@ -40,6 +40,6 @@ func (f *NewTokenFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, erro
} }
// Length of a written frame // Length of a written frame
func (f *NewTokenFrame) Length(protocol.VersionNumber) protocol.ByteCount { func (f *NewTokenFrame) Length(protocol.Version) protocol.ByteCount {
return 1 + quicvarint.Len(uint64(len(f.Token))) + protocol.ByteCount(len(f.Token)) return 1 + quicvarint.Len(uint64(len(f.Token))) + protocol.ByteCount(len(f.Token))
} }

View File

@ -12,7 +12,7 @@ type PathChallengeFrame struct {
Data [8]byte Data [8]byte
} }
func parsePathChallengeFrame(r *bytes.Reader, _ protocol.VersionNumber) (*PathChallengeFrame, error) { func parsePathChallengeFrame(r *bytes.Reader, _ protocol.Version) (*PathChallengeFrame, error) {
frame := &PathChallengeFrame{} frame := &PathChallengeFrame{}
if _, err := io.ReadFull(r, frame.Data[:]); err != nil { if _, err := io.ReadFull(r, frame.Data[:]); err != nil {
if err == io.ErrUnexpectedEOF { if err == io.ErrUnexpectedEOF {
@ -23,13 +23,13 @@ func parsePathChallengeFrame(r *bytes.Reader, _ protocol.VersionNumber) (*PathCh
return frame, nil return frame, nil
} }
func (f *PathChallengeFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, error) { func (f *PathChallengeFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {
b = append(b, pathChallengeFrameType) b = append(b, pathChallengeFrameType)
b = append(b, f.Data[:]...) b = append(b, f.Data[:]...)
return b, nil return b, nil
} }
// Length of a written frame // Length of a written frame
func (f *PathChallengeFrame) Length(_ protocol.VersionNumber) protocol.ByteCount { func (f *PathChallengeFrame) Length(_ protocol.Version) protocol.ByteCount {
return 1 + 8 return 1 + 8
} }

View File

@ -12,7 +12,7 @@ type PathResponseFrame struct {
Data [8]byte Data [8]byte
} }
func parsePathResponseFrame(r *bytes.Reader, _ protocol.VersionNumber) (*PathResponseFrame, error) { func parsePathResponseFrame(r *bytes.Reader, _ protocol.Version) (*PathResponseFrame, error) {
frame := &PathResponseFrame{} frame := &PathResponseFrame{}
if _, err := io.ReadFull(r, frame.Data[:]); err != nil { if _, err := io.ReadFull(r, frame.Data[:]); err != nil {
if err == io.ErrUnexpectedEOF { if err == io.ErrUnexpectedEOF {
@ -23,13 +23,13 @@ func parsePathResponseFrame(r *bytes.Reader, _ protocol.VersionNumber) (*PathRes
return frame, nil return frame, nil
} }
func (f *PathResponseFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, error) { func (f *PathResponseFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {
b = append(b, pathResponseFrameType) b = append(b, pathResponseFrameType)
b = append(b, f.Data[:]...) b = append(b, f.Data[:]...)
return b, nil return b, nil
} }
// Length of a written frame // Length of a written frame
func (f *PathResponseFrame) Length(_ protocol.VersionNumber) protocol.ByteCount { func (f *PathResponseFrame) Length(_ protocol.Version) protocol.ByteCount {
return 1 + 8 return 1 + 8
} }

View File

@ -7,11 +7,11 @@ import (
// A PingFrame is a PING frame // A PingFrame is a PING frame
type PingFrame struct{} type PingFrame struct{}
func (f *PingFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, error) { func (f *PingFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {
return append(b, pingFrameType), nil return append(b, pingFrameType), nil
} }
// Length of a written frame // Length of a written frame
func (f *PingFrame) Length(_ protocol.VersionNumber) protocol.ByteCount { func (f *PingFrame) Length(_ protocol.Version) protocol.ByteCount {
return 1 return 1
} }

View File

@ -15,7 +15,7 @@ type ResetStreamFrame struct {
FinalSize protocol.ByteCount FinalSize protocol.ByteCount
} }
func parseResetStreamFrame(r *bytes.Reader, _ protocol.VersionNumber) (*ResetStreamFrame, error) { func parseResetStreamFrame(r *bytes.Reader, _ protocol.Version) (*ResetStreamFrame, error) {
var streamID protocol.StreamID var streamID protocol.StreamID
var byteOffset protocol.ByteCount var byteOffset protocol.ByteCount
sid, err := quicvarint.Read(r) sid, err := quicvarint.Read(r)
@ -40,7 +40,7 @@ func parseResetStreamFrame(r *bytes.Reader, _ protocol.VersionNumber) (*ResetStr
}, nil }, nil
} }
func (f *ResetStreamFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, error) { func (f *ResetStreamFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {
b = append(b, resetStreamFrameType) b = append(b, resetStreamFrameType)
b = quicvarint.Append(b, uint64(f.StreamID)) b = quicvarint.Append(b, uint64(f.StreamID))
b = quicvarint.Append(b, uint64(f.ErrorCode)) b = quicvarint.Append(b, uint64(f.ErrorCode))
@ -49,6 +49,6 @@ func (f *ResetStreamFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, e
} }
// Length of a written frame // Length of a written frame
func (f *ResetStreamFrame) Length(version protocol.VersionNumber) protocol.ByteCount { func (f *ResetStreamFrame) Length(version protocol.Version) protocol.ByteCount {
return 1 + quicvarint.Len(uint64(f.StreamID)) + quicvarint.Len(uint64(f.ErrorCode)) + quicvarint.Len(uint64(f.FinalSize)) return 1 + quicvarint.Len(uint64(f.StreamID)) + quicvarint.Len(uint64(f.ErrorCode)) + quicvarint.Len(uint64(f.FinalSize))
} }

View File

@ -12,7 +12,7 @@ type RetireConnectionIDFrame struct {
SequenceNumber uint64 SequenceNumber uint64
} }
func parseRetireConnectionIDFrame(r *bytes.Reader, _ protocol.VersionNumber) (*RetireConnectionIDFrame, error) { func parseRetireConnectionIDFrame(r *bytes.Reader, _ protocol.Version) (*RetireConnectionIDFrame, error) {
seq, err := quicvarint.Read(r) seq, err := quicvarint.Read(r)
if err != nil { if err != nil {
return nil, err return nil, err
@ -20,13 +20,13 @@ func parseRetireConnectionIDFrame(r *bytes.Reader, _ protocol.VersionNumber) (*R
return &RetireConnectionIDFrame{SequenceNumber: seq}, nil return &RetireConnectionIDFrame{SequenceNumber: seq}, nil
} }
func (f *RetireConnectionIDFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, error) { func (f *RetireConnectionIDFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {
b = append(b, retireConnectionIDFrameType) b = append(b, retireConnectionIDFrameType)
b = quicvarint.Append(b, f.SequenceNumber) b = quicvarint.Append(b, f.SequenceNumber)
return b, nil return b, nil
} }
// Length of a written frame // Length of a written frame
func (f *RetireConnectionIDFrame) Length(protocol.VersionNumber) protocol.ByteCount { func (f *RetireConnectionIDFrame) Length(protocol.Version) protocol.ByteCount {
return 1 + quicvarint.Len(f.SequenceNumber) return 1 + quicvarint.Len(f.SequenceNumber)
} }

View File

@ -15,7 +15,7 @@ type StopSendingFrame struct {
} }
// parseStopSendingFrame parses a STOP_SENDING frame // parseStopSendingFrame parses a STOP_SENDING frame
func parseStopSendingFrame(r *bytes.Reader, _ protocol.VersionNumber) (*StopSendingFrame, error) { func parseStopSendingFrame(r *bytes.Reader, _ protocol.Version) (*StopSendingFrame, error) {
streamID, err := quicvarint.Read(r) streamID, err := quicvarint.Read(r)
if err != nil { if err != nil {
return nil, err return nil, err
@ -32,11 +32,11 @@ func parseStopSendingFrame(r *bytes.Reader, _ protocol.VersionNumber) (*StopSend
} }
// Length of a written frame // Length of a written frame
func (f *StopSendingFrame) Length(_ protocol.VersionNumber) protocol.ByteCount { func (f *StopSendingFrame) Length(_ protocol.Version) protocol.ByteCount {
return 1 + quicvarint.Len(uint64(f.StreamID)) + quicvarint.Len(uint64(f.ErrorCode)) return 1 + quicvarint.Len(uint64(f.StreamID)) + quicvarint.Len(uint64(f.ErrorCode))
} }
func (f *StopSendingFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, error) { func (f *StopSendingFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {
b = append(b, stopSendingFrameType) b = append(b, stopSendingFrameType)
b = quicvarint.Append(b, uint64(f.StreamID)) b = quicvarint.Append(b, uint64(f.StreamID))
b = quicvarint.Append(b, uint64(f.ErrorCode)) b = quicvarint.Append(b, uint64(f.ErrorCode))

View File

@ -13,7 +13,7 @@ type StreamDataBlockedFrame struct {
MaximumStreamData protocol.ByteCount MaximumStreamData protocol.ByteCount
} }
func parseStreamDataBlockedFrame(r *bytes.Reader, _ protocol.VersionNumber) (*StreamDataBlockedFrame, error) { func parseStreamDataBlockedFrame(r *bytes.Reader, _ protocol.Version) (*StreamDataBlockedFrame, error) {
sid, err := quicvarint.Read(r) sid, err := quicvarint.Read(r)
if err != nil { if err != nil {
return nil, err return nil, err
@ -29,7 +29,7 @@ func parseStreamDataBlockedFrame(r *bytes.Reader, _ protocol.VersionNumber) (*St
}, nil }, nil
} }
func (f *StreamDataBlockedFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, error) { func (f *StreamDataBlockedFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {
b = append(b, 0x15) b = append(b, 0x15)
b = quicvarint.Append(b, uint64(f.StreamID)) b = quicvarint.Append(b, uint64(f.StreamID))
b = quicvarint.Append(b, uint64(f.MaximumStreamData)) b = quicvarint.Append(b, uint64(f.MaximumStreamData))
@ -37,6 +37,6 @@ func (f *StreamDataBlockedFrame) Append(b []byte, _ protocol.VersionNumber) ([]b
} }
// Length of a written frame // Length of a written frame
func (f *StreamDataBlockedFrame) Length(version protocol.VersionNumber) protocol.ByteCount { func (f *StreamDataBlockedFrame) Length(version protocol.Version) protocol.ByteCount {
return 1 + quicvarint.Len(uint64(f.StreamID)) + quicvarint.Len(uint64(f.MaximumStreamData)) return 1 + quicvarint.Len(uint64(f.StreamID)) + quicvarint.Len(uint64(f.MaximumStreamData))
} }

View File

@ -20,7 +20,7 @@ type StreamFrame struct {
fromPool bool fromPool bool
} }
func parseStreamFrame(r *bytes.Reader, typ uint64, _ protocol.VersionNumber) (*StreamFrame, error) { func parseStreamFrame(r *bytes.Reader, typ uint64, _ protocol.Version) (*StreamFrame, error) {
hasOffset := typ&0b100 > 0 hasOffset := typ&0b100 > 0
fin := typ&0b1 > 0 fin := typ&0b1 > 0
hasDataLen := typ&0b10 > 0 hasDataLen := typ&0b10 > 0
@ -79,7 +79,7 @@ func parseStreamFrame(r *bytes.Reader, typ uint64, _ protocol.VersionNumber) (*S
} }
// Write writes a STREAM frame // Write writes a STREAM frame
func (f *StreamFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, error) { func (f *StreamFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {
if len(f.Data) == 0 && !f.Fin { if len(f.Data) == 0 && !f.Fin {
return nil, errors.New("StreamFrame: attempting to write empty frame without FIN") return nil, errors.New("StreamFrame: attempting to write empty frame without FIN")
} }
@ -108,7 +108,7 @@ func (f *StreamFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, error)
} }
// Length returns the total length of the STREAM frame // Length returns the total length of the STREAM frame
func (f *StreamFrame) Length(version protocol.VersionNumber) protocol.ByteCount { func (f *StreamFrame) Length(version protocol.Version) protocol.ByteCount {
length := 1 + quicvarint.Len(uint64(f.StreamID)) length := 1 + quicvarint.Len(uint64(f.StreamID))
if f.Offset != 0 { if f.Offset != 0 {
length += quicvarint.Len(uint64(f.Offset)) length += quicvarint.Len(uint64(f.Offset))
@ -126,7 +126,7 @@ func (f *StreamFrame) DataLen() protocol.ByteCount {
// MaxDataLen returns the maximum data length // MaxDataLen returns the maximum data length
// If 0 is returned, writing will fail (a STREAM frame must contain at least 1 byte of data). // If 0 is returned, writing will fail (a STREAM frame must contain at least 1 byte of data).
func (f *StreamFrame) MaxDataLen(maxSize protocol.ByteCount, version protocol.VersionNumber) protocol.ByteCount { func (f *StreamFrame) MaxDataLen(maxSize protocol.ByteCount, version protocol.Version) protocol.ByteCount {
headerLen := 1 + quicvarint.Len(uint64(f.StreamID)) headerLen := 1 + quicvarint.Len(uint64(f.StreamID))
if f.Offset != 0 { if f.Offset != 0 {
headerLen += quicvarint.Len(uint64(f.Offset)) headerLen += quicvarint.Len(uint64(f.Offset))
@ -151,7 +151,7 @@ func (f *StreamFrame) MaxDataLen(maxSize protocol.ByteCount, version protocol.Ve
// The frame might not be split if: // The frame might not be split if:
// * the size is large enough to fit the whole frame // * the size is large enough to fit the whole frame
// * the size is too small to fit even a 1-byte frame. In that case, the frame returned is nil. // * the size is too small to fit even a 1-byte frame. In that case, the frame returned is nil.
func (f *StreamFrame) MaybeSplitOffFrame(maxSize protocol.ByteCount, version protocol.VersionNumber) (*StreamFrame, bool /* was splitting required */) { func (f *StreamFrame) MaybeSplitOffFrame(maxSize protocol.ByteCount, version protocol.Version) (*StreamFrame, bool /* was splitting required */) {
if maxSize >= f.Length(version) { if maxSize >= f.Length(version) {
return nil, false return nil, false
} }

View File

@ -14,7 +14,7 @@ type StreamsBlockedFrame struct {
StreamLimit protocol.StreamNum StreamLimit protocol.StreamNum
} }
func parseStreamsBlockedFrame(r *bytes.Reader, typ uint64, _ protocol.VersionNumber) (*StreamsBlockedFrame, error) { func parseStreamsBlockedFrame(r *bytes.Reader, typ uint64, _ protocol.Version) (*StreamsBlockedFrame, error) {
f := &StreamsBlockedFrame{} f := &StreamsBlockedFrame{}
switch typ { switch typ {
case bidiStreamBlockedFrameType: case bidiStreamBlockedFrameType:
@ -33,7 +33,7 @@ func parseStreamsBlockedFrame(r *bytes.Reader, typ uint64, _ protocol.VersionNum
return f, nil return f, nil
} }
func (f *StreamsBlockedFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, error) { func (f *StreamsBlockedFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {
switch f.Type { switch f.Type {
case protocol.StreamTypeBidi: case protocol.StreamTypeBidi:
b = append(b, bidiStreamBlockedFrameType) b = append(b, bidiStreamBlockedFrameType)
@ -45,6 +45,6 @@ func (f *StreamsBlockedFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte
} }
// Length of a written frame // Length of a written frame
func (f *StreamsBlockedFrame) Length(_ protocol.VersionNumber) protocol.ByteCount { func (f *StreamsBlockedFrame) Length(_ protocol.Version) protocol.ByteCount {
return 1 + quicvarint.Len(uint64(f.StreamLimit)) return 1 + quicvarint.Len(uint64(f.StreamLimit))
} }

View File

@ -7,7 +7,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"net" "net/netip"
"sort" "sort"
"time" "time"
@ -51,10 +51,7 @@ const (
// PreferredAddress is the value encoding in the preferred_address transport parameter // PreferredAddress is the value encoding in the preferred_address transport parameter
type PreferredAddress struct { type PreferredAddress struct {
IPv4 net.IP IPv4, IPv6 netip.AddrPort
IPv4Port uint16
IPv6 net.IP
IPv6Port uint16
ConnectionID protocol.ConnectionID ConnectionID protocol.ConnectionID
StatelessResetToken protocol.StatelessResetToken StatelessResetToken protocol.StatelessResetToken
} }
@ -218,26 +215,24 @@ func (p *TransportParameters) unmarshal(r *bytes.Reader, sentBy protocol.Perspec
func (p *TransportParameters) readPreferredAddress(r *bytes.Reader, expectedLen int) error { func (p *TransportParameters) readPreferredAddress(r *bytes.Reader, expectedLen int) error {
remainingLen := r.Len() remainingLen := r.Len()
pa := &PreferredAddress{} pa := &PreferredAddress{}
ipv4 := make([]byte, 4) var ipv4 [4]byte
if _, err := io.ReadFull(r, ipv4); err != nil { if _, err := io.ReadFull(r, ipv4[:]); err != nil {
return err return err
} }
pa.IPv4 = net.IP(ipv4)
port, err := utils.BigEndian.ReadUint16(r) port, err := utils.BigEndian.ReadUint16(r)
if err != nil { if err != nil {
return err return err
} }
pa.IPv4Port = port pa.IPv4 = netip.AddrPortFrom(netip.AddrFrom4(ipv4), port)
ipv6 := make([]byte, 16) var ipv6 [16]byte
if _, err := io.ReadFull(r, ipv6); err != nil { if _, err := io.ReadFull(r, ipv6[:]); err != nil {
return err return err
} }
pa.IPv6 = net.IP(ipv6)
port, err = utils.BigEndian.ReadUint16(r) port, err = utils.BigEndian.ReadUint16(r)
if err != nil { if err != nil {
return err return err
} }
pa.IPv6Port = port pa.IPv6 = netip.AddrPortFrom(netip.AddrFrom16(ipv6), port)
connIDLen, err := r.ReadByte() connIDLen, err := r.ReadByte()
if err != nil { if err != nil {
return err return err
@ -384,13 +379,12 @@ func (p *TransportParameters) Marshal(pers protocol.Perspective) []byte {
if p.PreferredAddress != nil { if p.PreferredAddress != nil {
b = quicvarint.Append(b, uint64(preferredAddressParameterID)) b = quicvarint.Append(b, uint64(preferredAddressParameterID))
b = quicvarint.Append(b, 4+2+16+2+1+uint64(p.PreferredAddress.ConnectionID.Len())+16) b = quicvarint.Append(b, 4+2+16+2+1+uint64(p.PreferredAddress.ConnectionID.Len())+16)
ipv4 := p.PreferredAddress.IPv4 ip4 := p.PreferredAddress.IPv4.Addr().As4()
b = append(b, ipv4[len(ipv4)-4:]...) b = append(b, ip4[:]...)
b = append(b, []byte{0, 0}...) b = binary.BigEndian.AppendUint16(b, p.PreferredAddress.IPv4.Port())
binary.BigEndian.PutUint16(b[len(b)-2:], p.PreferredAddress.IPv4Port) ip6 := p.PreferredAddress.IPv6.Addr().As16()
b = append(b, p.PreferredAddress.IPv6...) b = append(b, ip6[:]...)
b = append(b, []byte{0, 0}...) b = binary.BigEndian.AppendUint16(b, p.PreferredAddress.IPv6.Port())
binary.BigEndian.PutUint16(b[len(b)-2:], p.PreferredAddress.IPv6Port)
b = append(b, uint8(p.PreferredAddress.ConnectionID.Len())) b = append(b, uint8(p.PreferredAddress.ConnectionID.Len()))
b = append(b, p.PreferredAddress.ConnectionID.Bytes()...) b = append(b, p.PreferredAddress.ConnectionID.Bytes()...)
b = append(b, p.PreferredAddress.StatelessResetToken[:]...) b = append(b, p.PreferredAddress.StatelessResetToken[:]...)

View File

@ -1,17 +1,15 @@
package wire package wire
import ( import (
"bytes"
"crypto/rand" "crypto/rand"
"encoding/binary" "encoding/binary"
"errors" "errors"
"github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/utils"
) )
// ParseVersionNegotiationPacket parses a Version Negotiation packet. // ParseVersionNegotiationPacket parses a Version Negotiation packet.
func ParseVersionNegotiationPacket(b []byte) (dest, src protocol.ArbitraryLenConnectionID, _ []protocol.VersionNumber, _ error) { func ParseVersionNegotiationPacket(b []byte) (dest, src protocol.ArbitraryLenConnectionID, _ []protocol.Version, _ error) {
n, dest, src, err := ParseArbitraryLenConnectionIDs(b) n, dest, src, err := ParseArbitraryLenConnectionIDs(b)
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, nil, err
@ -25,32 +23,31 @@ func ParseVersionNegotiationPacket(b []byte) (dest, src protocol.ArbitraryLenCon
//nolint:stylecheck //nolint:stylecheck
return nil, nil, nil, errors.New("Version Negotiation packet has a version list with an invalid length") return nil, nil, nil, errors.New("Version Negotiation packet has a version list with an invalid length")
} }
versions := make([]protocol.VersionNumber, len(b)/4) versions := make([]protocol.Version, len(b)/4)
for i := 0; len(b) > 0; i++ { for i := 0; len(b) > 0; i++ {
versions[i] = protocol.VersionNumber(binary.BigEndian.Uint32(b[:4])) versions[i] = protocol.Version(binary.BigEndian.Uint32(b[:4]))
b = b[4:] b = b[4:]
} }
return dest, src, versions, nil return dest, src, versions, nil
} }
// ComposeVersionNegotiation composes a Version Negotiation // ComposeVersionNegotiation composes a Version Negotiation
func ComposeVersionNegotiation(destConnID, srcConnID protocol.ArbitraryLenConnectionID, versions []protocol.VersionNumber) []byte { func ComposeVersionNegotiation(destConnID, srcConnID protocol.ArbitraryLenConnectionID, versions []protocol.Version) []byte {
greasedVersions := protocol.GetGreasedVersions(versions) greasedVersions := protocol.GetGreasedVersions(versions)
expectedLen := 1 /* type byte */ + 4 /* version field */ + 1 /* dest connection ID length field */ + destConnID.Len() + 1 /* src connection ID length field */ + srcConnID.Len() + len(greasedVersions)*4 expectedLen := 1 /* type byte */ + 4 /* version field */ + 1 /* dest connection ID length field */ + destConnID.Len() + 1 /* src connection ID length field */ + srcConnID.Len() + len(greasedVersions)*4
buf := bytes.NewBuffer(make([]byte, 0, expectedLen)) buf := make([]byte, 1+4 /* type byte and version field */, expectedLen)
r := make([]byte, 1) _, _ = rand.Read(buf[:1]) // ignore the error here. It is not critical to have perfect random here.
_, _ = rand.Read(r) // ignore the error here. It is not critical to have perfect random here.
// Setting the "QUIC bit" (0x40) is not required by the RFC, // Setting the "QUIC bit" (0x40) is not required by the RFC,
// but it allows clients to demultiplex QUIC with a long list of other protocols. // but it allows clients to demultiplex QUIC with a long list of other protocols.
// See RFC 9443 and https://mailarchive.ietf.org/arch/msg/quic/oR4kxGKY6mjtPC1CZegY1ED4beg/ for details. // See RFC 9443 and https://mailarchive.ietf.org/arch/msg/quic/oR4kxGKY6mjtPC1CZegY1ED4beg/ for details.
buf.WriteByte(r[0] | 0xc0) buf[0] |= 0xc0
utils.BigEndian.WriteUint32(buf, 0) // version 0 // The next 4 bytes are left at 0 (version number).
buf.WriteByte(uint8(destConnID.Len())) buf = append(buf, uint8(destConnID.Len()))
buf.Write(destConnID.Bytes()) buf = append(buf, destConnID.Bytes()...)
buf.WriteByte(uint8(srcConnID.Len())) buf = append(buf, uint8(srcConnID.Len()))
buf.Write(srcConnID.Bytes()) buf = append(buf, srcConnID.Bytes()...)
for _, v := range greasedVersions { for _, v := range greasedVersions {
utils.BigEndian.WriteUint32(buf, uint32(v)) buf = binary.BigEndian.AppendUint32(buf, uint32(v))
} }
return buf.Bytes() return buf
} }

View File

@ -27,9 +27,9 @@ type ConnectionTracer struct {
UpdatedCongestionState func(CongestionState) UpdatedCongestionState func(CongestionState)
UpdatedPTOCount func(value uint32) UpdatedPTOCount func(value uint32)
UpdatedKeyFromTLS func(EncryptionLevel, Perspective) UpdatedKeyFromTLS func(EncryptionLevel, Perspective)
UpdatedKey func(generation KeyPhase, remote bool) UpdatedKey func(keyPhase KeyPhase, remote bool)
DroppedEncryptionLevel func(EncryptionLevel) DroppedEncryptionLevel func(EncryptionLevel)
DroppedKey func(generation KeyPhase) DroppedKey func(keyPhase KeyPhase)
SetLossTimer func(TimerType, EncryptionLevel, time.Time) SetLossTimer func(TimerType, EncryptionLevel, time.Time)
LossTimerExpired func(TimerType, EncryptionLevel) LossTimerExpired func(TimerType, EncryptionLevel)
LossTimerCanceled func() LossTimerCanceled func()

View File

@ -37,7 +37,7 @@ type (
// The StreamType is the type of the stream (unidirectional or bidirectional). // The StreamType is the type of the stream (unidirectional or bidirectional).
StreamType = protocol.StreamType StreamType = protocol.StreamType
// The VersionNumber is the QUIC version. // The VersionNumber is the QUIC version.
VersionNumber = protocol.VersionNumber VersionNumber = protocol.Version
// The Header is the QUIC packet header, before removing header protection. // The Header is the QUIC packet header, before removing header protection.
Header = wire.Header Header = wire.Header

View File

@ -7,6 +7,8 @@ type Tracer struct {
SentPacket func(net.Addr, *Header, ByteCount, []Frame) SentPacket func(net.Addr, *Header, ByteCount, []Frame)
SentVersionNegotiationPacket func(_ net.Addr, dest, src ArbitraryLenConnectionID, _ []VersionNumber) SentVersionNegotiationPacket func(_ net.Addr, dest, src ArbitraryLenConnectionID, _ []VersionNumber)
DroppedPacket func(net.Addr, PacketType, ByteCount, PacketDropReason) DroppedPacket func(net.Addr, PacketType, ByteCount, PacketDropReason)
Debug func(name, msg string)
Close func()
} }
// NewMultiplexedTracer creates a new tracer that multiplexes events to multiple tracers. // NewMultiplexedTracer creates a new tracer that multiplexes events to multiple tracers.
@ -39,5 +41,19 @@ func NewMultiplexedTracer(tracers ...*Tracer) *Tracer {
} }
} }
}, },
Debug: func(name, msg string) {
for _, t := range tracers {
if t.Debug != nil {
t.Debug(name, msg)
}
}
},
Close: func() {
for _, t := range tracers {
if t.Close != nil {
t.Close()
}
}
},
} }
} }

View File

@ -3,12 +3,12 @@
# Install Go manually, since oss-fuzz ships with an outdated Go version. # Install Go manually, since oss-fuzz ships with an outdated Go version.
# See https://github.com/google/oss-fuzz/pull/10643. # See https://github.com/google/oss-fuzz/pull/10643.
export CXX="${CXX} -lresolv" # required by Go 1.20 export CXX="${CXX} -lresolv" # required by Go 1.20
wget https://go.dev/dl/go1.21.5.linux-amd64.tar.gz \ wget https://go.dev/dl/go1.22.0.linux-amd64.tar.gz \
&& mkdir temp-go \ && mkdir temp-go \
&& rm -rf /root/.go/* \ && rm -rf /root/.go/* \
&& tar -C temp-go/ -xzf go1.21.5.linux-amd64.tar.gz \ && tar -C temp-go/ -xzf go1.22.0.linux-amd64.tar.gz \
&& mv temp-go/go/* /root/.go/ \ && mv temp-go/go/* /root/.go/ \
&& rm -rf temp-go go1.21.5.linux-amd64.tar.gz && rm -rf temp-go go1.22.0.linux-amd64.tar.gz
( (
# fuzz qpack # fuzz qpack

View File

@ -129,7 +129,7 @@ func (h *packetHandlerMap) Add(id protocol.ConnectionID, handler packetHandler)
return true return true
} }
func (h *packetHandlerMap) AddWithConnID(clientDestConnID, newConnID protocol.ConnectionID, fn func() (packetHandler, bool)) bool { func (h *packetHandlerMap) AddWithConnID(clientDestConnID, newConnID protocol.ConnectionID, handler packetHandler) bool {
h.mutex.Lock() h.mutex.Lock()
defer h.mutex.Unlock() defer h.mutex.Unlock()
@ -137,12 +137,8 @@ func (h *packetHandlerMap) AddWithConnID(clientDestConnID, newConnID protocol.Co
h.logger.Debugf("Not adding connection ID %s for a new connection, as it already exists.", clientDestConnID) h.logger.Debugf("Not adding connection ID %s for a new connection, as it already exists.", clientDestConnID)
return false return false
} }
conn, ok := fn() h.handlers[clientDestConnID] = handler
if !ok { h.handlers[newConnID] = handler
return false
}
h.handlers[clientDestConnID] = conn
h.handlers[newConnID] = conn
h.logger.Debugf("Adding connection IDs %s and %s for a new connection.", clientDestConnID, newConnID) h.logger.Debugf("Adding connection IDs %s and %s for a new connection.", clientDestConnID, newConnID)
return true return true
} }
@ -168,18 +164,17 @@ func (h *packetHandlerMap) Retire(id protocol.ConnectionID) {
// Depending on which side closed the connection, we need to: // Depending on which side closed the connection, we need to:
// * remote close: absorb delayed packets // * remote close: absorb delayed packets
// * local close: retransmit the CONNECTION_CLOSE packet, in case it was lost // * local close: retransmit the CONNECTION_CLOSE packet, in case it was lost
func (h *packetHandlerMap) ReplaceWithClosed(ids []protocol.ConnectionID, pers protocol.Perspective, connClosePacket []byte) { func (h *packetHandlerMap) ReplaceWithClosed(ids []protocol.ConnectionID, connClosePacket []byte) {
var handler packetHandler var handler packetHandler
if connClosePacket != nil { if connClosePacket != nil {
handler = newClosedLocalConn( handler = newClosedLocalConn(
func(addr net.Addr, info packetInfo) { func(addr net.Addr, info packetInfo) {
h.enqueueClosePacket(closePacket{payload: connClosePacket, addr: addr, info: info}) h.enqueueClosePacket(closePacket{payload: connClosePacket, addr: addr, info: info})
}, },
pers,
h.logger, h.logger,
) )
} else { } else {
handler = newClosedRemoteConn(pers) handler = newClosedRemoteConn()
} }
h.mutex.Lock() h.mutex.Lock()
@ -191,7 +186,6 @@ func (h *packetHandlerMap) ReplaceWithClosed(ids []protocol.ConnectionID, pers p
time.AfterFunc(h.deleteRetiredConnsAfter, func() { time.AfterFunc(h.deleteRetiredConnsAfter, func() {
h.mutex.Lock() h.mutex.Lock()
handler.shutdown()
for _, id := range ids { for _, id := range ids {
delete(h.handlers, id) delete(h.handlers, id)
} }

View File

@ -18,13 +18,13 @@ import (
var errNothingToPack = errors.New("nothing to pack") var errNothingToPack = errors.New("nothing to pack")
type packer interface { type packer interface {
PackCoalescedPacket(onlyAck bool, maxPacketSize protocol.ByteCount, v protocol.VersionNumber) (*coalescedPacket, error) PackCoalescedPacket(onlyAck bool, maxPacketSize protocol.ByteCount, v protocol.Version) (*coalescedPacket, error)
PackAckOnlyPacket(maxPacketSize protocol.ByteCount, v protocol.VersionNumber) (shortHeaderPacket, *packetBuffer, error) PackAckOnlyPacket(maxPacketSize protocol.ByteCount, v protocol.Version) (shortHeaderPacket, *packetBuffer, error)
AppendPacket(buf *packetBuffer, maxPacketSize protocol.ByteCount, v protocol.VersionNumber) (shortHeaderPacket, error) AppendPacket(buf *packetBuffer, maxPacketSize protocol.ByteCount, v protocol.Version) (shortHeaderPacket, error)
MaybePackProbePacket(protocol.EncryptionLevel, protocol.ByteCount, protocol.VersionNumber) (*coalescedPacket, error) MaybePackProbePacket(protocol.EncryptionLevel, protocol.ByteCount, protocol.Version) (*coalescedPacket, error)
PackConnectionClose(*qerr.TransportError, protocol.ByteCount, protocol.VersionNumber) (*coalescedPacket, error) PackConnectionClose(*qerr.TransportError, protocol.ByteCount, protocol.Version) (*coalescedPacket, error)
PackApplicationClose(*qerr.ApplicationError, protocol.ByteCount, protocol.VersionNumber) (*coalescedPacket, error) PackApplicationClose(*qerr.ApplicationError, protocol.ByteCount, protocol.Version) (*coalescedPacket, error)
PackMTUProbePacket(ping ackhandler.Frame, size protocol.ByteCount, v protocol.VersionNumber) (shortHeaderPacket, *packetBuffer, error) PackMTUProbePacket(ping ackhandler.Frame, size protocol.ByteCount, v protocol.Version) (shortHeaderPacket, *packetBuffer, error)
SetToken([]byte) SetToken([]byte)
} }
@ -106,8 +106,8 @@ type sealingManager interface {
type frameSource interface { type frameSource interface {
HasData() bool HasData() bool
AppendStreamFrames([]ackhandler.StreamFrame, protocol.ByteCount, protocol.VersionNumber) ([]ackhandler.StreamFrame, protocol.ByteCount) AppendStreamFrames([]ackhandler.StreamFrame, protocol.ByteCount, protocol.Version) ([]ackhandler.StreamFrame, protocol.ByteCount)
AppendControlFrames([]ackhandler.Frame, protocol.ByteCount, protocol.VersionNumber) ([]ackhandler.Frame, protocol.ByteCount) AppendControlFrames([]ackhandler.Frame, protocol.ByteCount, protocol.Version) ([]ackhandler.Frame, protocol.ByteCount)
} }
type ackFrameSource interface { type ackFrameSource interface {
@ -170,7 +170,7 @@ func newPacketPacker(
} }
// PackConnectionClose packs a packet that closes the connection with a transport error. // PackConnectionClose packs a packet that closes the connection with a transport error.
func (p *packetPacker) PackConnectionClose(e *qerr.TransportError, maxPacketSize protocol.ByteCount, v protocol.VersionNumber) (*coalescedPacket, error) { func (p *packetPacker) PackConnectionClose(e *qerr.TransportError, maxPacketSize protocol.ByteCount, v protocol.Version) (*coalescedPacket, error) {
var reason string var reason string
// don't send details of crypto errors // don't send details of crypto errors
if !e.ErrorCode.IsCryptoError() { if !e.ErrorCode.IsCryptoError() {
@ -180,7 +180,7 @@ func (p *packetPacker) PackConnectionClose(e *qerr.TransportError, maxPacketSize
} }
// PackApplicationClose packs a packet that closes the connection with an application error. // PackApplicationClose packs a packet that closes the connection with an application error.
func (p *packetPacker) PackApplicationClose(e *qerr.ApplicationError, maxPacketSize protocol.ByteCount, v protocol.VersionNumber) (*coalescedPacket, error) { func (p *packetPacker) PackApplicationClose(e *qerr.ApplicationError, maxPacketSize protocol.ByteCount, v protocol.Version) (*coalescedPacket, error) {
return p.packConnectionClose(true, uint64(e.ErrorCode), 0, e.ErrorMessage, maxPacketSize, v) return p.packConnectionClose(true, uint64(e.ErrorCode), 0, e.ErrorMessage, maxPacketSize, v)
} }
@ -190,7 +190,7 @@ func (p *packetPacker) packConnectionClose(
frameType uint64, frameType uint64,
reason string, reason string,
maxPacketSize protocol.ByteCount, maxPacketSize protocol.ByteCount,
v protocol.VersionNumber, v protocol.Version,
) (*coalescedPacket, error) { ) (*coalescedPacket, error) {
var sealers [4]sealer var sealers [4]sealer
var hdrs [3]*wire.ExtendedHeader var hdrs [3]*wire.ExtendedHeader
@ -293,7 +293,7 @@ func (p *packetPacker) packConnectionClose(
// longHeaderPacketLength calculates the length of a serialized long header packet. // longHeaderPacketLength calculates the length of a serialized long header packet.
// It takes into account that packets that have a tiny payload need to be padded, // It takes into account that packets that have a tiny payload need to be padded,
// such that len(payload) + packet number len >= 4 + AEAD overhead // such that len(payload) + packet number len >= 4 + AEAD overhead
func (p *packetPacker) longHeaderPacketLength(hdr *wire.ExtendedHeader, pl payload, v protocol.VersionNumber) protocol.ByteCount { func (p *packetPacker) longHeaderPacketLength(hdr *wire.ExtendedHeader, pl payload, v protocol.Version) protocol.ByteCount {
var paddingLen protocol.ByteCount var paddingLen protocol.ByteCount
pnLen := protocol.ByteCount(hdr.PacketNumberLen) pnLen := protocol.ByteCount(hdr.PacketNumberLen)
if pl.length < 4-pnLen { if pl.length < 4-pnLen {
@ -328,7 +328,7 @@ func (p *packetPacker) initialPaddingLen(frames []ackhandler.Frame, currentSize,
// PackCoalescedPacket packs a new packet. // PackCoalescedPacket packs a new packet.
// It packs an Initial / Handshake if there is data to send in these packet number spaces. // It packs an Initial / Handshake if there is data to send in these packet number spaces.
// It should only be called before the handshake is confirmed. // It should only be called before the handshake is confirmed.
func (p *packetPacker) PackCoalescedPacket(onlyAck bool, maxPacketSize protocol.ByteCount, v protocol.VersionNumber) (*coalescedPacket, error) { func (p *packetPacker) PackCoalescedPacket(onlyAck bool, maxPacketSize protocol.ByteCount, v protocol.Version) (*coalescedPacket, error) {
var ( var (
initialHdr, handshakeHdr, zeroRTTHdr *wire.ExtendedHeader initialHdr, handshakeHdr, zeroRTTHdr *wire.ExtendedHeader
initialPayload, handshakePayload, zeroRTTPayload, oneRTTPayload payload initialPayload, handshakePayload, zeroRTTPayload, oneRTTPayload payload
@ -442,7 +442,7 @@ func (p *packetPacker) PackCoalescedPacket(onlyAck bool, maxPacketSize protocol.
// PackAckOnlyPacket packs a packet containing only an ACK in the application data packet number space. // PackAckOnlyPacket packs a packet containing only an ACK in the application data packet number space.
// It should be called after the handshake is confirmed. // It should be called after the handshake is confirmed.
func (p *packetPacker) PackAckOnlyPacket(maxPacketSize protocol.ByteCount, v protocol.VersionNumber) (shortHeaderPacket, *packetBuffer, error) { func (p *packetPacker) PackAckOnlyPacket(maxPacketSize protocol.ByteCount, v protocol.Version) (shortHeaderPacket, *packetBuffer, error) {
buf := getPacketBuffer() buf := getPacketBuffer()
packet, err := p.appendPacket(buf, true, maxPacketSize, v) packet, err := p.appendPacket(buf, true, maxPacketSize, v)
return packet, buf, err return packet, buf, err
@ -450,11 +450,11 @@ func (p *packetPacker) PackAckOnlyPacket(maxPacketSize protocol.ByteCount, v pro
// AppendPacket packs a packet in the application data packet number space. // AppendPacket packs a packet in the application data packet number space.
// It should be called after the handshake is confirmed. // It should be called after the handshake is confirmed.
func (p *packetPacker) AppendPacket(buf *packetBuffer, maxPacketSize protocol.ByteCount, v protocol.VersionNumber) (shortHeaderPacket, error) { func (p *packetPacker) AppendPacket(buf *packetBuffer, maxPacketSize protocol.ByteCount, v protocol.Version) (shortHeaderPacket, error) {
return p.appendPacket(buf, false, maxPacketSize, v) return p.appendPacket(buf, false, maxPacketSize, v)
} }
func (p *packetPacker) appendPacket(buf *packetBuffer, onlyAck bool, maxPacketSize protocol.ByteCount, v protocol.VersionNumber) (shortHeaderPacket, error) { func (p *packetPacker) appendPacket(buf *packetBuffer, onlyAck bool, maxPacketSize protocol.ByteCount, v protocol.Version) (shortHeaderPacket, error) {
sealer, err := p.cryptoSetup.Get1RTTSealer() sealer, err := p.cryptoSetup.Get1RTTSealer()
if err != nil { if err != nil {
return shortHeaderPacket{}, err return shortHeaderPacket{}, err
@ -471,7 +471,7 @@ func (p *packetPacker) appendPacket(buf *packetBuffer, onlyAck bool, maxPacketSi
return p.appendShortHeaderPacket(buf, connID, pn, pnLen, kp, pl, 0, maxPacketSize, sealer, false, v) return p.appendShortHeaderPacket(buf, connID, pn, pnLen, kp, pl, 0, maxPacketSize, sealer, false, v)
} }
func (p *packetPacker) maybeGetCryptoPacket(maxPacketSize protocol.ByteCount, encLevel protocol.EncryptionLevel, onlyAck, ackAllowed bool, v protocol.VersionNumber) (*wire.ExtendedHeader, payload) { func (p *packetPacker) maybeGetCryptoPacket(maxPacketSize protocol.ByteCount, encLevel protocol.EncryptionLevel, onlyAck, ackAllowed bool, v protocol.Version) (*wire.ExtendedHeader, payload) {
if onlyAck { if onlyAck {
if ack := p.acks.GetAckFrame(encLevel, true); ack != nil { if ack := p.acks.GetAckFrame(encLevel, true); ack != nil {
return p.getLongHeader(encLevel, v), payload{ return p.getLongHeader(encLevel, v), payload{
@ -543,7 +543,7 @@ func (p *packetPacker) maybeGetCryptoPacket(maxPacketSize protocol.ByteCount, en
return hdr, pl return hdr, pl
} }
func (p *packetPacker) maybeGetAppDataPacketFor0RTT(sealer sealer, maxPacketSize protocol.ByteCount, v protocol.VersionNumber) (*wire.ExtendedHeader, payload) { func (p *packetPacker) maybeGetAppDataPacketFor0RTT(sealer sealer, maxPacketSize protocol.ByteCount, v protocol.Version) (*wire.ExtendedHeader, payload) {
if p.perspective != protocol.PerspectiveClient { if p.perspective != protocol.PerspectiveClient {
return nil, payload{} return nil, payload{}
} }
@ -553,12 +553,12 @@ func (p *packetPacker) maybeGetAppDataPacketFor0RTT(sealer sealer, maxPacketSize
return hdr, p.maybeGetAppDataPacket(maxPayloadSize, false, false, v) return hdr, p.maybeGetAppDataPacket(maxPayloadSize, false, false, v)
} }
func (p *packetPacker) maybeGetShortHeaderPacket(sealer handshake.ShortHeaderSealer, hdrLen protocol.ByteCount, maxPacketSize protocol.ByteCount, onlyAck, ackAllowed bool, v protocol.VersionNumber) payload { func (p *packetPacker) maybeGetShortHeaderPacket(sealer handshake.ShortHeaderSealer, hdrLen protocol.ByteCount, maxPacketSize protocol.ByteCount, onlyAck, ackAllowed bool, v protocol.Version) payload {
maxPayloadSize := maxPacketSize - hdrLen - protocol.ByteCount(sealer.Overhead()) maxPayloadSize := maxPacketSize - hdrLen - protocol.ByteCount(sealer.Overhead())
return p.maybeGetAppDataPacket(maxPayloadSize, onlyAck, ackAllowed, v) return p.maybeGetAppDataPacket(maxPayloadSize, onlyAck, ackAllowed, v)
} }
func (p *packetPacker) maybeGetAppDataPacket(maxPayloadSize protocol.ByteCount, onlyAck, ackAllowed bool, v protocol.VersionNumber) payload { func (p *packetPacker) maybeGetAppDataPacket(maxPayloadSize protocol.ByteCount, onlyAck, ackAllowed bool, v protocol.Version) payload {
pl := p.composeNextPacket(maxPayloadSize, onlyAck, ackAllowed, v) pl := p.composeNextPacket(maxPayloadSize, onlyAck, ackAllowed, v)
// check if we have anything to send // check if we have anything to send
@ -581,7 +581,7 @@ func (p *packetPacker) maybeGetAppDataPacket(maxPayloadSize protocol.ByteCount,
return pl return pl
} }
func (p *packetPacker) composeNextPacket(maxFrameSize protocol.ByteCount, onlyAck, ackAllowed bool, v protocol.VersionNumber) payload { func (p *packetPacker) composeNextPacket(maxFrameSize protocol.ByteCount, onlyAck, ackAllowed bool, v protocol.Version) payload {
if onlyAck { if onlyAck {
if ack := p.acks.GetAckFrame(protocol.Encryption1RTT, true); ack != nil { if ack := p.acks.GetAckFrame(protocol.Encryption1RTT, true); ack != nil {
return payload{ack: ack, length: ack.Length(v)} return payload{ack: ack, length: ack.Length(v)}
@ -589,12 +589,11 @@ func (p *packetPacker) composeNextPacket(maxFrameSize protocol.ByteCount, onlyAc
return payload{} return payload{}
} }
pl := payload{streamFrames: make([]ackhandler.StreamFrame, 0, 1)}
hasData := p.framer.HasData() hasData := p.framer.HasData()
hasRetransmission := p.retransmissionQueue.HasAppData() hasRetransmission := p.retransmissionQueue.HasAppData()
var hasAck bool var hasAck bool
var pl payload
if ackAllowed { if ackAllowed {
if ack := p.acks.GetAckFrame(protocol.Encryption1RTT, !hasRetransmission && !hasData); ack != nil { if ack := p.acks.GetAckFrame(protocol.Encryption1RTT, !hasRetransmission && !hasData); ack != nil {
pl.ack = ack pl.ack = ack
@ -661,7 +660,7 @@ func (p *packetPacker) composeNextPacket(maxFrameSize protocol.ByteCount, onlyAc
return pl return pl
} }
func (p *packetPacker) MaybePackProbePacket(encLevel protocol.EncryptionLevel, maxPacketSize protocol.ByteCount, v protocol.VersionNumber) (*coalescedPacket, error) { func (p *packetPacker) MaybePackProbePacket(encLevel protocol.EncryptionLevel, maxPacketSize protocol.ByteCount, v protocol.Version) (*coalescedPacket, error) {
if encLevel == protocol.Encryption1RTT { if encLevel == protocol.Encryption1RTT {
s, err := p.cryptoSetup.Get1RTTSealer() s, err := p.cryptoSetup.Get1RTTSealer()
if err != nil { if err != nil {
@ -727,7 +726,7 @@ func (p *packetPacker) MaybePackProbePacket(encLevel protocol.EncryptionLevel, m
return packet, nil return packet, nil
} }
func (p *packetPacker) PackMTUProbePacket(ping ackhandler.Frame, size protocol.ByteCount, v protocol.VersionNumber) (shortHeaderPacket, *packetBuffer, error) { func (p *packetPacker) PackMTUProbePacket(ping ackhandler.Frame, size protocol.ByteCount, v protocol.Version) (shortHeaderPacket, *packetBuffer, error) {
pl := payload{ pl := payload{
frames: []ackhandler.Frame{ping}, frames: []ackhandler.Frame{ping},
length: ping.Frame.Length(v), length: ping.Frame.Length(v),
@ -745,7 +744,7 @@ func (p *packetPacker) PackMTUProbePacket(ping ackhandler.Frame, size protocol.B
return packet, buffer, err return packet, buffer, err
} }
func (p *packetPacker) getLongHeader(encLevel protocol.EncryptionLevel, v protocol.VersionNumber) *wire.ExtendedHeader { func (p *packetPacker) getLongHeader(encLevel protocol.EncryptionLevel, v protocol.Version) *wire.ExtendedHeader {
pn, pnLen := p.pnManager.PeekPacketNumber(encLevel) pn, pnLen := p.pnManager.PeekPacketNumber(encLevel)
hdr := &wire.ExtendedHeader{ hdr := &wire.ExtendedHeader{
PacketNumber: pn, PacketNumber: pn,
@ -768,7 +767,7 @@ func (p *packetPacker) getLongHeader(encLevel protocol.EncryptionLevel, v protoc
return hdr return hdr
} }
func (p *packetPacker) appendLongHeaderPacket(buffer *packetBuffer, header *wire.ExtendedHeader, pl payload, padding protocol.ByteCount, encLevel protocol.EncryptionLevel, sealer sealer, v protocol.VersionNumber) (*longHeaderPacket, error) { func (p *packetPacker) appendLongHeaderPacket(buffer *packetBuffer, header *wire.ExtendedHeader, pl payload, padding protocol.ByteCount, encLevel protocol.EncryptionLevel, sealer sealer, v protocol.Version) (*longHeaderPacket, error) {
var paddingLen protocol.ByteCount var paddingLen protocol.ByteCount
pnLen := protocol.ByteCount(header.PacketNumberLen) pnLen := protocol.ByteCount(header.PacketNumberLen)
if pl.length < 4-pnLen { if pl.length < 4-pnLen {
@ -814,7 +813,7 @@ func (p *packetPacker) appendShortHeaderPacket(
padding, maxPacketSize protocol.ByteCount, padding, maxPacketSize protocol.ByteCount,
sealer sealer, sealer sealer,
isMTUProbePacket bool, isMTUProbePacket bool,
v protocol.VersionNumber, v protocol.Version,
) (shortHeaderPacket, error) { ) (shortHeaderPacket, error) {
var paddingLen protocol.ByteCount var paddingLen protocol.ByteCount
if pl.length < 4-protocol.ByteCount(pnLen) { if pl.length < 4-protocol.ByteCount(pnLen) {
@ -860,7 +859,7 @@ func (p *packetPacker) appendShortHeaderPacket(
// appendPacketPayload serializes the payload of a packet into the raw byte slice. // appendPacketPayload serializes the payload of a packet into the raw byte slice.
// It modifies the order of payload.frames. // It modifies the order of payload.frames.
func (p *packetPacker) appendPacketPayload(raw []byte, pl payload, paddingLen protocol.ByteCount, v protocol.VersionNumber) ([]byte, error) { func (p *packetPacker) appendPacketPayload(raw []byte, pl payload, paddingLen protocol.ByteCount, v protocol.Version) ([]byte, error) {
payloadOffset := len(raw) payloadOffset := len(raw)
if pl.ack != nil { if pl.ack != nil {
var err error var err error

View File

@ -53,7 +53,7 @@ func newPacketUnpacker(cs handshake.CryptoSetup, shortHdrConnIDLen int) *packetU
// If the reserved bits are invalid, the error is wire.ErrInvalidReservedBits. // If the reserved bits are invalid, the error is wire.ErrInvalidReservedBits.
// If any other error occurred when parsing the header, the error is of type headerParseError. // If any other error occurred when parsing the header, the error is of type headerParseError.
// If decrypting the payload fails for any reason, the error is the error returned by the AEAD. // If decrypting the payload fails for any reason, the error is the error returned by the AEAD.
func (u *packetUnpacker) UnpackLongHeader(hdr *wire.Header, rcvTime time.Time, data []byte, v protocol.VersionNumber) (*unpackedPacket, error) { func (u *packetUnpacker) UnpackLongHeader(hdr *wire.Header, rcvTime time.Time, data []byte, v protocol.Version) (*unpackedPacket, error) {
var encLevel protocol.EncryptionLevel var encLevel protocol.EncryptionLevel
var extHdr *wire.ExtendedHeader var extHdr *wire.ExtendedHeader
var decrypted []byte var decrypted []byte
@ -125,7 +125,7 @@ func (u *packetUnpacker) UnpackShortHeader(rcvTime time.Time, data []byte) (prot
return pn, pnLen, kp, decrypted, nil return pn, pnLen, kp, decrypted, nil
} }
func (u *packetUnpacker) unpackLongHeaderPacket(opener handshake.LongHeaderOpener, hdr *wire.Header, data []byte, v protocol.VersionNumber) (*wire.ExtendedHeader, []byte, error) { func (u *packetUnpacker) unpackLongHeaderPacket(opener handshake.LongHeaderOpener, hdr *wire.Header, data []byte, v protocol.Version) (*wire.ExtendedHeader, []byte, error) {
extHdr, parseErr := u.unpackLongHeader(opener, hdr, data, v) extHdr, parseErr := u.unpackLongHeader(opener, hdr, data, v)
// If the reserved bits are set incorrectly, we still need to continue unpacking. // If the reserved bits are set incorrectly, we still need to continue unpacking.
// This avoids a timing side-channel, which otherwise might allow an attacker // This avoids a timing side-channel, which otherwise might allow an attacker
@ -187,7 +187,7 @@ func (u *packetUnpacker) unpackShortHeader(hd headerDecryptor, data []byte) (int
} }
// The error is either nil, a wire.ErrInvalidReservedBits or of type headerParseError. // The error is either nil, a wire.ErrInvalidReservedBits or of type headerParseError.
func (u *packetUnpacker) unpackLongHeader(hd headerDecryptor, hdr *wire.Header, data []byte, v protocol.VersionNumber) (*wire.ExtendedHeader, error) { func (u *packetUnpacker) unpackLongHeader(hd headerDecryptor, hdr *wire.Header, data []byte, v protocol.Version) (*wire.ExtendedHeader, error) {
extHdr, err := unpackLongHeader(hd, hdr, data, v) extHdr, err := unpackLongHeader(hd, hdr, data, v)
if err != nil && err != wire.ErrInvalidReservedBits { if err != nil && err != wire.ErrInvalidReservedBits {
return nil, &headerParseError{err: err} return nil, &headerParseError{err: err}
@ -195,7 +195,7 @@ func (u *packetUnpacker) unpackLongHeader(hd headerDecryptor, hdr *wire.Header,
return extHdr, err return extHdr, err
} }
func unpackLongHeader(hd headerDecryptor, hdr *wire.Header, data []byte, v protocol.VersionNumber) (*wire.ExtendedHeader, error) { func unpackLongHeader(hd headerDecryptor, hdr *wire.Header, data []byte, v protocol.Version) (*wire.ExtendedHeader, error) {
r := bytes.NewReader(data) r := bytes.NewReader(data)
hdrLen := hdr.ParsedLen() hdrLen := hdr.ParsedLen()

View File

@ -292,10 +292,6 @@ func (s *receiveStream) handleResetStreamFrameImpl(frame *wire.ResetStreamFrame)
return newlyRcvdFinalOffset, nil return newlyRcvdFinalOffset, nil
} }
func (s *receiveStream) CloseRemote(offset protocol.ByteCount) {
s.handleStreamFrame(&wire.StreamFrame{Fin: true, Offset: offset})
}
func (s *receiveStream) SetReadDeadline(t time.Time) error { func (s *receiveStream) SetReadDeadline(t time.Time) error {
s.mutex.Lock() s.mutex.Lock()
s.deadline = t s.deadline = t

View File

@ -74,7 +74,7 @@ func (q *retransmissionQueue) addAppData(f wire.Frame) {
q.appData = append(q.appData, f) q.appData = append(q.appData, f)
} }
func (q *retransmissionQueue) GetInitialFrame(maxLen protocol.ByteCount, v protocol.VersionNumber) wire.Frame { func (q *retransmissionQueue) GetInitialFrame(maxLen protocol.ByteCount, v protocol.Version) wire.Frame {
if len(q.initialCryptoData) > 0 { if len(q.initialCryptoData) > 0 {
f := q.initialCryptoData[0] f := q.initialCryptoData[0]
newFrame, needsSplit := f.MaybeSplitOffFrame(maxLen, v) newFrame, needsSplit := f.MaybeSplitOffFrame(maxLen, v)
@ -97,7 +97,7 @@ func (q *retransmissionQueue) GetInitialFrame(maxLen protocol.ByteCount, v proto
return f return f
} }
func (q *retransmissionQueue) GetHandshakeFrame(maxLen protocol.ByteCount, v protocol.VersionNumber) wire.Frame { func (q *retransmissionQueue) GetHandshakeFrame(maxLen protocol.ByteCount, v protocol.Version) wire.Frame {
if len(q.handshakeCryptoData) > 0 { if len(q.handshakeCryptoData) > 0 {
f := q.handshakeCryptoData[0] f := q.handshakeCryptoData[0]
newFrame, needsSplit := f.MaybeSplitOffFrame(maxLen, v) newFrame, needsSplit := f.MaybeSplitOffFrame(maxLen, v)
@ -120,7 +120,7 @@ func (q *retransmissionQueue) GetHandshakeFrame(maxLen protocol.ByteCount, v pro
return f return f
} }
func (q *retransmissionQueue) GetAppDataFrame(maxLen protocol.ByteCount, v protocol.VersionNumber) wire.Frame { func (q *retransmissionQueue) GetAppDataFrame(maxLen protocol.ByteCount, v protocol.Version) wire.Frame {
if len(q.appData) == 0 { if len(q.appData) == 0 {
return nil return nil
} }

View File

@ -18,7 +18,7 @@ type sendStreamI interface {
SendStream SendStream
handleStopSendingFrame(*wire.StopSendingFrame) handleStopSendingFrame(*wire.StopSendingFrame)
hasData() bool hasData() bool
popStreamFrame(maxBytes protocol.ByteCount, v protocol.VersionNumber) (frame ackhandler.StreamFrame, ok, hasMore bool) popStreamFrame(maxBytes protocol.ByteCount, v protocol.Version) (frame ackhandler.StreamFrame, ok, hasMore bool)
closeForShutdown(error) closeForShutdown(error)
updateSendWindow(protocol.ByteCount) updateSendWindow(protocol.ByteCount)
} }
@ -198,7 +198,7 @@ func (s *sendStream) canBufferStreamFrame() bool {
// popStreamFrame returns the next STREAM frame that is supposed to be sent on this stream // popStreamFrame returns the next STREAM frame that is supposed to be sent on this stream
// maxBytes is the maximum length this frame (including frame header) will have. // maxBytes is the maximum length this frame (including frame header) will have.
func (s *sendStream) popStreamFrame(maxBytes protocol.ByteCount, v protocol.VersionNumber) (af ackhandler.StreamFrame, ok, hasMore bool) { func (s *sendStream) popStreamFrame(maxBytes protocol.ByteCount, v protocol.Version) (af ackhandler.StreamFrame, ok, hasMore bool) {
s.mutex.Lock() s.mutex.Lock()
f, hasMoreData := s.popNewOrRetransmittedStreamFrame(maxBytes, v) f, hasMoreData := s.popNewOrRetransmittedStreamFrame(maxBytes, v)
if f != nil { if f != nil {
@ -215,7 +215,7 @@ func (s *sendStream) popStreamFrame(maxBytes protocol.ByteCount, v protocol.Vers
}, true, hasMoreData }, true, hasMoreData
} }
func (s *sendStream) popNewOrRetransmittedStreamFrame(maxBytes protocol.ByteCount, v protocol.VersionNumber) (*wire.StreamFrame, bool /* has more data to send */) { func (s *sendStream) popNewOrRetransmittedStreamFrame(maxBytes protocol.ByteCount, v protocol.Version) (*wire.StreamFrame, bool /* has more data to send */) {
if s.cancelWriteErr != nil || s.closeForShutdownErr != nil { if s.cancelWriteErr != nil || s.closeForShutdownErr != nil {
return nil, false return nil, false
} }
@ -269,7 +269,7 @@ func (s *sendStream) popNewOrRetransmittedStreamFrame(maxBytes protocol.ByteCoun
return f, hasMoreData return f, hasMoreData
} }
func (s *sendStream) popNewStreamFrame(maxBytes, sendWindow protocol.ByteCount, v protocol.VersionNumber) (*wire.StreamFrame, bool) { func (s *sendStream) popNewStreamFrame(maxBytes, sendWindow protocol.ByteCount, v protocol.Version) (*wire.StreamFrame, bool) {
if s.nextFrame != nil { if s.nextFrame != nil {
nextFrame := s.nextFrame nextFrame := s.nextFrame
s.nextFrame = nil s.nextFrame = nil
@ -304,7 +304,7 @@ func (s *sendStream) popNewStreamFrame(maxBytes, sendWindow protocol.ByteCount,
return f, hasMoreData return f, hasMoreData
} }
func (s *sendStream) popNewStreamFrameWithoutBuffer(f *wire.StreamFrame, maxBytes, sendWindow protocol.ByteCount, v protocol.VersionNumber) bool { func (s *sendStream) popNewStreamFrameWithoutBuffer(f *wire.StreamFrame, maxBytes, sendWindow protocol.ByteCount, v protocol.Version) bool {
maxDataLen := f.MaxDataLen(maxBytes, v) maxDataLen := f.MaxDataLen(maxBytes, v)
if maxDataLen == 0 { // a STREAM frame must have at least one byte of data if maxDataLen == 0 { // a STREAM frame must have at least one byte of data
return s.dataForWriting != nil || s.nextFrame != nil || s.finishedWriting return s.dataForWriting != nil || s.nextFrame != nil || s.finishedWriting
@ -314,7 +314,7 @@ func (s *sendStream) popNewStreamFrameWithoutBuffer(f *wire.StreamFrame, maxByte
return s.dataForWriting != nil || s.nextFrame != nil || s.finishedWriting return s.dataForWriting != nil || s.nextFrame != nil || s.finishedWriting
} }
func (s *sendStream) maybeGetRetransmission(maxBytes protocol.ByteCount, v protocol.VersionNumber) (*wire.StreamFrame, bool /* has more retransmissions */) { func (s *sendStream) maybeGetRetransmission(maxBytes protocol.ByteCount, v protocol.Version) (*wire.StreamFrame, bool /* has more retransmissions */) {
f := s.retransmissionQueue[0] f := s.retransmissionQueue[0]
newFrame, needsSplit := f.MaybeSplitOffFrame(maxBytes, v) newFrame, needsSplit := f.MaybeSplitOffFrame(maxBytes, v)
if needsSplit { if needsSplit {
@ -404,11 +404,13 @@ func (s *sendStream) cancelWriteImpl(errorCode qerr.StreamErrorCode, remote bool
} }
func (s *sendStream) updateSendWindow(limit protocol.ByteCount) { func (s *sendStream) updateSendWindow(limit protocol.ByteCount) {
updated := s.flowController.UpdateSendWindow(limit)
if !updated { // duplicate or reordered MAX_STREAM_DATA frame
return
}
s.mutex.Lock() s.mutex.Lock()
hasStreamData := s.dataForWriting != nil || s.nextFrame != nil hasStreamData := s.dataForWriting != nil || s.nextFrame != nil
s.mutex.Unlock() s.mutex.Unlock()
s.flowController.UpdateSendWindow(limit)
if hasStreamData { if hasStreamData {
s.sender.onHasStreamData(s.streamID) s.sender.onHasStreamData(s.streamID)
} }

View File

@ -7,7 +7,6 @@ import (
"fmt" "fmt"
"net" "net"
"sync" "sync"
"sync/atomic"
"time" "time"
"github.com/quic-go/quic-go/internal/handshake" "github.com/quic-go/quic-go/internal/handshake"
@ -24,15 +23,14 @@ var ErrServerClosed = errors.New("quic: server closed")
// packetHandler handles packets // packetHandler handles packets
type packetHandler interface { type packetHandler interface {
handlePacket(receivedPacket) handlePacket(receivedPacket)
shutdown()
destroy(error) destroy(error)
getPerspective() protocol.Perspective closeWithTransportError(qerr.TransportErrorCode)
} }
type packetHandlerManager interface { type packetHandlerManager interface {
Get(protocol.ConnectionID) (packetHandler, bool) Get(protocol.ConnectionID) (packetHandler, bool)
GetByResetToken(protocol.StatelessResetToken) (packetHandler, bool) GetByResetToken(protocol.StatelessResetToken) (packetHandler, bool)
AddWithConnID(protocol.ConnectionID, protocol.ConnectionID, func() (packetHandler, bool)) bool AddWithConnID(destConnID, newConnID protocol.ConnectionID, h packetHandler) bool
Close(error) Close(error)
connRunner connRunner
} }
@ -41,11 +39,9 @@ type quicConn interface {
EarlyConnection EarlyConnection
earlyConnReady() <-chan struct{} earlyConnReady() <-chan struct{}
handlePacket(receivedPacket) handlePacket(receivedPacket)
GetVersion() protocol.VersionNumber
getPerspective() protocol.Perspective
run() error run() error
destroy(error) destroy(error)
shutdown() closeWithTransportError(TransportErrorCode)
} }
type zeroRTTQueue struct { type zeroRTTQueue struct {
@ -98,10 +94,10 @@ type baseServer struct {
*logging.ConnectionTracer, *logging.ConnectionTracer,
uint64, uint64,
utils.Logger, utils.Logger,
protocol.VersionNumber, protocol.Version,
) quicConn ) quicConn
closeOnce sync.Once closeMx sync.Mutex
errorChan chan struct{} // is closed when the server is closed errorChan chan struct{} // is closed when the server is closed
closeErr error closeErr error
running chan struct{} // closed as soon as run() returns running chan struct{} // closed as soon as run() returns
@ -111,8 +107,9 @@ type baseServer struct {
connectionRefusedQueue chan rejectedPacket connectionRefusedQueue chan rejectedPacket
retryQueue chan rejectedPacket retryQueue chan rejectedPacket
connQueue chan quicConn verifySourceAddress func(net.Addr) bool
connQueueLen int32 // to be used as an atomic
connQueue chan quicConn
tracer *logging.Tracer tracer *logging.Tracer
@ -240,6 +237,7 @@ func newServer(
onClose func(), onClose func(),
tokenGeneratorKey TokenGeneratorKey, tokenGeneratorKey TokenGeneratorKey,
maxTokenAge time.Duration, maxTokenAge time.Duration,
verifySourceAddress func(net.Addr) bool,
disableVersionNegotiation bool, disableVersionNegotiation bool,
acceptEarly bool, acceptEarly bool,
) *baseServer { ) *baseServer {
@ -249,9 +247,10 @@ func newServer(
config: config, config: config,
tokenGenerator: handshake.NewTokenGenerator(tokenGeneratorKey), tokenGenerator: handshake.NewTokenGenerator(tokenGeneratorKey),
maxTokenAge: maxTokenAge, maxTokenAge: maxTokenAge,
verifySourceAddress: verifySourceAddress,
connIDGenerator: connIDGenerator, connIDGenerator: connIDGenerator,
connHandler: connHandler, connHandler: connHandler,
connQueue: make(chan quicConn), connQueue: make(chan quicConn, protocol.MaxAcceptQueueSize),
errorChan: make(chan struct{}), errorChan: make(chan struct{}),
running: make(chan struct{}), running: make(chan struct{}),
receivedPackets: make(chan receivedPacket, protocol.MaxServerUnprocessedPackets), receivedPackets: make(chan receivedPacket, protocol.MaxServerUnprocessedPackets),
@ -322,7 +321,6 @@ func (s *baseServer) accept(ctx context.Context) (quicConn, error) {
case <-ctx.Done(): case <-ctx.Done():
return nil, ctx.Err() return nil, ctx.Err()
case conn := <-s.connQueue: case conn := <-s.connQueue:
atomic.AddInt32(&s.connQueueLen, -1)
return conn, nil return conn, nil
case <-s.errorChan: case <-s.errorChan:
return nil, s.closeErr return nil, s.closeErr
@ -335,15 +333,19 @@ func (s *baseServer) Close() error {
} }
func (s *baseServer) close(e error, notifyOnClose bool) { func (s *baseServer) close(e error, notifyOnClose bool) {
s.closeOnce.Do(func() { s.closeMx.Lock()
s.closeErr = e if s.closeErr != nil {
close(s.errorChan) s.closeMx.Unlock()
return
}
s.closeErr = e
close(s.errorChan)
<-s.running
s.closeMx.Unlock()
<-s.running if notifyOnClose {
if notifyOnClose { s.onClose()
s.onClose() }
}
})
} }
// Addr returns the server's network address // Addr returns the server's network address
@ -542,10 +544,10 @@ func (s *baseServer) validateToken(token *handshake.Token, addr net.Addr) bool {
func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error { func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error {
if len(hdr.Token) == 0 && hdr.DestConnectionID.Len() < protocol.MinConnectionIDLenInitial { if len(hdr.Token) == 0 && hdr.DestConnectionID.Len() < protocol.MinConnectionIDLenInitial {
p.buffer.Release()
if s.tracer != nil && s.tracer.DroppedPacket != nil { if s.tracer != nil && s.tracer.DroppedPacket != nil {
s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropUnexpectedPacket) s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropUnexpectedPacket)
} }
p.buffer.Release()
return errors.New("too short connection ID") return errors.New("too short connection ID")
} }
@ -558,8 +560,9 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error
} }
var ( var (
token *handshake.Token token *handshake.Token
retrySrcConnID *protocol.ConnectionID retrySrcConnID *protocol.ConnectionID
clientAddrVerified bool
) )
origDestConnID := hdr.DestConnectionID origDestConnID := hdr.DestConnectionID
if len(hdr.Token) > 0 { if len(hdr.Token) > 0 {
@ -572,28 +575,30 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error
token = tok token = tok
} }
} }
if token != nil {
clientAddrIsValid := s.validateToken(token, p.remoteAddr) clientAddrVerified = s.validateToken(token, p.remoteAddr)
if token != nil && !clientAddrIsValid { if !clientAddrVerified {
// For invalid and expired non-retry tokens, we don't send an INVALID_TOKEN error. // For invalid and expired non-retry tokens, we don't send an INVALID_TOKEN error.
// We just ignore them, and act as if there was no token on this packet at all. // We just ignore them, and act as if there was no token on this packet at all.
// This also means we might send a Retry later. // This also means we might send a Retry later.
if !token.IsRetryToken { if !token.IsRetryToken {
token = nil token = nil
} else { } else {
// For Retry tokens, we send an INVALID_ERROR if // For Retry tokens, we send an INVALID_ERROR if
// * the token is too old, or // * the token is too old, or
// * the token is invalid, in case of a retry token. // * the token is invalid, in case of a retry token.
select { select {
case s.invalidTokenQueue <- rejectedPacket{receivedPacket: p, hdr: hdr}: case s.invalidTokenQueue <- rejectedPacket{receivedPacket: p, hdr: hdr}:
default: default:
// drop packet if we can't send out the INVALID_TOKEN packets fast enough // drop packet if we can't send out the INVALID_TOKEN packets fast enough
p.buffer.Release() p.buffer.Release()
}
return nil
} }
return nil
} }
} }
if token == nil && s.config.RequireAddressValidation(p.remoteAddr) {
if token == nil && s.verifySourceAddress != nil && s.verifySourceAddress(p.remoteAddr) {
// Retry invalidates all 0-RTT packets sent. // Retry invalidates all 0-RTT packets sent.
delete(s.zeroRTTQueues, hdr.DestConnectionID) delete(s.zeroRTTQueues, hdr.DestConnectionID)
select { select {
@ -605,121 +610,116 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error
return nil return nil
} }
if queueLen := atomic.LoadInt32(&s.connQueueLen); queueLen >= protocol.MaxAcceptQueueSize { config := s.config
s.logger.Debugf("Rejecting new connection. Server currently busy. Accept queue length: %d (max %d)", queueLen, protocol.MaxAcceptQueueSize) if s.config.GetConfigForClient != nil {
select { conf, err := s.config.GetConfigForClient(&ClientHelloInfo{
case s.connectionRefusedQueue <- rejectedPacket{receivedPacket: p, hdr: hdr}: RemoteAddr: p.remoteAddr,
default: AddrVerified: clientAddrVerified,
// drop packet if we can't send out the CONNECTION_REFUSED fast enough })
p.buffer.Release() if err != nil {
s.logger.Debugf("Rejecting new connection due to GetConfigForClient callback")
delete(s.zeroRTTQueues, hdr.DestConnectionID)
select {
case s.connectionRefusedQueue <- rejectedPacket{receivedPacket: p, hdr: hdr}:
default:
// drop packet if we can't send out the CONNECTION_REFUSED fast enough
p.buffer.Release()
}
return nil
} }
return nil config = populateConfig(conf)
} }
var conn quicConn
tracingID := nextConnTracingID()
var tracer *logging.ConnectionTracer
if config.Tracer != nil {
// Use the same connection ID that is passed to the client's GetLogWriter callback.
connID := hdr.DestConnectionID
if origDestConnID.Len() > 0 {
connID = origDestConnID
}
tracer = config.Tracer(context.WithValue(context.Background(), ConnectionTracingKey, tracingID), protocol.PerspectiveServer, connID)
}
connID, err := s.connIDGenerator.GenerateConnectionID() connID, err := s.connIDGenerator.GenerateConnectionID()
if err != nil { if err != nil {
return err return err
} }
s.logger.Debugf("Changing connection ID to %s.", connID) s.logger.Debugf("Changing connection ID to %s.", connID)
var conn quicConn conn = s.newConn(
tracingID := nextConnTracingID() newSendConn(s.conn, p.remoteAddr, p.info, s.logger),
if added := s.connHandler.AddWithConnID(hdr.DestConnectionID, connID, func() (packetHandler, bool) { s.connHandler,
config := s.config origDestConnID,
if s.config.GetConfigForClient != nil { retrySrcConnID,
conf, err := s.config.GetConfigForClient(&ClientHelloInfo{RemoteAddr: p.remoteAddr}) hdr.DestConnectionID,
if err != nil { hdr.SrcConnectionID,
s.logger.Debugf("Rejecting new connection due to GetConfigForClient callback") connID,
return nil, false s.connIDGenerator,
} s.connHandler.GetStatelessResetToken(connID),
config = populateConfig(conf) config,
} s.tlsConf,
var tracer *logging.ConnectionTracer s.tokenGenerator,
if config.Tracer != nil { clientAddrVerified,
// Use the same connection ID that is passed to the client's GetLogWriter callback. tracer,
connID := hdr.DestConnectionID tracingID,
if origDestConnID.Len() > 0 { s.logger,
connID = origDestConnID hdr.Version,
} )
tracer = config.Tracer(context.WithValue(context.Background(), ConnectionTracingKey, tracingID), protocol.PerspectiveServer, connID) conn.handlePacket(p)
} // Adding the connection will fail if the client's chosen Destination Connection ID is already in use.
conn = s.newConn( // This is very unlikely: Even if an attacker chooses a connection ID that's already in use,
newSendConn(s.conn, p.remoteAddr, p.info, s.logger), // under normal circumstances the packet would just be routed to that connection.
s.connHandler, // The only time this collision will occur if we receive the two Initial packets at the same time.
origDestConnID, if added := s.connHandler.AddWithConnID(hdr.DestConnectionID, connID, conn); !added {
retrySrcConnID, delete(s.zeroRTTQueues, hdr.DestConnectionID)
hdr.DestConnectionID, conn.closeWithTransportError(qerr.ConnectionRefused)
hdr.SrcConnectionID,
connID,
s.connIDGenerator,
s.connHandler.GetStatelessResetToken(connID),
config,
s.tlsConf,
s.tokenGenerator,
clientAddrIsValid,
tracer,
tracingID,
s.logger,
hdr.Version,
)
conn.handlePacket(p)
if q, ok := s.zeroRTTQueues[hdr.DestConnectionID]; ok {
for _, p := range q.packets {
conn.handlePacket(p)
}
delete(s.zeroRTTQueues, hdr.DestConnectionID)
}
return conn, true
}); !added {
select {
case s.connectionRefusedQueue <- rejectedPacket{receivedPacket: p, hdr: hdr}:
default:
// drop packet if we can't send out the CONNECTION_REFUSED fast enough
p.buffer.Release()
}
return nil return nil
} }
// Pass queued 0-RTT to the newly established connection.
if q, ok := s.zeroRTTQueues[hdr.DestConnectionID]; ok {
for _, p := range q.packets {
conn.handlePacket(p)
}
delete(s.zeroRTTQueues, hdr.DestConnectionID)
}
go conn.run() go conn.run()
go s.handleNewConn(conn) go func() {
if conn == nil { if completed := s.handleNewConn(conn); !completed {
p.buffer.Release() return
return nil }
}
select {
case s.connQueue <- conn:
default:
conn.closeWithTransportError(ConnectionRefused)
}
}()
return nil return nil
} }
func (s *baseServer) handleNewConn(conn quicConn) { func (s *baseServer) handleNewConn(conn quicConn) bool {
connCtx := conn.Context()
if s.acceptEarlyConns { if s.acceptEarlyConns {
// wait until the early connection is ready, the handshake fails, or the server is closed // wait until the early connection is ready, the handshake fails, or the server is closed
select { select {
case <-s.errorChan: case <-s.errorChan:
conn.destroy(&qerr.TransportError{ErrorCode: ConnectionRefused}) conn.closeWithTransportError(ConnectionRefused)
return return false
case <-conn.Context().Done():
return false
case <-conn.earlyConnReady(): case <-conn.earlyConnReady():
case <-connCtx.Done(): return true
return
}
} else {
// wait until the handshake is complete (or fails)
select {
case <-s.errorChan:
conn.destroy(&qerr.TransportError{ErrorCode: ConnectionRefused})
return
case <-conn.HandshakeComplete():
case <-connCtx.Done():
return
} }
} }
// wait until the handshake completes, fails, or the server is closed
atomic.AddInt32(&s.connQueueLen, 1)
select { select {
case s.connQueue <- conn: case <-s.errorChan:
// blocks until the connection is accepted conn.closeWithTransportError(ConnectionRefused)
case <-connCtx.Done(): return false
atomic.AddInt32(&s.connQueueLen, -1) case <-conn.Context().Done():
// don't pass connections that were already closed to Accept() return false
case <-conn.HandshakeComplete():
return true
} }
} }

View File

@ -60,7 +60,7 @@ type streamI interface {
// for sending // for sending
hasData() bool hasData() bool
handleStopSendingFrame(*wire.StopSendingFrame) handleStopSendingFrame(*wire.StopSendingFrame)
popStreamFrame(maxBytes protocol.ByteCount, v protocol.VersionNumber) (ackhandler.StreamFrame, bool, bool) popStreamFrame(maxBytes protocol.ByteCount, v protocol.Version) (ackhandler.StreamFrame, bool, bool)
updateSendWindow(protocol.ByteCount) updateSendWindow(protocol.ByteCount)
} }

View File

@ -221,7 +221,7 @@ func (c *oobConn) ReadPacket() (receivedPacket, error) {
// unsigned int ipi6_ifindex; /* send/recv interface index */ // unsigned int ipi6_ifindex; /* send/recv interface index */
// }; // };
if len(body) == 20 { if len(body) == 20 {
p.info.addr = netip.AddrFrom16(*(*[16]byte)(body[:16])) p.info.addr = netip.AddrFrom16(*(*[16]byte)(body[:16])).Unmap()
p.info.ifIndex = binary.LittleEndian.Uint32(body[16:]) p.info.ifIndex = binary.LittleEndian.Uint32(body[16:])
} else { } else {
invalidCmsgOnceV6.Do(func() { invalidCmsgOnceV6.Do(func() {

View File

@ -41,7 +41,8 @@ type Transport struct {
Conn net.PacketConn Conn net.PacketConn
// The length of the connection ID in bytes. // The length of the connection ID in bytes.
// It can be 0, or any value between 4 and 18. // It can be any value between 1 and 20.
// Due to the increased risk of collisions, it is not recommended to use connection IDs shorter than 4 bytes.
// If unset, a 4 byte connection ID will be used. // If unset, a 4 byte connection ID will be used.
ConnectionIDLength int ConnectionIDLength int
@ -77,7 +78,19 @@ type Transport struct {
// It has no effect for clients. // It has no effect for clients.
DisableVersionNegotiationPackets bool DisableVersionNegotiationPackets bool
// VerifySourceAddress decides if a connection attempt originating from unvalidated source
// addresses first needs to go through source address validation using QUIC's Retry mechanism,
// as described in RFC 9000 section 8.1.2.
// Note that the address passed to this callback is unvalidated, and might be spoofed in case
// of an attack.
// Validating the source address adds one additional network roundtrip to the handshake,
// and should therefore only be used if a suspiciously high number of incoming connection is recorded.
// For most use cases, wrapping the Allow function of a rate.Limiter will be a reasonable
// implementation of this callback (negating its return value).
VerifySourceAddress func(net.Addr) bool
// A Tracer traces events that don't belong to a single QUIC connection. // A Tracer traces events that don't belong to a single QUIC connection.
// Tracer.Close is called when the transport is closed.
Tracer *logging.Tracer Tracer *logging.Tracer
handlerMap packetHandlerManager handlerMap packetHandlerManager
@ -147,7 +160,7 @@ func (t *Transport) createServer(tlsConf *tls.Config, conf *Config, allow0RTT bo
if t.server != nil { if t.server != nil {
return nil, errListenerAlreadySet return nil, errListenerAlreadySet
} }
conf = populateServerConfig(conf) conf = populateConfig(conf)
if err := t.init(false); err != nil { if err := t.init(false); err != nil {
return nil, err return nil, err
} }
@ -161,6 +174,7 @@ func (t *Transport) createServer(tlsConf *tls.Config, conf *Config, allow0RTT bo
t.closeServer, t.closeServer,
*t.TokenGeneratorKey, *t.TokenGeneratorKey,
t.MaxTokenAge, t.MaxTokenAge,
t.VerifySourceAddress,
t.DisableVersionNegotiationPackets, t.DisableVersionNegotiationPackets,
allow0RTT, allow0RTT,
) )
@ -323,6 +337,9 @@ func (t *Transport) close(e error) {
if t.server != nil { if t.server != nil {
t.server.close(e, false) t.server.close(e, false)
} }
if t.Tracer != nil && t.Tracer.Close != nil {
t.Tracer.Close()
}
t.closed = true t.closed = true
} }
@ -379,13 +396,21 @@ func (t *Transport) handlePacket(p receivedPacket) {
return return
} }
if isStatelessReset := t.maybeHandleStatelessReset(p.data); isStatelessReset { // If there's a connection associated with the connection ID, pass the packet there.
return
}
if handler, ok := t.handlerMap.Get(connID); ok { if handler, ok := t.handlerMap.Get(connID); ok {
handler.handlePacket(p) handler.handlePacket(p)
return return
} }
// RFC 9000 section 10.3.1 requires that the stateless reset detection logic is run for both
// packets that cannot be associated with any connections, and for packets that can't be decrypted.
// We deviate from the RFC and ignore the latter: If a packet's connection ID is associated with an
// existing connection, it is dropped there if if it can't be decrypted.
// Stateless resets use random connection IDs, and at reasonable connection ID lengths collisions are
// exceedingly rare. In the unlikely event that a stateless reset is misrouted to an existing connection,
// it is to be expected that the next stateless reset will be correctly detected.
if isStatelessReset := t.maybeHandleStatelessReset(p.data); isStatelessReset {
return
}
if !wire.IsLongHeaderPacket(p.data[0]) { if !wire.IsLongHeaderPacket(p.data[0]) {
t.maybeSendStatelessReset(p) t.maybeSendStatelessReset(p)
return return

37
vendor/go.uber.org/mock/CONTRIBUTORS generated vendored
View File

@ -1,37 +0,0 @@
# This is the official list of people who can contribute (and typically
# have contributed) code to the gomock repository.
# The AUTHORS file lists the copyright holders; this file
# lists people. For example, Google employees are listed here
# but not in AUTHORS, because Google holds the copyright.
#
# The submission process automatically checks to make sure
# that people submitting code are listed in this file (by email address).
#
# Names should be added to this file only after verifying that
# the individual or the individual's organization has agreed to
# the appropriate Contributor License Agreement, found here:
#
# http://code.google.com/legal/individual-cla-v1.0.html
# http://code.google.com/legal/corporate-cla-v1.0.html
#
# The agreement for individuals can be filled out on the web.
#
# When adding J Random Contributor's name to this file,
# either J's name or J's organization's name should be
# added to the AUTHORS file, depending on whether the
# individual or corporate CLA was used.
# Names should be added to this file like so:
# Name <email address>
#
# An entry with two email addresses specifies that the
# first address should be used in the submit logs and
# that the second address should be recognized as the
# same person when interacting with Rietveld.
# Please keep the list sorted.
Aaron Jacobs <jacobsa@google.com> <aaronjjacobs@gmail.com>
Alex Reece <awreece@gmail.com>
David Symonds <dsymonds@golang.org>
Ryan Barrett <ryanb@google.com>

View File

@ -11,8 +11,10 @@
package main package main
import ( import (
"errors"
"fmt" "fmt"
"go/ast" "go/ast"
"go/token"
"strings" "strings"
"go.uber.org/mock/mockgen/model" "go.uber.org/mock/mockgen/model"
@ -98,6 +100,16 @@ func (p *fileParser) parseGenericMethod(field *ast.Field, it *namedInterface, if
case *ast.IndexListExpr: case *ast.IndexListExpr:
indices = v.Indices indices = v.Indices
typ = v.X typ = v.X
case *ast.UnaryExpr:
if v.Op == token.TILDE {
return nil, errConstraintInterface
}
return nil, fmt.Errorf("~T may only appear as constraint for %T", field.Type)
case *ast.BinaryExpr:
if v.Op == token.OR {
return nil, errConstraintInterface
}
return nil, fmt.Errorf("A|B may only appear as constraint for %T", field.Type)
default: default:
return nil, fmt.Errorf("don't know how to mock method of type %T", field.Type) return nil, fmt.Errorf("don't know how to mock method of type %T", field.Type)
} }
@ -114,3 +126,5 @@ func (p *fileParser) parseGenericMethod(field *ast.Field, it *namedInterface, if
return p.parseMethod(nf, it, iface, pkg, tps) return p.parseMethod(nf, it, iface, pkg, tps)
} }
var errConstraintInterface = errors.New("interface contains constraints")

View File

@ -31,6 +31,7 @@ import (
"os/exec" "os/exec"
"path" "path"
"path/filepath" "path/filepath"
"runtime"
"sort" "sort"
"strconv" "strconv"
"strings" "strings"
@ -314,8 +315,14 @@ func (g *generator) Generate(pkg *model.Package, outputPkgName string, outputPac
} }
g.p("//") g.p("//")
g.p("// Generated by this command:") g.p("// Generated by this command:")
g.p("//")
// only log the name of the executable, not the full path // only log the name of the executable, not the full path
g.p("// %v", strings.Join(append([]string{filepath.Base(os.Args[0])}, os.Args[1:]...), " ")) name := filepath.Base(os.Args[0])
if runtime.GOOS == "windows" {
name = strings.TrimSuffix(name, ".exe")
}
g.p("//\t%v", strings.Join(append([]string{name}, os.Args[1:]...), " "))
g.p("//")
// Get all required imports, and generate unique names for them all. // Get all required imports, and generate unique names for them all.
im := pkg.Imports() im := pkg.Imports()
@ -371,7 +378,7 @@ func (g *generator) Generate(pkg *model.Package, outputPkgName string, outputPac
} }
i := 0 i := 0
for localNames[pkgName] || token.Lookup(pkgName).IsKeyword() { for localNames[pkgName] || token.Lookup(pkgName).IsKeyword() || pkgName == "any" {
pkgName = base + strconv.Itoa(i) pkgName = base + strconv.Itoa(i)
i++ i++
} }
@ -386,6 +393,10 @@ func (g *generator) Generate(pkg *model.Package, outputPkgName string, outputPac
} }
if *writePkgComment { if *writePkgComment {
// Ensure there's an empty line before the package to follow the recommendations:
// https://github.com/golang/go/wiki/CodeReviewComments#package-comments
g.p("")
g.p("// Package %v is a generated GoMock package.", outputPkgName) g.p("// Package %v is a generated GoMock package.", outputPkgName)
} }
g.p("package %v", outputPkgName) g.p("package %v", outputPkgName)
@ -508,7 +519,7 @@ func (g *generator) GenerateMockMethods(mockType string, intf *model.Interface,
g.p("") g.p("")
_ = g.GenerateMockMethod(mockType, m, pkgOverride, shortTp) _ = g.GenerateMockMethod(mockType, m, pkgOverride, shortTp)
g.p("") g.p("")
_ = g.GenerateMockRecorderMethod(intf, mockType, m, shortTp, typed) _ = g.GenerateMockRecorderMethod(intf, m, shortTp, typed)
if typed { if typed {
g.p("") g.p("")
_ = g.GenerateMockReturnCallMethod(intf, m, pkgOverride, longTp, shortTp) _ = g.GenerateMockReturnCallMethod(intf, m, pkgOverride, longTp, shortTp)
@ -596,7 +607,8 @@ func (g *generator) GenerateMockMethod(mockType string, m *model.Method, pkgOver
return nil return nil
} }
func (g *generator) GenerateMockRecorderMethod(intf *model.Interface, mockType string, m *model.Method, shortTp string, typed bool) error { func (g *generator) GenerateMockRecorderMethod(intf *model.Interface, m *model.Method, shortTp string, typed bool) error {
mockType := g.mockName(intf.Name)
argNames := g.getArgNames(m, true) argNames := g.getArgNames(m, true)
var argString string var argString string
@ -621,7 +633,7 @@ func (g *generator) GenerateMockRecorderMethod(intf *model.Interface, mockType s
g.p("// %v indicates an expected call of %v.", m.Name, m.Name) g.p("// %v indicates an expected call of %v.", m.Name, m.Name)
if typed { if typed {
g.p("func (%s *%vMockRecorder%v) %v(%v) *%s%sCall%s {", idRecv, mockType, shortTp, m.Name, argString, intf.Name, m.Name, shortTp) g.p("func (%s *%vMockRecorder%v) %v(%v) *%s%sCall%s {", idRecv, mockType, shortTp, m.Name, argString, mockType, m.Name, shortTp)
} else { } else {
g.p("func (%s *%vMockRecorder%v) %v(%v) *gomock.Call {", idRecv, mockType, shortTp, m.Name, argString) g.p("func (%s *%vMockRecorder%v) %v(%v) *gomock.Call {", idRecv, mockType, shortTp, m.Name, argString)
} }
@ -650,7 +662,7 @@ func (g *generator) GenerateMockRecorderMethod(intf *model.Interface, mockType s
} }
if typed { if typed {
g.p(`call := %s.mock.ctrl.RecordCallWithMethodType(%s.mock, "%s", reflect.TypeOf((*%s%s)(nil).%s)%s)`, idRecv, idRecv, m.Name, mockType, shortTp, m.Name, callArgs) g.p(`call := %s.mock.ctrl.RecordCallWithMethodType(%s.mock, "%s", reflect.TypeOf((*%s%s)(nil).%s)%s)`, idRecv, idRecv, m.Name, mockType, shortTp, m.Name, callArgs)
g.p(`return &%s%sCall%s{Call: call}`, intf.Name, m.Name, shortTp) g.p(`return &%s%sCall%s{Call: call}`, mockType, m.Name, shortTp)
} else { } else {
g.p(`return %s.mock.ctrl.RecordCallWithMethodType(%s.mock, "%s", reflect.TypeOf((*%s%s)(nil).%s)%s)`, idRecv, idRecv, m.Name, mockType, shortTp, m.Name, callArgs) g.p(`return %s.mock.ctrl.RecordCallWithMethodType(%s.mock, "%s", reflect.TypeOf((*%s%s)(nil).%s)%s)`, idRecv, idRecv, m.Name, mockType, shortTp, m.Name, callArgs)
} }
@ -661,6 +673,7 @@ func (g *generator) GenerateMockRecorderMethod(intf *model.Interface, mockType s
} }
func (g *generator) GenerateMockReturnCallMethod(intf *model.Interface, m *model.Method, pkgOverride, longTp, shortTp string) error { func (g *generator) GenerateMockReturnCallMethod(intf *model.Interface, m *model.Method, pkgOverride, longTp, shortTp string) error {
mockType := g.mockName(intf.Name)
argNames := g.getArgNames(m, true /* in */) argNames := g.getArgNames(m, true /* in */)
retNames := g.getArgNames(m, false /* out */) retNames := g.getArgNames(m, false /* out */)
argTypes := g.getArgTypes(m, pkgOverride, true /* in */) argTypes := g.getArgTypes(m, pkgOverride, true /* in */)
@ -683,10 +696,10 @@ func (g *generator) GenerateMockReturnCallMethod(intf *model.Interface, m *model
ia := newIdentifierAllocator(argNames) ia := newIdentifierAllocator(argNames)
idRecv := ia.allocateIdentifier("c") idRecv := ia.allocateIdentifier("c")
recvStructName := intf.Name + m.Name recvStructName := mockType + m.Name
g.p("// %s%sCall wrap *gomock.Call", intf.Name, m.Name) g.p("// %s%sCall wrap *gomock.Call", mockType, m.Name)
g.p("type %s%sCall%s struct{", intf.Name, m.Name, longTp) g.p("type %s%sCall%s struct{", mockType, m.Name, longTp)
g.in() g.in()
g.p("*gomock.Call") g.p("*gomock.Call")
g.out() g.out()

View File

@ -305,7 +305,7 @@ type PredeclaredType string
func (pt PredeclaredType) String(map[string]string, string) string { return string(pt) } func (pt PredeclaredType) String(map[string]string, string) string { return string(pt) }
func (pt PredeclaredType) addImports(map[string]bool) {} func (pt PredeclaredType) addImports(map[string]bool) {}
// TypeParametersType contains type paramters for a NamedType. // TypeParametersType contains type parameters for a NamedType.
type TypeParametersType struct { type TypeParametersType struct {
TypeParameters []Type TypeParameters []Type
} }

View File

@ -232,6 +232,9 @@ func (p *fileParser) parseFile(importPath string, file *ast.File) (*model.Packag
continue continue
} }
i, err := p.parseInterface(ni.name.String(), importPath, ni) i, err := p.parseInterface(ni.name.String(), importPath, ni)
if errors.Is(err, errConstraintInterface) {
continue
}
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -187,6 +187,7 @@ type reflectData struct {
// gob encoding of a model.Package to standard output. // gob encoding of a model.Package to standard output.
// JSON doesn't work because of the model.Type interface. // JSON doesn't work because of the model.Type interface.
var reflectProgram = template.Must(template.New("program").Parse(` var reflectProgram = template.Must(template.New("program").Parse(`
// Code generated by MockGen. DO NOT EDIT.
package main package main
import ( import (

4
vendor/modules.txt vendored
View File

@ -230,7 +230,7 @@ github.com/prometheus/common/model
github.com/prometheus/procfs github.com/prometheus/procfs
github.com/prometheus/procfs/internal/fs github.com/prometheus/procfs/internal/fs
github.com/prometheus/procfs/internal/util github.com/prometheus/procfs/internal/util
# github.com/quic-go/quic-go v0.40.1-0.20240101045026-22b7f7744eb6 # github.com/quic-go/quic-go v0.42.0
## explicit; go 1.21 ## explicit; go 1.21
github.com/quic-go/quic-go github.com/quic-go/quic-go
github.com/quic-go/quic-go/internal/ackhandler github.com/quic-go/quic-go/internal/ackhandler
@ -313,7 +313,7 @@ go.opentelemetry.io/proto/otlp/trace/v1
go.uber.org/automaxprocs/internal/cgroups go.uber.org/automaxprocs/internal/cgroups
go.uber.org/automaxprocs/internal/runtime go.uber.org/automaxprocs/internal/runtime
go.uber.org/automaxprocs/maxprocs go.uber.org/automaxprocs/maxprocs
# go.uber.org/mock v0.3.0 # go.uber.org/mock v0.4.0
## explicit; go 1.20 ## explicit; go 1.20
go.uber.org/mock/mockgen go.uber.org/mock/mockgen
go.uber.org/mock/mockgen/model go.uber.org/mock/mockgen/model