TUN-8371: Bump quic-go to v0.42.0
## Summary We discovered that we were being impacted by a bug in quic-go, that could create deadlocks and not close connections. This commit bumps quic-go to the version that contains the fix to prevent that from happening.
This commit is contained in:
parent
5e5f2f4d8c
commit
84833011ec
4
go.mod
4
go.mod
|
@ -25,7 +25,7 @@ require (
|
||||||
github.com/pkg/errors v0.9.1
|
github.com/pkg/errors v0.9.1
|
||||||
github.com/prometheus/client_golang v1.13.0
|
github.com/prometheus/client_golang v1.13.0
|
||||||
github.com/prometheus/client_model v0.2.0
|
github.com/prometheus/client_model v0.2.0
|
||||||
github.com/quic-go/quic-go v0.40.1-0.20240101045026-22b7f7744eb6
|
github.com/quic-go/quic-go v0.42.0
|
||||||
github.com/rs/zerolog v1.20.0
|
github.com/rs/zerolog v1.20.0
|
||||||
github.com/stretchr/testify v1.8.4
|
github.com/stretchr/testify v1.8.4
|
||||||
github.com/urfave/cli/v2 v2.3.0
|
github.com/urfave/cli/v2 v2.3.0
|
||||||
|
@ -84,7 +84,7 @@ require (
|
||||||
github.com/prometheus/procfs v0.8.0 // indirect
|
github.com/prometheus/procfs v0.8.0 // indirect
|
||||||
github.com/russross/blackfriday/v2 v2.1.0 // indirect
|
github.com/russross/blackfriday/v2 v2.1.0 // indirect
|
||||||
go.opentelemetry.io/otel/metric v1.21.0 // indirect
|
go.opentelemetry.io/otel/metric v1.21.0 // indirect
|
||||||
go.uber.org/mock v0.3.0 // indirect
|
go.uber.org/mock v0.4.0 // indirect
|
||||||
golang.org/x/exp v0.0.0-20221205204356-47842c84f3db // indirect
|
golang.org/x/exp v0.0.0-20221205204356-47842c84f3db // indirect
|
||||||
golang.org/x/mod v0.11.0 // indirect
|
golang.org/x/mod v0.11.0 // indirect
|
||||||
golang.org/x/oauth2 v0.13.0 // indirect
|
golang.org/x/oauth2 v0.13.0 // indirect
|
||||||
|
|
10
go.sum
10
go.sum
|
@ -322,8 +322,8 @@ github.com/prometheus/procfs v0.6.0/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1
|
||||||
github.com/prometheus/procfs v0.7.3/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1xBZuNvfVA=
|
github.com/prometheus/procfs v0.7.3/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1xBZuNvfVA=
|
||||||
github.com/prometheus/procfs v0.8.0 h1:ODq8ZFEaYeCaZOJlZZdJA2AbQR98dSHSM1KW/You5mo=
|
github.com/prometheus/procfs v0.8.0 h1:ODq8ZFEaYeCaZOJlZZdJA2AbQR98dSHSM1KW/You5mo=
|
||||||
github.com/prometheus/procfs v0.8.0/go.mod h1:z7EfXMXOkbkqb9IINtpCn86r/to3BnA0uaxHdg830/4=
|
github.com/prometheus/procfs v0.8.0/go.mod h1:z7EfXMXOkbkqb9IINtpCn86r/to3BnA0uaxHdg830/4=
|
||||||
github.com/quic-go/quic-go v0.40.1-0.20240101045026-22b7f7744eb6 h1:OI4WiysowCcxLtcZMGBZildo12di3ljcMN4vWdUQpoU=
|
github.com/quic-go/quic-go v0.42.0 h1:uSfdap0eveIl8KXnipv9K7nlwZ5IqLlYOpJ58u5utpM=
|
||||||
github.com/quic-go/quic-go v0.40.1-0.20240101045026-22b7f7744eb6/go.mod h1:qCkNjqczPEvgsOnxZ0eCD14lv+B2LHlFAB++CNOh9hA=
|
github.com/quic-go/quic-go v0.42.0/go.mod h1:132kz4kL3F9vxhW3CtQJLDVwcFe5wdWeJXXijhsO57M=
|
||||||
github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4=
|
github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4=
|
||||||
github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ=
|
github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ=
|
||||||
github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog=
|
github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog=
|
||||||
|
@ -381,8 +381,8 @@ go.opentelemetry.io/proto/otlp v1.0.0 h1:T0TX0tmXU8a3CbNXzEKGeU5mIVOdf0oykP+u2lI
|
||||||
go.opentelemetry.io/proto/otlp v1.0.0/go.mod h1:Sy6pihPLfYHkr3NkUbEhGHFhINUSI/v80hjKIs5JXpM=
|
go.opentelemetry.io/proto/otlp v1.0.0/go.mod h1:Sy6pihPLfYHkr3NkUbEhGHFhINUSI/v80hjKIs5JXpM=
|
||||||
go.uber.org/automaxprocs v1.4.0 h1:CpDZl6aOlLhReez+8S3eEotD7Jx0Os++lemPlMULQP0=
|
go.uber.org/automaxprocs v1.4.0 h1:CpDZl6aOlLhReez+8S3eEotD7Jx0Os++lemPlMULQP0=
|
||||||
go.uber.org/automaxprocs v1.4.0/go.mod h1:/mTEdr7LvHhs0v7mjdxDreTz1OG5zdZGqgOnhWiR/+Q=
|
go.uber.org/automaxprocs v1.4.0/go.mod h1:/mTEdr7LvHhs0v7mjdxDreTz1OG5zdZGqgOnhWiR/+Q=
|
||||||
go.uber.org/mock v0.3.0 h1:3mUxI1No2/60yUYax92Pt8eNOEecx2D3lcXZh2NEZJo=
|
go.uber.org/mock v0.4.0 h1:VcM4ZOtdbR4f6VXfiOpwpVJDL6lCReaZ6mw31wqh7KU=
|
||||||
go.uber.org/mock v0.3.0/go.mod h1:a6FSlNadKUHUa9IP5Vyt1zh4fC7uAwxMutEAscFbkZc=
|
go.uber.org/mock v0.4.0/go.mod h1:a6FSlNadKUHUa9IP5Vyt1zh4fC7uAwxMutEAscFbkZc=
|
||||||
golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
|
golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
|
||||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||||
golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||||
|
@ -552,6 +552,8 @@ golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
|
||||||
golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
||||||
golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
||||||
golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
||||||
|
golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk=
|
||||||
|
golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
|
||||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||||
golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||||
golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY=
|
golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY=
|
||||||
|
|
|
@ -2,16 +2,6 @@ run:
|
||||||
skip-files:
|
skip-files:
|
||||||
- internal/handshake/cipher_suite.go
|
- internal/handshake/cipher_suite.go
|
||||||
linters-settings:
|
linters-settings:
|
||||||
depguard:
|
|
||||||
rules:
|
|
||||||
qtls:
|
|
||||||
list-mode: lax
|
|
||||||
files:
|
|
||||||
- "!internal/qtls/**"
|
|
||||||
- "$all"
|
|
||||||
deny:
|
|
||||||
- pkg: github.com/quic-go/qtls-go1-20
|
|
||||||
desc: "importing qtls only allowed in internal/qtls"
|
|
||||||
misspell:
|
misspell:
|
||||||
ignore-words:
|
ignore-words:
|
||||||
- ect
|
- ect
|
||||||
|
@ -20,7 +10,6 @@ linters:
|
||||||
disable-all: true
|
disable-all: true
|
||||||
enable:
|
enable:
|
||||||
- asciicheck
|
- asciicheck
|
||||||
- depguard
|
|
||||||
- exhaustive
|
- exhaustive
|
||||||
- exportloopref
|
- exportloopref
|
||||||
- goimports
|
- goimports
|
||||||
|
|
|
@ -183,26 +183,20 @@ quic-go logs a wide range of events defined in [draft-ietf-quic-qlog-quic-events
|
||||||
|
|
||||||
qlog files can be processed by a number of 3rd-party tools. [qviz](https://qvis.quictools.info/) has proven very useful for debugging all kinds of QUIC connection failures.
|
qlog files can be processed by a number of 3rd-party tools. [qviz](https://qvis.quictools.info/) has proven very useful for debugging all kinds of QUIC connection failures.
|
||||||
|
|
||||||
qlog is activated by setting a `Tracer` callback on the `Config`. It is called as soon as quic-go decides to starts the QUIC handshake on a new connection.
|
qlog can be activated by setting the `Tracer` callback on the `Config`. It is called as soon as quic-go decides to start the QUIC handshake on a new connection.
|
||||||
A useful implementation of this callback could look like this:
|
`qlog.DefaultTracer` provides a tracer implementation which writes qlog files to a directory specified by the `QLOGDIR` environment variable, if set.
|
||||||
|
The default qlog tracer can be used like this:
|
||||||
```go
|
```go
|
||||||
quic.Config{
|
quic.Config{
|
||||||
Tracer: func(ctx context.Context, p logging.Perspective, connID quic.ConnectionID) *logging.ConnectionTracer {
|
Tracer: qlog.DefaultTracer,
|
||||||
role := "server"
|
|
||||||
if p == logging.PerspectiveClient {
|
|
||||||
role = "client"
|
|
||||||
}
|
|
||||||
filename := fmt.Sprintf("./log_%s_%s.qlog", connID, role)
|
|
||||||
f, err := os.Create(filename)
|
|
||||||
// handle the error
|
|
||||||
return qlog.NewConnectionTracer(f, p, connID)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
This implementation of the callback creates a new qlog file in the current directory named `log_<client / server>_<QUIC connection ID>.qlog`.
|
This example creates a new qlog file under `<QLOGDIR>/<Original Destination Connection ID>_<Vantage Point>.qlog`, e.g. `qlogs/2e0407da_client.qlog`.
|
||||||
|
|
||||||
|
|
||||||
|
For custom qlog behavior, `qlog.NewConnectionTracer` can be used.
|
||||||
|
|
||||||
## Using HTTP/3
|
## Using HTTP/3
|
||||||
|
|
||||||
### As a server
|
### As a server
|
||||||
|
@ -232,11 +226,13 @@ http.Client{
|
||||||
| [algernon](https://github.com/xyproto/algernon) | Small self-contained pure-Go web server with Lua, Markdown, HTTP/2, QUIC, Redis and PostgreSQL support | ![GitHub Repo stars](https://img.shields.io/github/stars/xyproto/algernon?style=flat-square) |
|
| [algernon](https://github.com/xyproto/algernon) | Small self-contained pure-Go web server with Lua, Markdown, HTTP/2, QUIC, Redis and PostgreSQL support | ![GitHub Repo stars](https://img.shields.io/github/stars/xyproto/algernon?style=flat-square) |
|
||||||
| [caddy](https://github.com/caddyserver/caddy/) | Fast, multi-platform web server with automatic HTTPS | ![GitHub Repo stars](https://img.shields.io/github/stars/caddyserver/caddy?style=flat-square) |
|
| [caddy](https://github.com/caddyserver/caddy/) | Fast, multi-platform web server with automatic HTTPS | ![GitHub Repo stars](https://img.shields.io/github/stars/caddyserver/caddy?style=flat-square) |
|
||||||
| [cloudflared](https://github.com/cloudflare/cloudflared) | A tunneling daemon that proxies traffic from the Cloudflare network to your origins | ![GitHub Repo stars](https://img.shields.io/github/stars/cloudflare/cloudflared?style=flat-square) |
|
| [cloudflared](https://github.com/cloudflare/cloudflared) | A tunneling daemon that proxies traffic from the Cloudflare network to your origins | ![GitHub Repo stars](https://img.shields.io/github/stars/cloudflare/cloudflared?style=flat-square) |
|
||||||
|
| [frp](https://github.com/fatedier/frp) | A fast reverse proxy to help you expose a local server behind a NAT or firewall to the internet | ![GitHub Repo stars](https://img.shields.io/github/stars/fatedier/frp?style=flat-square) |
|
||||||
| [go-libp2p](https://github.com/libp2p/go-libp2p) | libp2p implementation in Go, powering [Kubo](https://github.com/ipfs/kubo) (IPFS) and [Lotus](https://github.com/filecoin-project/lotus) (Filecoin), among others | ![GitHub Repo stars](https://img.shields.io/github/stars/libp2p/go-libp2p?style=flat-square) |
|
| [go-libp2p](https://github.com/libp2p/go-libp2p) | libp2p implementation in Go, powering [Kubo](https://github.com/ipfs/kubo) (IPFS) and [Lotus](https://github.com/filecoin-project/lotus) (Filecoin), among others | ![GitHub Repo stars](https://img.shields.io/github/stars/libp2p/go-libp2p?style=flat-square) |
|
||||||
| [gost](https://github.com/go-gost/gost) | A simple security tunnel written in Go | ![GitHub Repo stars](https://img.shields.io/github/stars/go-gost/gost?style=flat-square) |
|
| [gost](https://github.com/go-gost/gost) | A simple security tunnel written in Go | ![GitHub Repo stars](https://img.shields.io/github/stars/go-gost/gost?style=flat-square) |
|
||||||
| [Hysteria](https://github.com/apernet/hysteria) | A powerful, lightning fast and censorship resistant proxy | ![GitHub Repo stars](https://img.shields.io/github/stars/apernet/hysteria?style=flat-square) |
|
| [Hysteria](https://github.com/apernet/hysteria) | A powerful, lightning fast and censorship resistant proxy | ![GitHub Repo stars](https://img.shields.io/github/stars/apernet/hysteria?style=flat-square) |
|
||||||
| [Mercure](https://github.com/dunglas/mercure) | An open, easy, fast, reliable and battery-efficient solution for real-time communications | ![GitHub Repo stars](https://img.shields.io/github/stars/dunglas/mercure?style=flat-square) |
|
| [Mercure](https://github.com/dunglas/mercure) | An open, easy, fast, reliable and battery-efficient solution for real-time communications | ![GitHub Repo stars](https://img.shields.io/github/stars/dunglas/mercure?style=flat-square) |
|
||||||
| [OONI Probe](https://github.com/ooni/probe-cli) | Next generation OONI Probe. Library and CLI tool. | ![GitHub Repo stars](https://img.shields.io/github/stars/ooni/probe-cli?style=flat-square) |
|
| [OONI Probe](https://github.com/ooni/probe-cli) | Next generation OONI Probe. Library and CLI tool. | ![GitHub Repo stars](https://img.shields.io/github/stars/ooni/probe-cli?style=flat-square) |
|
||||||
|
| [RoadRunner](https://github.com/roadrunner-server/roadrunner) | High-performance PHP application server, process manager written in Go and powered with plugins | ![GitHub Repo stars](https://img.shields.io/github/stars/roadrunner-server/roadrunner?style=flat-square) |
|
||||||
| [syncthing](https://github.com/syncthing/syncthing/) | Open Source Continuous File Synchronization | ![GitHub Repo stars](https://img.shields.io/github/stars/syncthing/syncthing?style=flat-square) |
|
| [syncthing](https://github.com/syncthing/syncthing/) | Open Source Continuous File Synchronization | ![GitHub Repo stars](https://img.shields.io/github/stars/syncthing/syncthing?style=flat-square) |
|
||||||
| [traefik](https://github.com/traefik/traefik) | The Cloud Native Application Proxy | ![GitHub Repo stars](https://img.shields.io/github/stars/traefik/traefik?style=flat-square) |
|
| [traefik](https://github.com/traefik/traefik) | The Cloud Native Application Proxy | ![GitHub Repo stars](https://img.shields.io/github/stars/traefik/traefik?style=flat-square) |
|
||||||
| [v2ray-core](https://github.com/v2fly/v2ray-core) | A platform for building proxies to bypass network restrictions | ![GitHub Repo stars](https://img.shields.io/github/stars/v2fly/v2ray-core?style=flat-square) |
|
| [v2ray-core](https://github.com/v2fly/v2ray-core) | A platform for building proxies to bypass network restrictions | ![GitHub Repo stars](https://img.shields.io/github/stars/v2fly/v2ray-core?style=flat-square) |
|
||||||
|
|
|
@ -28,7 +28,7 @@ type client struct {
|
||||||
|
|
||||||
initialPacketNumber protocol.PacketNumber
|
initialPacketNumber protocol.PacketNumber
|
||||||
hasNegotiatedVersion bool
|
hasNegotiatedVersion bool
|
||||||
version protocol.VersionNumber
|
version protocol.Version
|
||||||
|
|
||||||
handshakeChan chan struct{}
|
handshakeChan chan struct{}
|
||||||
|
|
||||||
|
@ -232,7 +232,7 @@ func (c *client) dial(ctx context.Context) error {
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
c.conn.shutdown()
|
c.conn.destroy(nil)
|
||||||
return context.Cause(ctx)
|
return context.Cause(ctx)
|
||||||
case err := <-errorChan:
|
case err := <-errorChan:
|
||||||
return err
|
return err
|
||||||
|
|
|
@ -4,7 +4,6 @@ import (
|
||||||
"math/bits"
|
"math/bits"
|
||||||
"net"
|
"net"
|
||||||
|
|
||||||
"github.com/quic-go/quic-go/internal/protocol"
|
|
||||||
"github.com/quic-go/quic-go/internal/utils"
|
"github.com/quic-go/quic-go/internal/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -12,9 +11,8 @@ import (
|
||||||
// When receiving packets for such a connection, we need to retransmit the packet containing the CONNECTION_CLOSE frame,
|
// When receiving packets for such a connection, we need to retransmit the packet containing the CONNECTION_CLOSE frame,
|
||||||
// with an exponential backoff.
|
// with an exponential backoff.
|
||||||
type closedLocalConn struct {
|
type closedLocalConn struct {
|
||||||
counter uint32
|
counter uint32
|
||||||
perspective protocol.Perspective
|
logger utils.Logger
|
||||||
logger utils.Logger
|
|
||||||
|
|
||||||
sendPacket func(net.Addr, packetInfo)
|
sendPacket func(net.Addr, packetInfo)
|
||||||
}
|
}
|
||||||
|
@ -22,11 +20,10 @@ type closedLocalConn struct {
|
||||||
var _ packetHandler = &closedLocalConn{}
|
var _ packetHandler = &closedLocalConn{}
|
||||||
|
|
||||||
// newClosedLocalConn creates a new closedLocalConn and runs it.
|
// newClosedLocalConn creates a new closedLocalConn and runs it.
|
||||||
func newClosedLocalConn(sendPacket func(net.Addr, packetInfo), pers protocol.Perspective, logger utils.Logger) packetHandler {
|
func newClosedLocalConn(sendPacket func(net.Addr, packetInfo), logger utils.Logger) packetHandler {
|
||||||
return &closedLocalConn{
|
return &closedLocalConn{
|
||||||
sendPacket: sendPacket,
|
sendPacket: sendPacket,
|
||||||
perspective: pers,
|
logger: logger,
|
||||||
logger: logger,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -41,24 +38,20 @@ func (c *closedLocalConn) handlePacket(p receivedPacket) {
|
||||||
c.sendPacket(p.remoteAddr, p.info)
|
c.sendPacket(p.remoteAddr, p.info)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *closedLocalConn) shutdown() {}
|
func (c *closedLocalConn) destroy(error) {}
|
||||||
func (c *closedLocalConn) destroy(error) {}
|
func (c *closedLocalConn) closeWithTransportError(TransportErrorCode) {}
|
||||||
func (c *closedLocalConn) getPerspective() protocol.Perspective { return c.perspective }
|
|
||||||
|
|
||||||
// A closedRemoteConn is a connection that was closed remotely.
|
// A closedRemoteConn is a connection that was closed remotely.
|
||||||
// For such a connection, we might receive reordered packets that were sent before the CONNECTION_CLOSE.
|
// For such a connection, we might receive reordered packets that were sent before the CONNECTION_CLOSE.
|
||||||
// We can just ignore those packets.
|
// We can just ignore those packets.
|
||||||
type closedRemoteConn struct {
|
type closedRemoteConn struct{}
|
||||||
perspective protocol.Perspective
|
|
||||||
}
|
|
||||||
|
|
||||||
var _ packetHandler = &closedRemoteConn{}
|
var _ packetHandler = &closedRemoteConn{}
|
||||||
|
|
||||||
func newClosedRemoteConn(pers protocol.Perspective) packetHandler {
|
func newClosedRemoteConn() packetHandler {
|
||||||
return &closedRemoteConn{perspective: pers}
|
return &closedRemoteConn{}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *closedRemoteConn) handlePacket(receivedPacket) {}
|
func (c *closedRemoteConn) handlePacket(receivedPacket) {}
|
||||||
func (s *closedRemoteConn) shutdown() {}
|
func (c *closedRemoteConn) destroy(error) {}
|
||||||
func (s *closedRemoteConn) destroy(error) {}
|
func (c *closedRemoteConn) closeWithTransportError(TransportErrorCode) {}
|
||||||
func (s *closedRemoteConn) getPerspective() protocol.Perspective { return s.perspective }
|
|
||||||
|
|
|
@ -5,6 +5,8 @@ coverage:
|
||||||
- interop/
|
- interop/
|
||||||
- internal/handshake/cipher_suite.go
|
- internal/handshake/cipher_suite.go
|
||||||
- internal/utils/linkedlist/linkedlist.go
|
- internal/utils/linkedlist/linkedlist.go
|
||||||
|
- internal/testdata
|
||||||
|
- testutils/
|
||||||
- fuzzing/
|
- fuzzing/
|
||||||
- metrics/
|
- metrics/
|
||||||
status:
|
status:
|
||||||
|
|
|
@ -2,7 +2,6 @@ package quic
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/quic-go/quic-go/internal/protocol"
|
"github.com/quic-go/quic-go/internal/protocol"
|
||||||
|
@ -49,16 +48,6 @@ func validateConfig(config *Config) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// populateServerConfig populates fields in the quic.Config with their default values, if none are set
|
|
||||||
// it may be called with nil
|
|
||||||
func populateServerConfig(config *Config) *Config {
|
|
||||||
config = populateConfig(config)
|
|
||||||
if config.RequireAddressValidation == nil {
|
|
||||||
config.RequireAddressValidation = func(net.Addr) bool { return false }
|
|
||||||
}
|
|
||||||
return config
|
|
||||||
}
|
|
||||||
|
|
||||||
// populateConfig populates fields in the quic.Config with their default values, if none are set
|
// populateConfig populates fields in the quic.Config with their default values, if none are set
|
||||||
// it may be called with nil
|
// it may be called with nil
|
||||||
func populateConfig(config *Config) *Config {
|
func populateConfig(config *Config) *Config {
|
||||||
|
@ -111,7 +100,6 @@ func populateConfig(config *Config) *Config {
|
||||||
Versions: versions,
|
Versions: versions,
|
||||||
HandshakeIdleTimeout: handshakeIdleTimeout,
|
HandshakeIdleTimeout: handshakeIdleTimeout,
|
||||||
MaxIdleTimeout: idleTimeout,
|
MaxIdleTimeout: idleTimeout,
|
||||||
RequireAddressValidation: config.RequireAddressValidation,
|
|
||||||
KeepAlivePeriod: config.KeepAlivePeriod,
|
KeepAlivePeriod: config.KeepAlivePeriod,
|
||||||
InitialStreamReceiveWindow: initialStreamReceiveWindow,
|
InitialStreamReceiveWindow: initialStreamReceiveWindow,
|
||||||
MaxStreamReceiveWindow: maxStreamReceiveWindow,
|
MaxStreamReceiveWindow: maxStreamReceiveWindow,
|
||||||
|
|
|
@ -19,7 +19,7 @@ type connIDGenerator struct {
|
||||||
getStatelessResetToken func(protocol.ConnectionID) protocol.StatelessResetToken
|
getStatelessResetToken func(protocol.ConnectionID) protocol.StatelessResetToken
|
||||||
removeConnectionID func(protocol.ConnectionID)
|
removeConnectionID func(protocol.ConnectionID)
|
||||||
retireConnectionID func(protocol.ConnectionID)
|
retireConnectionID func(protocol.ConnectionID)
|
||||||
replaceWithClosed func([]protocol.ConnectionID, protocol.Perspective, []byte)
|
replaceWithClosed func([]protocol.ConnectionID, []byte)
|
||||||
queueControlFrame func(wire.Frame)
|
queueControlFrame func(wire.Frame)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -30,7 +30,7 @@ func newConnIDGenerator(
|
||||||
getStatelessResetToken func(protocol.ConnectionID) protocol.StatelessResetToken,
|
getStatelessResetToken func(protocol.ConnectionID) protocol.StatelessResetToken,
|
||||||
removeConnectionID func(protocol.ConnectionID),
|
removeConnectionID func(protocol.ConnectionID),
|
||||||
retireConnectionID func(protocol.ConnectionID),
|
retireConnectionID func(protocol.ConnectionID),
|
||||||
replaceWithClosed func([]protocol.ConnectionID, protocol.Perspective, []byte),
|
replaceWithClosed func([]protocol.ConnectionID, []byte),
|
||||||
queueControlFrame func(wire.Frame),
|
queueControlFrame func(wire.Frame),
|
||||||
generator ConnectionIDGenerator,
|
generator ConnectionIDGenerator,
|
||||||
) *connIDGenerator {
|
) *connIDGenerator {
|
||||||
|
@ -126,7 +126,7 @@ func (m *connIDGenerator) RemoveAll() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *connIDGenerator) ReplaceWithClosed(pers protocol.Perspective, connClose []byte) {
|
func (m *connIDGenerator) ReplaceWithClosed(connClose []byte) {
|
||||||
connIDs := make([]protocol.ConnectionID, 0, len(m.activeSrcConnIDs)+1)
|
connIDs := make([]protocol.ConnectionID, 0, len(m.activeSrcConnIDs)+1)
|
||||||
if m.initialClientDestConnID != nil {
|
if m.initialClientDestConnID != nil {
|
||||||
connIDs = append(connIDs, *m.initialClientDestConnID)
|
connIDs = append(connIDs, *m.initialClientDestConnID)
|
||||||
|
@ -134,5 +134,5 @@ func (m *connIDGenerator) ReplaceWithClosed(pers protocol.Perspective, connClose
|
||||||
for _, connID := range m.activeSrcConnIDs {
|
for _, connID := range m.activeSrcConnIDs {
|
||||||
connIDs = append(connIDs, connID)
|
connIDs = append(connIDs, connID)
|
||||||
}
|
}
|
||||||
m.replaceWithClosed(connIDs, pers, connClose)
|
m.replaceWithClosed(connIDs, connClose)
|
||||||
}
|
}
|
||||||
|
|
|
@ -25,7 +25,7 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
type unpacker interface {
|
type unpacker interface {
|
||||||
UnpackLongHeader(hdr *wire.Header, rcvTime time.Time, data []byte, v protocol.VersionNumber) (*unpackedPacket, error)
|
UnpackLongHeader(hdr *wire.Header, rcvTime time.Time, data []byte, v protocol.Version) (*unpackedPacket, error)
|
||||||
UnpackShortHeader(rcvTime time.Time, data []byte) (protocol.PacketNumber, protocol.PacketNumberLen, protocol.KeyPhaseBit, []byte, error)
|
UnpackShortHeader(rcvTime time.Time, data []byte) (protocol.PacketNumber, protocol.PacketNumberLen, protocol.KeyPhaseBit, []byte, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -93,7 +93,7 @@ type connRunner interface {
|
||||||
GetStatelessResetToken(protocol.ConnectionID) protocol.StatelessResetToken
|
GetStatelessResetToken(protocol.ConnectionID) protocol.StatelessResetToken
|
||||||
Retire(protocol.ConnectionID)
|
Retire(protocol.ConnectionID)
|
||||||
Remove(protocol.ConnectionID)
|
Remove(protocol.ConnectionID)
|
||||||
ReplaceWithClosed([]protocol.ConnectionID, protocol.Perspective, []byte)
|
ReplaceWithClosed([]protocol.ConnectionID, []byte)
|
||||||
AddResetToken(protocol.StatelessResetToken, packetHandler)
|
AddResetToken(protocol.StatelessResetToken, packetHandler)
|
||||||
RemoveResetToken(protocol.StatelessResetToken)
|
RemoveResetToken(protocol.StatelessResetToken)
|
||||||
}
|
}
|
||||||
|
@ -106,7 +106,7 @@ type closeError struct {
|
||||||
|
|
||||||
type errCloseForRecreating struct {
|
type errCloseForRecreating struct {
|
||||||
nextPacketNumber protocol.PacketNumber
|
nextPacketNumber protocol.PacketNumber
|
||||||
nextVersion protocol.VersionNumber
|
nextVersion protocol.Version
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *errCloseForRecreating) Error() string {
|
func (e *errCloseForRecreating) Error() string {
|
||||||
|
@ -128,7 +128,7 @@ type connection struct {
|
||||||
srcConnIDLen int
|
srcConnIDLen int
|
||||||
|
|
||||||
perspective protocol.Perspective
|
perspective protocol.Perspective
|
||||||
version protocol.VersionNumber
|
version protocol.Version
|
||||||
config *Config
|
config *Config
|
||||||
|
|
||||||
conn sendConn
|
conn sendConn
|
||||||
|
@ -177,6 +177,7 @@ type connection struct {
|
||||||
|
|
||||||
earlyConnReadyChan chan struct{}
|
earlyConnReadyChan chan struct{}
|
||||||
sentFirstPacket bool
|
sentFirstPacket bool
|
||||||
|
droppedInitialKeys bool
|
||||||
handshakeComplete bool
|
handshakeComplete bool
|
||||||
handshakeConfirmed bool
|
handshakeConfirmed bool
|
||||||
|
|
||||||
|
@ -235,7 +236,7 @@ var newConnection = func(
|
||||||
tracer *logging.ConnectionTracer,
|
tracer *logging.ConnectionTracer,
|
||||||
tracingID uint64,
|
tracingID uint64,
|
||||||
logger utils.Logger,
|
logger utils.Logger,
|
||||||
v protocol.VersionNumber,
|
v protocol.Version,
|
||||||
) quicConn {
|
) quicConn {
|
||||||
s := &connection{
|
s := &connection{
|
||||||
conn: conn,
|
conn: conn,
|
||||||
|
@ -348,7 +349,7 @@ var newClientConnection = func(
|
||||||
tracer *logging.ConnectionTracer,
|
tracer *logging.ConnectionTracer,
|
||||||
tracingID uint64,
|
tracingID uint64,
|
||||||
logger utils.Logger,
|
logger utils.Logger,
|
||||||
v protocol.VersionNumber,
|
v protocol.Version,
|
||||||
) quicConn {
|
) quicConn {
|
||||||
s := &connection{
|
s := &connection{
|
||||||
conn: conn,
|
conn: conn,
|
||||||
|
@ -453,7 +454,7 @@ func (s *connection) preSetup() {
|
||||||
s.handshakeStream = newCryptoStream()
|
s.handshakeStream = newCryptoStream()
|
||||||
s.sendQueue = newSendQueue(s.conn)
|
s.sendQueue = newSendQueue(s.conn)
|
||||||
s.retransmissionQueue = newRetransmissionQueue()
|
s.retransmissionQueue = newRetransmissionQueue()
|
||||||
s.frameParser = wire.NewFrameParser(s.config.EnableDatagrams)
|
s.frameParser = *wire.NewFrameParser(s.config.EnableDatagrams)
|
||||||
s.rttStats = &utils.RTTStats{}
|
s.rttStats = &utils.RTTStats{}
|
||||||
s.connFlowController = flowcontrol.NewConnectionFlowController(
|
s.connFlowController = flowcontrol.NewConnectionFlowController(
|
||||||
protocol.ByteCount(s.config.InitialConnectionReceiveWindow),
|
protocol.ByteCount(s.config.InitialConnectionReceiveWindow),
|
||||||
|
@ -520,6 +521,9 @@ func (s *connection) run() error {
|
||||||
|
|
||||||
runLoop:
|
runLoop:
|
||||||
for {
|
for {
|
||||||
|
if s.framer.QueuedTooManyControlFrames() {
|
||||||
|
s.closeLocal(&qerr.TransportError{ErrorCode: InternalError})
|
||||||
|
}
|
||||||
// Close immediately if requested
|
// Close immediately if requested
|
||||||
select {
|
select {
|
||||||
case closeErr = <-s.closeChan:
|
case closeErr = <-s.closeChan:
|
||||||
|
@ -1148,7 +1152,7 @@ func (s *connection) handleUnpackedLongHeaderPacket(
|
||||||
if !s.receivedFirstPacket {
|
if !s.receivedFirstPacket {
|
||||||
s.receivedFirstPacket = true
|
s.receivedFirstPacket = true
|
||||||
if !s.versionNegotiated && s.tracer != nil && s.tracer.NegotiatedVersion != nil {
|
if !s.versionNegotiated && s.tracer != nil && s.tracer.NegotiatedVersion != nil {
|
||||||
var clientVersions, serverVersions []protocol.VersionNumber
|
var clientVersions, serverVersions []protocol.Version
|
||||||
switch s.perspective {
|
switch s.perspective {
|
||||||
case protocol.PerspectiveClient:
|
case protocol.PerspectiveClient:
|
||||||
clientVersions = s.config.Versions
|
clientVersions = s.config.Versions
|
||||||
|
@ -1185,7 +1189,8 @@ func (s *connection) handleUnpackedLongHeaderPacket(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if s.perspective == protocol.PerspectiveServer && packet.encryptionLevel == protocol.EncryptionHandshake {
|
if s.perspective == protocol.PerspectiveServer && packet.encryptionLevel == protocol.EncryptionHandshake &&
|
||||||
|
!s.droppedInitialKeys {
|
||||||
// On the server side, Initial keys are dropped as soon as the first Handshake packet is received.
|
// On the server side, Initial keys are dropped as soon as the first Handshake packet is received.
|
||||||
// See Section 4.9.1 of RFC 9001.
|
// See Section 4.9.1 of RFC 9001.
|
||||||
if err := s.dropEncryptionLevel(protocol.EncryptionInitial); err != nil {
|
if err := s.dropEncryptionLevel(protocol.EncryptionInitial); err != nil {
|
||||||
|
@ -1572,13 +1577,6 @@ func (s *connection) closeRemote(e error) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// Close the connection. It sends a NO_ERROR application error.
|
|
||||||
// It waits until the run loop has stopped before returning
|
|
||||||
func (s *connection) shutdown() {
|
|
||||||
s.closeLocal(nil)
|
|
||||||
<-s.ctx.Done()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *connection) CloseWithError(code ApplicationErrorCode, desc string) error {
|
func (s *connection) CloseWithError(code ApplicationErrorCode, desc string) error {
|
||||||
s.closeLocal(&qerr.ApplicationError{
|
s.closeLocal(&qerr.ApplicationError{
|
||||||
ErrorCode: code,
|
ErrorCode: code,
|
||||||
|
@ -1588,6 +1586,11 @@ func (s *connection) CloseWithError(code ApplicationErrorCode, desc string) erro
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *connection) closeWithTransportError(code TransportErrorCode) {
|
||||||
|
s.closeLocal(&qerr.TransportError{ErrorCode: code})
|
||||||
|
<-s.ctx.Done()
|
||||||
|
}
|
||||||
|
|
||||||
func (s *connection) handleCloseError(closeErr *closeError) {
|
func (s *connection) handleCloseError(closeErr *closeError) {
|
||||||
e := closeErr.err
|
e := closeErr.err
|
||||||
if e == nil {
|
if e == nil {
|
||||||
|
@ -1632,7 +1635,7 @@ func (s *connection) handleCloseError(closeErr *closeError) {
|
||||||
|
|
||||||
// If this is a remote close we're done here
|
// If this is a remote close we're done here
|
||||||
if closeErr.remote {
|
if closeErr.remote {
|
||||||
s.connIDGenerator.ReplaceWithClosed(s.perspective, nil)
|
s.connIDGenerator.ReplaceWithClosed(nil)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if closeErr.immediate {
|
if closeErr.immediate {
|
||||||
|
@ -1649,7 +1652,7 @@ func (s *connection) handleCloseError(closeErr *closeError) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.logger.Debugf("Error sending CONNECTION_CLOSE: %s", err)
|
s.logger.Debugf("Error sending CONNECTION_CLOSE: %s", err)
|
||||||
}
|
}
|
||||||
s.connIDGenerator.ReplaceWithClosed(s.perspective, connClosePacket)
|
s.connIDGenerator.ReplaceWithClosed(connClosePacket)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *connection) dropEncryptionLevel(encLevel protocol.EncryptionLevel) error {
|
func (s *connection) dropEncryptionLevel(encLevel protocol.EncryptionLevel) error {
|
||||||
|
@ -1661,6 +1664,7 @@ func (s *connection) dropEncryptionLevel(encLevel protocol.EncryptionLevel) erro
|
||||||
//nolint:exhaustive // only Initial and 0-RTT need special treatment
|
//nolint:exhaustive // only Initial and 0-RTT need special treatment
|
||||||
switch encLevel {
|
switch encLevel {
|
||||||
case protocol.EncryptionInitial:
|
case protocol.EncryptionInitial:
|
||||||
|
s.droppedInitialKeys = true
|
||||||
s.cryptoStreamHandler.DiscardInitialKeys()
|
s.cryptoStreamHandler.DiscardInitialKeys()
|
||||||
case protocol.Encryption0RTT:
|
case protocol.Encryption0RTT:
|
||||||
s.streamsMap.ResetFor0RTT()
|
s.streamsMap.ResetFor0RTT()
|
||||||
|
@ -2077,7 +2081,8 @@ func (s *connection) sendPackedCoalescedPacket(packet *coalescedPacket, ecn prot
|
||||||
largestAcked = p.ack.LargestAcked()
|
largestAcked = p.ack.LargestAcked()
|
||||||
}
|
}
|
||||||
s.sentPacketHandler.SentPacket(now, p.header.PacketNumber, largestAcked, p.streamFrames, p.frames, p.EncryptionLevel(), ecn, p.length, false)
|
s.sentPacketHandler.SentPacket(now, p.header.PacketNumber, largestAcked, p.streamFrames, p.frames, p.EncryptionLevel(), ecn, p.length, false)
|
||||||
if s.perspective == protocol.PerspectiveClient && p.EncryptionLevel() == protocol.EncryptionHandshake {
|
if s.perspective == protocol.PerspectiveClient && p.EncryptionLevel() == protocol.EncryptionHandshake &&
|
||||||
|
!s.droppedInitialKeys {
|
||||||
// On the client side, Initial keys are dropped as soon as the first Handshake packet is sent.
|
// On the client side, Initial keys are dropped as soon as the first Handshake packet is sent.
|
||||||
// See Section 4.9.1 of RFC 9001.
|
// See Section 4.9.1 of RFC 9001.
|
||||||
if err := s.dropEncryptionLevel(protocol.EncryptionInitial); err != nil {
|
if err := s.dropEncryptionLevel(protocol.EncryptionInitial); err != nil {
|
||||||
|
@ -2377,11 +2382,7 @@ func (s *connection) RemoteAddr() net.Addr {
|
||||||
return s.conn.RemoteAddr()
|
return s.conn.RemoteAddr()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *connection) getPerspective() protocol.Perspective {
|
func (s *connection) GetVersion() protocol.Version {
|
||||||
return s.perspective
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *connection) GetVersion() protocol.VersionNumber {
|
|
||||||
return s.version
|
return s.version
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -15,15 +15,25 @@ type framer interface {
|
||||||
HasData() bool
|
HasData() bool
|
||||||
|
|
||||||
QueueControlFrame(wire.Frame)
|
QueueControlFrame(wire.Frame)
|
||||||
AppendControlFrames([]ackhandler.Frame, protocol.ByteCount, protocol.VersionNumber) ([]ackhandler.Frame, protocol.ByteCount)
|
AppendControlFrames([]ackhandler.Frame, protocol.ByteCount, protocol.Version) ([]ackhandler.Frame, protocol.ByteCount)
|
||||||
|
|
||||||
AddActiveStream(protocol.StreamID)
|
AddActiveStream(protocol.StreamID)
|
||||||
AppendStreamFrames([]ackhandler.StreamFrame, protocol.ByteCount, protocol.VersionNumber) ([]ackhandler.StreamFrame, protocol.ByteCount)
|
AppendStreamFrames([]ackhandler.StreamFrame, protocol.ByteCount, protocol.Version) ([]ackhandler.StreamFrame, protocol.ByteCount)
|
||||||
|
|
||||||
Handle0RTTRejection() error
|
Handle0RTTRejection() error
|
||||||
|
|
||||||
|
// QueuedTooManyControlFrames says if the control frame queue exceeded its maximum queue length.
|
||||||
|
// This is a hack.
|
||||||
|
// It is easier to implement than propagating an error return value in QueueControlFrame.
|
||||||
|
// The correct solution would be to queue frames with their respective structs.
|
||||||
|
// See https://github.com/quic-go/quic-go/issues/4271 for the queueing of stream-related control frames.
|
||||||
|
QueuedTooManyControlFrames() bool
|
||||||
}
|
}
|
||||||
|
|
||||||
const maxPathResponses = 256
|
const (
|
||||||
|
maxPathResponses = 256
|
||||||
|
maxControlFrames = 16 << 10
|
||||||
|
)
|
||||||
|
|
||||||
type framerI struct {
|
type framerI struct {
|
||||||
mutex sync.Mutex
|
mutex sync.Mutex
|
||||||
|
@ -33,9 +43,10 @@ type framerI struct {
|
||||||
activeStreams map[protocol.StreamID]struct{}
|
activeStreams map[protocol.StreamID]struct{}
|
||||||
streamQueue ringbuffer.RingBuffer[protocol.StreamID]
|
streamQueue ringbuffer.RingBuffer[protocol.StreamID]
|
||||||
|
|
||||||
controlFrameMutex sync.Mutex
|
controlFrameMutex sync.Mutex
|
||||||
controlFrames []wire.Frame
|
controlFrames []wire.Frame
|
||||||
pathResponses []*wire.PathResponseFrame
|
pathResponses []*wire.PathResponseFrame
|
||||||
|
queuedTooManyControlFrames bool
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ framer = &framerI{}
|
var _ framer = &framerI{}
|
||||||
|
@ -73,10 +84,15 @@ func (f *framerI) QueueControlFrame(frame wire.Frame) {
|
||||||
f.pathResponses = append(f.pathResponses, pr)
|
f.pathResponses = append(f.pathResponses, pr)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
// This is a hack.
|
||||||
|
if len(f.controlFrames) >= maxControlFrames {
|
||||||
|
f.queuedTooManyControlFrames = true
|
||||||
|
return
|
||||||
|
}
|
||||||
f.controlFrames = append(f.controlFrames, frame)
|
f.controlFrames = append(f.controlFrames, frame)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *framerI) AppendControlFrames(frames []ackhandler.Frame, maxLen protocol.ByteCount, v protocol.VersionNumber) ([]ackhandler.Frame, protocol.ByteCount) {
|
func (f *framerI) AppendControlFrames(frames []ackhandler.Frame, maxLen protocol.ByteCount, v protocol.Version) ([]ackhandler.Frame, protocol.ByteCount) {
|
||||||
f.controlFrameMutex.Lock()
|
f.controlFrameMutex.Lock()
|
||||||
defer f.controlFrameMutex.Unlock()
|
defer f.controlFrameMutex.Unlock()
|
||||||
|
|
||||||
|
@ -105,6 +121,10 @@ func (f *framerI) AppendControlFrames(frames []ackhandler.Frame, maxLen protocol
|
||||||
return frames, length
|
return frames, length
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (f *framerI) QueuedTooManyControlFrames() bool {
|
||||||
|
return f.queuedTooManyControlFrames
|
||||||
|
}
|
||||||
|
|
||||||
func (f *framerI) AddActiveStream(id protocol.StreamID) {
|
func (f *framerI) AddActiveStream(id protocol.StreamID) {
|
||||||
f.mutex.Lock()
|
f.mutex.Lock()
|
||||||
if _, ok := f.activeStreams[id]; !ok {
|
if _, ok := f.activeStreams[id]; !ok {
|
||||||
|
@ -114,7 +134,7 @@ func (f *framerI) AddActiveStream(id protocol.StreamID) {
|
||||||
f.mutex.Unlock()
|
f.mutex.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *framerI) AppendStreamFrames(frames []ackhandler.StreamFrame, maxLen protocol.ByteCount, v protocol.VersionNumber) ([]ackhandler.StreamFrame, protocol.ByteCount) {
|
func (f *framerI) AppendStreamFrames(frames []ackhandler.StreamFrame, maxLen protocol.ByteCount, v protocol.Version) ([]ackhandler.StreamFrame, protocol.ByteCount) {
|
||||||
startLen := len(frames)
|
startLen := len(frames)
|
||||||
var length protocol.ByteCount
|
var length protocol.ByteCount
|
||||||
f.mutex.Lock()
|
f.mutex.Lock()
|
||||||
|
|
|
@ -16,8 +16,12 @@ import (
|
||||||
// The StreamID is the ID of a QUIC stream.
|
// The StreamID is the ID of a QUIC stream.
|
||||||
type StreamID = protocol.StreamID
|
type StreamID = protocol.StreamID
|
||||||
|
|
||||||
|
// A Version is a QUIC version number.
|
||||||
|
type Version = protocol.Version
|
||||||
|
|
||||||
// A VersionNumber is a QUIC version number.
|
// A VersionNumber is a QUIC version number.
|
||||||
type VersionNumber = protocol.VersionNumber
|
// Deprecated: VersionNumber was renamed to Version.
|
||||||
|
type VersionNumber = Version
|
||||||
|
|
||||||
const (
|
const (
|
||||||
// Version1 is RFC 9000
|
// Version1 is RFC 9000
|
||||||
|
@ -159,6 +163,9 @@ type Connection interface {
|
||||||
OpenStream() (Stream, error)
|
OpenStream() (Stream, error)
|
||||||
// OpenStreamSync opens a new bidirectional QUIC stream.
|
// OpenStreamSync opens a new bidirectional QUIC stream.
|
||||||
// It blocks until a new stream can be opened.
|
// It blocks until a new stream can be opened.
|
||||||
|
// There is no signaling to the peer about new streams:
|
||||||
|
// The peer can only accept the stream after data has been sent on the stream,
|
||||||
|
// or the stream has been reset or closed.
|
||||||
// If the error is non-nil, it satisfies the net.Error interface.
|
// If the error is non-nil, it satisfies the net.Error interface.
|
||||||
// If the connection was closed due to a timeout, Timeout() will be true.
|
// If the connection was closed due to a timeout, Timeout() will be true.
|
||||||
OpenStreamSync(context.Context) (Stream, error)
|
OpenStreamSync(context.Context) (Stream, error)
|
||||||
|
@ -255,7 +262,7 @@ type Config struct {
|
||||||
GetConfigForClient func(info *ClientHelloInfo) (*Config, error)
|
GetConfigForClient func(info *ClientHelloInfo) (*Config, error)
|
||||||
// The QUIC versions that can be negotiated.
|
// The QUIC versions that can be negotiated.
|
||||||
// If not set, it uses all versions available.
|
// If not set, it uses all versions available.
|
||||||
Versions []VersionNumber
|
Versions []Version
|
||||||
// HandshakeIdleTimeout is the idle timeout before completion of the handshake.
|
// HandshakeIdleTimeout is the idle timeout before completion of the handshake.
|
||||||
// If we don't receive any packet from the peer within this time, the connection attempt is aborted.
|
// If we don't receive any packet from the peer within this time, the connection attempt is aborted.
|
||||||
// Additionally, if the handshake doesn't complete in twice this time, the connection attempt is also aborted.
|
// Additionally, if the handshake doesn't complete in twice this time, the connection attempt is also aborted.
|
||||||
|
@ -267,11 +274,6 @@ type Config struct {
|
||||||
// If the timeout is exceeded, the connection is closed.
|
// If the timeout is exceeded, the connection is closed.
|
||||||
// If this value is zero, the timeout is set to 30 seconds.
|
// If this value is zero, the timeout is set to 30 seconds.
|
||||||
MaxIdleTimeout time.Duration
|
MaxIdleTimeout time.Duration
|
||||||
// RequireAddressValidation determines if a QUIC Retry packet is sent.
|
|
||||||
// This allows the server to verify the client's address, at the cost of increasing the handshake latency by 1 RTT.
|
|
||||||
// See https://datatracker.ietf.org/doc/html/rfc9000#section-8 for details.
|
|
||||||
// If not set, every client is forced to prove its remote address.
|
|
||||||
RequireAddressValidation func(net.Addr) bool
|
|
||||||
// The TokenStore stores tokens received from the server.
|
// The TokenStore stores tokens received from the server.
|
||||||
// Tokens are used to skip address validation on future connection attempts.
|
// Tokens are used to skip address validation on future connection attempts.
|
||||||
// The key used to store tokens is the ServerName from the tls.Config, if set
|
// The key used to store tokens is the ServerName from the tls.Config, if set
|
||||||
|
@ -331,8 +333,15 @@ type Config struct {
|
||||||
Tracer func(context.Context, logging.Perspective, ConnectionID) *logging.ConnectionTracer
|
Tracer func(context.Context, logging.Perspective, ConnectionID) *logging.ConnectionTracer
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ClientHelloInfo contains information about an incoming connection attempt.
|
||||||
type ClientHelloInfo struct {
|
type ClientHelloInfo struct {
|
||||||
|
// RemoteAddr is the remote address on the Initial packet.
|
||||||
|
// Unless AddrVerified is set, the address is not yet verified, and could be a spoofed IP address.
|
||||||
RemoteAddr net.Addr
|
RemoteAddr net.Addr
|
||||||
|
// AddrVerified says if the remote address was verified using QUIC's Retry mechanism.
|
||||||
|
// Note that the Retry mechanism costs one network roundtrip,
|
||||||
|
// and is not performed unless Transport.MaxUnvalidatedHandshakes is surpassed.
|
||||||
|
AddrVerified bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// ConnectionState records basic details about a QUIC connection
|
// ConnectionState records basic details about a QUIC connection
|
||||||
|
@ -347,7 +356,7 @@ type ConnectionState struct {
|
||||||
// Used0RTT says if 0-RTT resumption was used.
|
// Used0RTT says if 0-RTT resumption was used.
|
||||||
Used0RTT bool
|
Used0RTT bool
|
||||||
// Version is the QUIC version of the QUIC connection.
|
// Version is the QUIC version of the QUIC connection.
|
||||||
Version VersionNumber
|
Version Version
|
||||||
// GSO says if generic segmentation offload is used
|
// GSO says if generic segmentation offload is used
|
||||||
GSO bool
|
GSO bool
|
||||||
}
|
}
|
||||||
|
|
|
@ -20,5 +20,5 @@ func NewAckHandler(
|
||||||
logger utils.Logger,
|
logger utils.Logger,
|
||||||
) (SentPacketHandler, ReceivedPacketHandler) {
|
) (SentPacketHandler, ReceivedPacketHandler) {
|
||||||
sph := newSentPacketHandler(initialPacketNumber, initialMaxDatagramSize, rttStats, clientAddressValidated, enableECN, pers, tracer, logger)
|
sph := newSentPacketHandler(initialPacketNumber, initialMaxDatagramSize, rttStats, clientAddressValidated, enableECN, pers, tracer, logger)
|
||||||
return sph, newReceivedPacketHandler(sph, rttStats, logger)
|
return sph, newReceivedPacketHandler(sph, logger)
|
||||||
}
|
}
|
||||||
|
|
39
vendor/github.com/quic-go/quic-go/internal/ackhandler/received_packet_handler.go
generated
vendored
39
vendor/github.com/quic-go/quic-go/internal/ackhandler/received_packet_handler.go
generated
vendored
|
@ -14,23 +14,19 @@ type receivedPacketHandler struct {
|
||||||
|
|
||||||
initialPackets *receivedPacketTracker
|
initialPackets *receivedPacketTracker
|
||||||
handshakePackets *receivedPacketTracker
|
handshakePackets *receivedPacketTracker
|
||||||
appDataPackets *receivedPacketTracker
|
appDataPackets appDataReceivedPacketTracker
|
||||||
|
|
||||||
lowest1RTTPacket protocol.PacketNumber
|
lowest1RTTPacket protocol.PacketNumber
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ ReceivedPacketHandler = &receivedPacketHandler{}
|
var _ ReceivedPacketHandler = &receivedPacketHandler{}
|
||||||
|
|
||||||
func newReceivedPacketHandler(
|
func newReceivedPacketHandler(sentPackets sentPacketTracker, logger utils.Logger) ReceivedPacketHandler {
|
||||||
sentPackets sentPacketTracker,
|
|
||||||
rttStats *utils.RTTStats,
|
|
||||||
logger utils.Logger,
|
|
||||||
) ReceivedPacketHandler {
|
|
||||||
return &receivedPacketHandler{
|
return &receivedPacketHandler{
|
||||||
sentPackets: sentPackets,
|
sentPackets: sentPackets,
|
||||||
initialPackets: newReceivedPacketTracker(rttStats, logger),
|
initialPackets: newReceivedPacketTracker(),
|
||||||
handshakePackets: newReceivedPacketTracker(rttStats, logger),
|
handshakePackets: newReceivedPacketTracker(),
|
||||||
appDataPackets: newReceivedPacketTracker(rttStats, logger),
|
appDataPackets: *newAppDataReceivedPacketTracker(logger),
|
||||||
lowest1RTTPacket: protocol.InvalidPacketNumber,
|
lowest1RTTPacket: protocol.InvalidPacketNumber,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -88,41 +84,28 @@ func (h *receivedPacketHandler) DropPackets(encLevel protocol.EncryptionLevel) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *receivedPacketHandler) GetAlarmTimeout() time.Time {
|
func (h *receivedPacketHandler) GetAlarmTimeout() time.Time {
|
||||||
var initialAlarm, handshakeAlarm time.Time
|
return h.appDataPackets.GetAlarmTimeout()
|
||||||
if h.initialPackets != nil {
|
|
||||||
initialAlarm = h.initialPackets.GetAlarmTimeout()
|
|
||||||
}
|
|
||||||
if h.handshakePackets != nil {
|
|
||||||
handshakeAlarm = h.handshakePackets.GetAlarmTimeout()
|
|
||||||
}
|
|
||||||
oneRTTAlarm := h.appDataPackets.GetAlarmTimeout()
|
|
||||||
return utils.MinNonZeroTime(utils.MinNonZeroTime(initialAlarm, handshakeAlarm), oneRTTAlarm)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *receivedPacketHandler) GetAckFrame(encLevel protocol.EncryptionLevel, onlyIfQueued bool) *wire.AckFrame {
|
func (h *receivedPacketHandler) GetAckFrame(encLevel protocol.EncryptionLevel, onlyIfQueued bool) *wire.AckFrame {
|
||||||
var ack *wire.AckFrame
|
|
||||||
//nolint:exhaustive // 0-RTT packets can't contain ACK frames.
|
//nolint:exhaustive // 0-RTT packets can't contain ACK frames.
|
||||||
switch encLevel {
|
switch encLevel {
|
||||||
case protocol.EncryptionInitial:
|
case protocol.EncryptionInitial:
|
||||||
if h.initialPackets != nil {
|
if h.initialPackets != nil {
|
||||||
ack = h.initialPackets.GetAckFrame(onlyIfQueued)
|
return h.initialPackets.GetAckFrame()
|
||||||
}
|
}
|
||||||
|
return nil
|
||||||
case protocol.EncryptionHandshake:
|
case protocol.EncryptionHandshake:
|
||||||
if h.handshakePackets != nil {
|
if h.handshakePackets != nil {
|
||||||
ack = h.handshakePackets.GetAckFrame(onlyIfQueued)
|
return h.handshakePackets.GetAckFrame()
|
||||||
}
|
}
|
||||||
|
return nil
|
||||||
case protocol.Encryption1RTT:
|
case protocol.Encryption1RTT:
|
||||||
// 0-RTT packets can't contain ACK frames
|
|
||||||
return h.appDataPackets.GetAckFrame(onlyIfQueued)
|
return h.appDataPackets.GetAckFrame(onlyIfQueued)
|
||||||
default:
|
default:
|
||||||
|
// 0-RTT packets can't contain ACK frames
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
// For Initial and Handshake ACKs, the delay time is ignored by the receiver.
|
|
||||||
// Set it to 0 in order to save bytes.
|
|
||||||
if ack != nil {
|
|
||||||
ack.DelayTime = 0
|
|
||||||
}
|
|
||||||
return ack
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *receivedPacketHandler) IsPotentiallyDuplicate(pn protocol.PacketNumber, encLevel protocol.EncryptionLevel) bool {
|
func (h *receivedPacketHandler) IsPotentiallyDuplicate(pn protocol.PacketNumber, encLevel protocol.EncryptionLevel) bool {
|
||||||
|
|
208
vendor/github.com/quic-go/quic-go/internal/ackhandler/received_packet_tracker.go
generated
vendored
208
vendor/github.com/quic-go/quic-go/internal/ackhandler/received_packet_tracker.go
generated
vendored
|
@ -9,40 +9,19 @@ import (
|
||||||
"github.com/quic-go/quic-go/internal/wire"
|
"github.com/quic-go/quic-go/internal/wire"
|
||||||
)
|
)
|
||||||
|
|
||||||
// number of ack-eliciting packets received before sending an ack.
|
// The receivedPacketTracker tracks packets for the Initial and Handshake packet number space.
|
||||||
const packetsBeforeAck = 2
|
// Every received packet is acknowledged immediately.
|
||||||
|
|
||||||
type receivedPacketTracker struct {
|
type receivedPacketTracker struct {
|
||||||
largestObserved protocol.PacketNumber
|
ect0, ect1, ecnce uint64
|
||||||
ignoreBelow protocol.PacketNumber
|
|
||||||
largestObservedRcvdTime time.Time
|
|
||||||
ect0, ect1, ecnce uint64
|
|
||||||
|
|
||||||
packetHistory *receivedPacketHistory
|
packetHistory receivedPacketHistory
|
||||||
|
|
||||||
maxAckDelay time.Duration
|
|
||||||
rttStats *utils.RTTStats
|
|
||||||
|
|
||||||
|
lastAck *wire.AckFrame
|
||||||
hasNewAck bool // true as soon as we received an ack-eliciting new packet
|
hasNewAck bool // true as soon as we received an ack-eliciting new packet
|
||||||
ackQueued bool // true once we received more than 2 (or later in the connection 10) ack-eliciting packets
|
|
||||||
|
|
||||||
ackElicitingPacketsReceivedSinceLastAck int
|
|
||||||
ackAlarm time.Time
|
|
||||||
lastAck *wire.AckFrame
|
|
||||||
|
|
||||||
logger utils.Logger
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func newReceivedPacketTracker(
|
func newReceivedPacketTracker() *receivedPacketTracker {
|
||||||
rttStats *utils.RTTStats,
|
return &receivedPacketTracker{packetHistory: *newReceivedPacketHistory()}
|
||||||
logger utils.Logger,
|
|
||||||
) *receivedPacketTracker {
|
|
||||||
return &receivedPacketTracker{
|
|
||||||
packetHistory: newReceivedPacketHistory(),
|
|
||||||
maxAckDelay: protocol.MaxAckDelay,
|
|
||||||
rttStats: rttStats,
|
|
||||||
logger: logger,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *receivedPacketTracker) ReceivedPacket(pn protocol.PacketNumber, ecn protocol.ECN, rcvTime time.Time, ackEliciting bool) error {
|
func (h *receivedPacketTracker) ReceivedPacket(pn protocol.PacketNumber, ecn protocol.ECN, rcvTime time.Time, ackEliciting bool) error {
|
||||||
|
@ -50,16 +29,6 @@ func (h *receivedPacketTracker) ReceivedPacket(pn protocol.PacketNumber, ecn pro
|
||||||
return fmt.Errorf("recevedPacketTracker BUG: ReceivedPacket called for old / duplicate packet %d", pn)
|
return fmt.Errorf("recevedPacketTracker BUG: ReceivedPacket called for old / duplicate packet %d", pn)
|
||||||
}
|
}
|
||||||
|
|
||||||
isMissing := h.isMissing(pn)
|
|
||||||
if pn >= h.largestObserved {
|
|
||||||
h.largestObserved = pn
|
|
||||||
h.largestObservedRcvdTime = rcvTime
|
|
||||||
}
|
|
||||||
|
|
||||||
if ackEliciting {
|
|
||||||
h.hasNewAck = true
|
|
||||||
h.maybeQueueACK(pn, rcvTime, ecn, isMissing)
|
|
||||||
}
|
|
||||||
//nolint:exhaustive // Only need to count ECT(0), ECT(1) and ECN-CE.
|
//nolint:exhaustive // Only need to count ECT(0), ECT(1) and ECN-CE.
|
||||||
switch ecn {
|
switch ecn {
|
||||||
case protocol.ECT0:
|
case protocol.ECT0:
|
||||||
|
@ -69,12 +38,99 @@ func (h *receivedPacketTracker) ReceivedPacket(pn protocol.PacketNumber, ecn pro
|
||||||
case protocol.ECNCE:
|
case protocol.ECNCE:
|
||||||
h.ecnce++
|
h.ecnce++
|
||||||
}
|
}
|
||||||
|
if !ackEliciting {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
h.hasNewAck = true
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *receivedPacketTracker) GetAckFrame() *wire.AckFrame {
|
||||||
|
if !h.hasNewAck {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// This function always returns the same ACK frame struct, filled with the most recent values.
|
||||||
|
ack := h.lastAck
|
||||||
|
if ack == nil {
|
||||||
|
ack = &wire.AckFrame{}
|
||||||
|
}
|
||||||
|
ack.Reset()
|
||||||
|
ack.ECT0 = h.ect0
|
||||||
|
ack.ECT1 = h.ect1
|
||||||
|
ack.ECNCE = h.ecnce
|
||||||
|
ack.AckRanges = h.packetHistory.AppendAckRanges(ack.AckRanges)
|
||||||
|
|
||||||
|
h.lastAck = ack
|
||||||
|
h.hasNewAck = false
|
||||||
|
return ack
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *receivedPacketTracker) IsPotentiallyDuplicate(pn protocol.PacketNumber) bool {
|
||||||
|
return h.packetHistory.IsPotentiallyDuplicate(pn)
|
||||||
|
}
|
||||||
|
|
||||||
|
// number of ack-eliciting packets received before sending an ACK
|
||||||
|
const packetsBeforeAck = 2
|
||||||
|
|
||||||
|
// The appDataReceivedPacketTracker tracks packets received in the Application Data packet number space.
|
||||||
|
// It waits until at least 2 packets were received before queueing an ACK, or until the max_ack_delay was reached.
|
||||||
|
type appDataReceivedPacketTracker struct {
|
||||||
|
receivedPacketTracker
|
||||||
|
|
||||||
|
largestObservedRcvdTime time.Time
|
||||||
|
|
||||||
|
largestObserved protocol.PacketNumber
|
||||||
|
ignoreBelow protocol.PacketNumber
|
||||||
|
|
||||||
|
maxAckDelay time.Duration
|
||||||
|
ackQueued bool // true if we need send a new ACK
|
||||||
|
|
||||||
|
ackElicitingPacketsReceivedSinceLastAck int
|
||||||
|
ackAlarm time.Time
|
||||||
|
|
||||||
|
logger utils.Logger
|
||||||
|
}
|
||||||
|
|
||||||
|
func newAppDataReceivedPacketTracker(logger utils.Logger) *appDataReceivedPacketTracker {
|
||||||
|
h := &appDataReceivedPacketTracker{
|
||||||
|
receivedPacketTracker: *newReceivedPacketTracker(),
|
||||||
|
maxAckDelay: protocol.MaxAckDelay,
|
||||||
|
logger: logger,
|
||||||
|
}
|
||||||
|
return h
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *appDataReceivedPacketTracker) ReceivedPacket(pn protocol.PacketNumber, ecn protocol.ECN, rcvTime time.Time, ackEliciting bool) error {
|
||||||
|
if err := h.receivedPacketTracker.ReceivedPacket(pn, ecn, rcvTime, ackEliciting); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if pn >= h.largestObserved {
|
||||||
|
h.largestObserved = pn
|
||||||
|
h.largestObservedRcvdTime = rcvTime
|
||||||
|
}
|
||||||
|
if !ackEliciting {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
h.ackElicitingPacketsReceivedSinceLastAck++
|
||||||
|
isMissing := h.isMissing(pn)
|
||||||
|
if !h.ackQueued && h.shouldQueueACK(pn, ecn, isMissing) {
|
||||||
|
h.ackQueued = true
|
||||||
|
h.ackAlarm = time.Time{} // cancel the ack alarm
|
||||||
|
}
|
||||||
|
if !h.ackQueued {
|
||||||
|
// No ACK queued, but we'll need to acknowledge the packet after max_ack_delay.
|
||||||
|
h.ackAlarm = rcvTime.Add(h.maxAckDelay)
|
||||||
|
if h.logger.Debug() {
|
||||||
|
h.logger.Debugf("\tSetting ACK timer to max ack delay: %s", h.maxAckDelay)
|
||||||
|
}
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// IgnoreBelow sets a lower limit for acknowledging packets.
|
// IgnoreBelow sets a lower limit for acknowledging packets.
|
||||||
// Packets with packet numbers smaller than p will not be acked.
|
// Packets with packet numbers smaller than p will not be acked.
|
||||||
func (h *receivedPacketTracker) IgnoreBelow(pn protocol.PacketNumber) {
|
func (h *appDataReceivedPacketTracker) IgnoreBelow(pn protocol.PacketNumber) {
|
||||||
if pn <= h.ignoreBelow {
|
if pn <= h.ignoreBelow {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -86,14 +142,14 @@ func (h *receivedPacketTracker) IgnoreBelow(pn protocol.PacketNumber) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// isMissing says if a packet was reported missing in the last ACK.
|
// isMissing says if a packet was reported missing in the last ACK.
|
||||||
func (h *receivedPacketTracker) isMissing(p protocol.PacketNumber) bool {
|
func (h *appDataReceivedPacketTracker) isMissing(p protocol.PacketNumber) bool {
|
||||||
if h.lastAck == nil || p < h.ignoreBelow {
|
if h.lastAck == nil || p < h.ignoreBelow {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
return p < h.lastAck.LargestAcked() && !h.lastAck.AcksPacket(p)
|
return p < h.lastAck.LargestAcked() && !h.lastAck.AcksPacket(p)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *receivedPacketTracker) hasNewMissingPackets() bool {
|
func (h *appDataReceivedPacketTracker) hasNewMissingPackets() bool {
|
||||||
if h.lastAck == nil {
|
if h.lastAck == nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
@ -101,31 +157,21 @@ func (h *receivedPacketTracker) hasNewMissingPackets() bool {
|
||||||
return highestRange.Smallest > h.lastAck.LargestAcked()+1 && highestRange.Len() == 1
|
return highestRange.Smallest > h.lastAck.LargestAcked()+1 && highestRange.Len() == 1
|
||||||
}
|
}
|
||||||
|
|
||||||
// maybeQueueACK queues an ACK, if necessary.
|
func (h *appDataReceivedPacketTracker) shouldQueueACK(pn protocol.PacketNumber, ecn protocol.ECN, wasMissing bool) bool {
|
||||||
func (h *receivedPacketTracker) maybeQueueACK(pn protocol.PacketNumber, rcvTime time.Time, ecn protocol.ECN, wasMissing bool) {
|
|
||||||
// always acknowledge the first packet
|
// always acknowledge the first packet
|
||||||
if h.lastAck == nil {
|
if h.lastAck == nil {
|
||||||
if !h.ackQueued {
|
h.logger.Debugf("\tQueueing ACK because the first packet should be acknowledged.")
|
||||||
h.logger.Debugf("\tQueueing ACK because the first packet should be acknowledged.")
|
return true
|
||||||
}
|
|
||||||
h.ackQueued = true
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if h.ackQueued {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
h.ackElicitingPacketsReceivedSinceLastAck++
|
|
||||||
|
|
||||||
// Send an ACK if this packet was reported missing in an ACK sent before.
|
// Send an ACK if this packet was reported missing in an ACK sent before.
|
||||||
// Ack decimation with reordering relies on the timer to send an ACK, but if
|
// Ack decimation with reordering relies on the timer to send an ACK, but if
|
||||||
// missing packets we reported in the previous ack, send an ACK immediately.
|
// missing packets we reported in the previous ACK, send an ACK immediately.
|
||||||
if wasMissing {
|
if wasMissing {
|
||||||
if h.logger.Debug() {
|
if h.logger.Debug() {
|
||||||
h.logger.Debugf("\tQueueing ACK because packet %d was missing before.", pn)
|
h.logger.Debugf("\tQueueing ACK because packet %d was missing before.", pn)
|
||||||
}
|
}
|
||||||
h.ackQueued = true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
// send an ACK every 2 ack-eliciting packets
|
// send an ACK every 2 ack-eliciting packets
|
||||||
|
@ -133,68 +179,42 @@ func (h *receivedPacketTracker) maybeQueueACK(pn protocol.PacketNumber, rcvTime
|
||||||
if h.logger.Debug() {
|
if h.logger.Debug() {
|
||||||
h.logger.Debugf("\tQueueing ACK because packet %d packets were received after the last ACK (using initial threshold: %d).", h.ackElicitingPacketsReceivedSinceLastAck, packetsBeforeAck)
|
h.logger.Debugf("\tQueueing ACK because packet %d packets were received after the last ACK (using initial threshold: %d).", h.ackElicitingPacketsReceivedSinceLastAck, packetsBeforeAck)
|
||||||
}
|
}
|
||||||
h.ackQueued = true
|
return true
|
||||||
} else if h.ackAlarm.IsZero() {
|
|
||||||
if h.logger.Debug() {
|
|
||||||
h.logger.Debugf("\tSetting ACK timer to max ack delay: %s", h.maxAckDelay)
|
|
||||||
}
|
|
||||||
h.ackAlarm = rcvTime.Add(h.maxAckDelay)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// queue an ACK if there are new missing packets to report
|
// queue an ACK if there are new missing packets to report
|
||||||
if h.hasNewMissingPackets() {
|
if h.hasNewMissingPackets() {
|
||||||
h.logger.Debugf("\tQueuing ACK because there's a new missing packet to report.")
|
h.logger.Debugf("\tQueuing ACK because there's a new missing packet to report.")
|
||||||
h.ackQueued = true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
// queue an ACK if the packet was ECN-CE marked
|
// queue an ACK if the packet was ECN-CE marked
|
||||||
if ecn == protocol.ECNCE {
|
if ecn == protocol.ECNCE {
|
||||||
h.logger.Debugf("\tQueuing ACK because the packet was ECN-CE marked.")
|
h.logger.Debugf("\tQueuing ACK because the packet was ECN-CE marked.")
|
||||||
h.ackQueued = true
|
return true
|
||||||
}
|
|
||||||
|
|
||||||
if h.ackQueued {
|
|
||||||
// cancel the ack alarm
|
|
||||||
h.ackAlarm = time.Time{}
|
|
||||||
}
|
}
|
||||||
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *receivedPacketTracker) GetAckFrame(onlyIfQueued bool) *wire.AckFrame {
|
func (h *appDataReceivedPacketTracker) GetAckFrame(onlyIfQueued bool) *wire.AckFrame {
|
||||||
if !h.hasNewAck {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
if onlyIfQueued {
|
if onlyIfQueued && !h.ackQueued {
|
||||||
if !h.ackQueued && (h.ackAlarm.IsZero() || h.ackAlarm.After(now)) {
|
if h.ackAlarm.IsZero() || h.ackAlarm.After(now) {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
if h.logger.Debug() && !h.ackQueued && !h.ackAlarm.IsZero() {
|
if h.logger.Debug() && !h.ackAlarm.IsZero() {
|
||||||
h.logger.Debugf("Sending ACK because the ACK timer expired.")
|
h.logger.Debugf("Sending ACK because the ACK timer expired.")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
ack := h.receivedPacketTracker.GetAckFrame()
|
||||||
// This function always returns the same ACK frame struct, filled with the most recent values.
|
|
||||||
ack := h.lastAck
|
|
||||||
if ack == nil {
|
if ack == nil {
|
||||||
ack = &wire.AckFrame{}
|
return nil
|
||||||
}
|
}
|
||||||
ack.Reset()
|
|
||||||
ack.DelayTime = max(0, now.Sub(h.largestObservedRcvdTime))
|
ack.DelayTime = max(0, now.Sub(h.largestObservedRcvdTime))
|
||||||
ack.ECT0 = h.ect0
|
|
||||||
ack.ECT1 = h.ect1
|
|
||||||
ack.ECNCE = h.ecnce
|
|
||||||
ack.AckRanges = h.packetHistory.AppendAckRanges(ack.AckRanges)
|
|
||||||
|
|
||||||
h.lastAck = ack
|
|
||||||
h.ackAlarm = time.Time{}
|
|
||||||
h.ackQueued = false
|
h.ackQueued = false
|
||||||
h.hasNewAck = false
|
h.ackAlarm = time.Time{}
|
||||||
h.ackElicitingPacketsReceivedSinceLastAck = 0
|
h.ackElicitingPacketsReceivedSinceLastAck = 0
|
||||||
return ack
|
return ack
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *receivedPacketTracker) GetAlarmTimeout() time.Time { return h.ackAlarm }
|
func (h *appDataReceivedPacketTracker) GetAlarmTimeout() time.Time { return h.ackAlarm }
|
||||||
|
|
||||||
func (h *receivedPacketTracker) IsPotentiallyDuplicate(pn protocol.PacketNumber) bool {
|
|
||||||
return h.packetHistory.IsPotentiallyDuplicate(pn)
|
|
||||||
}
|
|
||||||
|
|
|
@ -48,10 +48,12 @@ func (c *baseFlowController) AddBytesSent(n protocol.ByteCount) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateSendWindow is called after receiving a MAX_{STREAM_}DATA frame.
|
// UpdateSendWindow is called after receiving a MAX_{STREAM_}DATA frame.
|
||||||
func (c *baseFlowController) UpdateSendWindow(offset protocol.ByteCount) {
|
func (c *baseFlowController) UpdateSendWindow(offset protocol.ByteCount) (updated bool) {
|
||||||
if offset > c.sendWindow {
|
if offset > c.sendWindow {
|
||||||
c.sendWindow = offset
|
c.sendWindow = offset
|
||||||
|
return true
|
||||||
}
|
}
|
||||||
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *baseFlowController) sendWindowSize() protocol.ByteCount {
|
func (c *baseFlowController) sendWindowSize() protocol.ByteCount {
|
||||||
|
|
|
@ -5,7 +5,7 @@ import "github.com/quic-go/quic-go/internal/protocol"
|
||||||
type flowController interface {
|
type flowController interface {
|
||||||
// for sending
|
// for sending
|
||||||
SendWindowSize() protocol.ByteCount
|
SendWindowSize() protocol.ByteCount
|
||||||
UpdateSendWindow(protocol.ByteCount)
|
UpdateSendWindow(protocol.ByteCount) (updated bool)
|
||||||
AddBytesSent(protocol.ByteCount)
|
AddBytesSent(protocol.ByteCount)
|
||||||
// for receiving
|
// for receiving
|
||||||
AddBytesRead(protocol.ByteCount)
|
AddBytesRead(protocol.ByteCount)
|
||||||
|
@ -16,12 +16,11 @@ type flowController interface {
|
||||||
// A StreamFlowController is a flow controller for a QUIC stream.
|
// A StreamFlowController is a flow controller for a QUIC stream.
|
||||||
type StreamFlowController interface {
|
type StreamFlowController interface {
|
||||||
flowController
|
flowController
|
||||||
// for receiving
|
// UpdateHighestReceived is called when a new highest offset is received
|
||||||
// UpdateHighestReceived should be called when a new highest offset is received
|
|
||||||
// final has to be to true if this is the final offset of the stream,
|
// final has to be to true if this is the final offset of the stream,
|
||||||
// as contained in a STREAM frame with FIN bit, and the RESET_STREAM frame
|
// as contained in a STREAM frame with FIN bit, and the RESET_STREAM frame
|
||||||
UpdateHighestReceived(offset protocol.ByteCount, final bool) error
|
UpdateHighestReceived(offset protocol.ByteCount, final bool) error
|
||||||
// Abandon should be called when reading from the stream is aborted early,
|
// Abandon is called when reading from the stream is aborted early,
|
||||||
// and there won't be any further calls to AddBytesRead.
|
// and there won't be any further calls to AddBytesRead.
|
||||||
Abandon()
|
Abandon()
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,13 +1,12 @@
|
||||||
package handshake
|
package handshake
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/cipher"
|
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
|
|
||||||
"github.com/quic-go/quic-go/internal/protocol"
|
"github.com/quic-go/quic-go/internal/protocol"
|
||||||
)
|
)
|
||||||
|
|
||||||
func createAEAD(suite *cipherSuite, trafficSecret []byte, v protocol.VersionNumber) cipher.AEAD {
|
func createAEAD(suite *cipherSuite, trafficSecret []byte, v protocol.Version) *xorNonceAEAD {
|
||||||
keyLabel := hkdfLabelKeyV1
|
keyLabel := hkdfLabelKeyV1
|
||||||
ivLabel := hkdfLabelIVV1
|
ivLabel := hkdfLabelIVV1
|
||||||
if v == protocol.Version2 {
|
if v == protocol.Version2 {
|
||||||
|
@ -20,28 +19,26 @@ func createAEAD(suite *cipherSuite, trafficSecret []byte, v protocol.VersionNumb
|
||||||
}
|
}
|
||||||
|
|
||||||
type longHeaderSealer struct {
|
type longHeaderSealer struct {
|
||||||
aead cipher.AEAD
|
aead *xorNonceAEAD
|
||||||
headerProtector headerProtector
|
headerProtector headerProtector
|
||||||
|
nonceBuf [8]byte
|
||||||
// use a single slice to avoid allocations
|
|
||||||
nonceBuf []byte
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ LongHeaderSealer = &longHeaderSealer{}
|
var _ LongHeaderSealer = &longHeaderSealer{}
|
||||||
|
|
||||||
func newLongHeaderSealer(aead cipher.AEAD, headerProtector headerProtector) LongHeaderSealer {
|
func newLongHeaderSealer(aead *xorNonceAEAD, headerProtector headerProtector) LongHeaderSealer {
|
||||||
|
if aead.NonceSize() != 8 {
|
||||||
|
panic("unexpected nonce size")
|
||||||
|
}
|
||||||
return &longHeaderSealer{
|
return &longHeaderSealer{
|
||||||
aead: aead,
|
aead: aead,
|
||||||
headerProtector: headerProtector,
|
headerProtector: headerProtector,
|
||||||
nonceBuf: make([]byte, aead.NonceSize()),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *longHeaderSealer) Seal(dst, src []byte, pn protocol.PacketNumber, ad []byte) []byte {
|
func (s *longHeaderSealer) Seal(dst, src []byte, pn protocol.PacketNumber, ad []byte) []byte {
|
||||||
binary.BigEndian.PutUint64(s.nonceBuf[len(s.nonceBuf)-8:], uint64(pn))
|
binary.BigEndian.PutUint64(s.nonceBuf[:], uint64(pn))
|
||||||
// The AEAD we're using here will be the qtls.aeadAESGCM13.
|
return s.aead.Seal(dst, s.nonceBuf[:], src, ad)
|
||||||
// It uses the nonce provided here and XOR it with the IV.
|
|
||||||
return s.aead.Seal(dst, s.nonceBuf, src, ad)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *longHeaderSealer) EncryptHeader(sample []byte, firstByte *byte, pnBytes []byte) {
|
func (s *longHeaderSealer) EncryptHeader(sample []byte, firstByte *byte, pnBytes []byte) {
|
||||||
|
@ -53,21 +50,23 @@ func (s *longHeaderSealer) Overhead() int {
|
||||||
}
|
}
|
||||||
|
|
||||||
type longHeaderOpener struct {
|
type longHeaderOpener struct {
|
||||||
aead cipher.AEAD
|
aead *xorNonceAEAD
|
||||||
headerProtector headerProtector
|
headerProtector headerProtector
|
||||||
highestRcvdPN protocol.PacketNumber // highest packet number received (which could be successfully unprotected)
|
highestRcvdPN protocol.PacketNumber // highest packet number received (which could be successfully unprotected)
|
||||||
|
|
||||||
// use a single slice to avoid allocations
|
// use a single array to avoid allocations
|
||||||
nonceBuf []byte
|
nonceBuf [8]byte
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ LongHeaderOpener = &longHeaderOpener{}
|
var _ LongHeaderOpener = &longHeaderOpener{}
|
||||||
|
|
||||||
func newLongHeaderOpener(aead cipher.AEAD, headerProtector headerProtector) LongHeaderOpener {
|
func newLongHeaderOpener(aead *xorNonceAEAD, headerProtector headerProtector) LongHeaderOpener {
|
||||||
|
if aead.NonceSize() != 8 {
|
||||||
|
panic("unexpected nonce size")
|
||||||
|
}
|
||||||
return &longHeaderOpener{
|
return &longHeaderOpener{
|
||||||
aead: aead,
|
aead: aead,
|
||||||
headerProtector: headerProtector,
|
headerProtector: headerProtector,
|
||||||
nonceBuf: make([]byte, aead.NonceSize()),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -76,10 +75,8 @@ func (o *longHeaderOpener) DecodePacketNumber(wirePN protocol.PacketNumber, wire
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o *longHeaderOpener) Open(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error) {
|
func (o *longHeaderOpener) Open(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error) {
|
||||||
binary.BigEndian.PutUint64(o.nonceBuf[len(o.nonceBuf)-8:], uint64(pn))
|
binary.BigEndian.PutUint64(o.nonceBuf[:], uint64(pn))
|
||||||
// The AEAD we're using here will be the qtls.aeadAESGCM13.
|
dec, err := o.aead.Open(dst, o.nonceBuf[:], src, ad)
|
||||||
// It uses the nonce provided here and XOR it with the IV.
|
|
||||||
dec, err := o.aead.Open(dst, o.nonceBuf, src, ad)
|
|
||||||
if err == nil {
|
if err == nil {
|
||||||
o.highestRcvdPN = max(o.highestRcvdPN, pn)
|
o.highestRcvdPN = max(o.highestRcvdPN, pn)
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -18,7 +18,7 @@ type cipherSuite struct {
|
||||||
ID uint16
|
ID uint16
|
||||||
Hash crypto.Hash
|
Hash crypto.Hash
|
||||||
KeyLen int
|
KeyLen int
|
||||||
AEAD func(key, nonceMask []byte) cipher.AEAD
|
AEAD func(key, nonceMask []byte) *xorNonceAEAD
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s cipherSuite) IVLen() int { return aeadNonceLength }
|
func (s cipherSuite) IVLen() int { return aeadNonceLength }
|
||||||
|
@ -36,7 +36,7 @@ func getCipherSuite(id uint16) *cipherSuite {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func aeadAESGCMTLS13(key, nonceMask []byte) cipher.AEAD {
|
func aeadAESGCMTLS13(key, nonceMask []byte) *xorNonceAEAD {
|
||||||
if len(nonceMask) != aeadNonceLength {
|
if len(nonceMask) != aeadNonceLength {
|
||||||
panic("tls: internal error: wrong nonce length")
|
panic("tls: internal error: wrong nonce length")
|
||||||
}
|
}
|
||||||
|
@ -54,7 +54,7 @@ func aeadAESGCMTLS13(key, nonceMask []byte) cipher.AEAD {
|
||||||
return ret
|
return ret
|
||||||
}
|
}
|
||||||
|
|
||||||
func aeadChaCha20Poly1305(key, nonceMask []byte) cipher.AEAD {
|
func aeadChaCha20Poly1305(key, nonceMask []byte) *xorNonceAEAD {
|
||||||
if len(nonceMask) != aeadNonceLength {
|
if len(nonceMask) != aeadNonceLength {
|
||||||
panic("tls: internal error: wrong nonce length")
|
panic("tls: internal error: wrong nonce length")
|
||||||
}
|
}
|
||||||
|
|
|
@ -8,7 +8,6 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -33,7 +32,7 @@ type cryptoSetup struct {
|
||||||
|
|
||||||
events []Event
|
events []Event
|
||||||
|
|
||||||
version protocol.VersionNumber
|
version protocol.Version
|
||||||
|
|
||||||
ourParams *wire.TransportParameters
|
ourParams *wire.TransportParameters
|
||||||
peerParams *wire.TransportParameters
|
peerParams *wire.TransportParameters
|
||||||
|
@ -48,8 +47,6 @@ type cryptoSetup struct {
|
||||||
|
|
||||||
perspective protocol.Perspective
|
perspective protocol.Perspective
|
||||||
|
|
||||||
mutex sync.Mutex // protects all members below
|
|
||||||
|
|
||||||
handshakeCompleteTime time.Time
|
handshakeCompleteTime time.Time
|
||||||
|
|
||||||
zeroRTTOpener LongHeaderOpener // only set for the server
|
zeroRTTOpener LongHeaderOpener // only set for the server
|
||||||
|
@ -79,7 +76,7 @@ func NewCryptoSetupClient(
|
||||||
rttStats *utils.RTTStats,
|
rttStats *utils.RTTStats,
|
||||||
tracer *logging.ConnectionTracer,
|
tracer *logging.ConnectionTracer,
|
||||||
logger utils.Logger,
|
logger utils.Logger,
|
||||||
version protocol.VersionNumber,
|
version protocol.Version,
|
||||||
) CryptoSetup {
|
) CryptoSetup {
|
||||||
cs := newCryptoSetup(
|
cs := newCryptoSetup(
|
||||||
connID,
|
connID,
|
||||||
|
@ -114,7 +111,7 @@ func NewCryptoSetupServer(
|
||||||
rttStats *utils.RTTStats,
|
rttStats *utils.RTTStats,
|
||||||
tracer *logging.ConnectionTracer,
|
tracer *logging.ConnectionTracer,
|
||||||
logger utils.Logger,
|
logger utils.Logger,
|
||||||
version protocol.VersionNumber,
|
version protocol.Version,
|
||||||
) CryptoSetup {
|
) CryptoSetup {
|
||||||
cs := newCryptoSetup(
|
cs := newCryptoSetup(
|
||||||
connID,
|
connID,
|
||||||
|
@ -172,7 +169,7 @@ func newCryptoSetup(
|
||||||
tracer *logging.ConnectionTracer,
|
tracer *logging.ConnectionTracer,
|
||||||
logger utils.Logger,
|
logger utils.Logger,
|
||||||
perspective protocol.Perspective,
|
perspective protocol.Perspective,
|
||||||
version protocol.VersionNumber,
|
version protocol.Version,
|
||||||
) *cryptoSetup {
|
) *cryptoSetup {
|
||||||
initialSealer, initialOpener := NewInitialAEAD(connID, perspective, version)
|
initialSealer, initialOpener := NewInitialAEAD(connID, perspective, version)
|
||||||
if tracer != nil && tracer.UpdatedKeyFromTLS != nil {
|
if tracer != nil && tracer.UpdatedKeyFromTLS != nil {
|
||||||
|
@ -269,10 +266,10 @@ func (h *cryptoSetup) handleEvent(ev tls.QUICEvent) (done bool, err error) {
|
||||||
case tls.QUICNoEvent:
|
case tls.QUICNoEvent:
|
||||||
return true, nil
|
return true, nil
|
||||||
case tls.QUICSetReadSecret:
|
case tls.QUICSetReadSecret:
|
||||||
h.SetReadKey(ev.Level, ev.Suite, ev.Data)
|
h.setReadKey(ev.Level, ev.Suite, ev.Data)
|
||||||
return false, nil
|
return false, nil
|
||||||
case tls.QUICSetWriteSecret:
|
case tls.QUICSetWriteSecret:
|
||||||
h.SetWriteKey(ev.Level, ev.Suite, ev.Data)
|
h.setWriteKey(ev.Level, ev.Suite, ev.Data)
|
||||||
return false, nil
|
return false, nil
|
||||||
case tls.QUICTransportParameters:
|
case tls.QUICTransportParameters:
|
||||||
return false, h.handleTransportParameters(ev.Data)
|
return false, h.handleTransportParameters(ev.Data)
|
||||||
|
@ -434,19 +431,16 @@ func (h *cryptoSetup) handleSessionTicket(sessionTicketData []byte, using0RTT bo
|
||||||
func (h *cryptoSetup) rejected0RTT() {
|
func (h *cryptoSetup) rejected0RTT() {
|
||||||
h.logger.Debugf("0-RTT was rejected. Dropping 0-RTT keys.")
|
h.logger.Debugf("0-RTT was rejected. Dropping 0-RTT keys.")
|
||||||
|
|
||||||
h.mutex.Lock()
|
|
||||||
had0RTTKeys := h.zeroRTTSealer != nil
|
had0RTTKeys := h.zeroRTTSealer != nil
|
||||||
h.zeroRTTSealer = nil
|
h.zeroRTTSealer = nil
|
||||||
h.mutex.Unlock()
|
|
||||||
|
|
||||||
if had0RTTKeys {
|
if had0RTTKeys {
|
||||||
h.events = append(h.events, Event{Kind: EventDiscard0RTTKeys})
|
h.events = append(h.events, Event{Kind: EventDiscard0RTTKeys})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *cryptoSetup) SetReadKey(el tls.QUICEncryptionLevel, suiteID uint16, trafficSecret []byte) {
|
func (h *cryptoSetup) setReadKey(el tls.QUICEncryptionLevel, suiteID uint16, trafficSecret []byte) {
|
||||||
suite := getCipherSuite(suiteID)
|
suite := getCipherSuite(suiteID)
|
||||||
h.mutex.Lock()
|
|
||||||
//nolint:exhaustive // The TLS stack doesn't export Initial keys.
|
//nolint:exhaustive // The TLS stack doesn't export Initial keys.
|
||||||
switch el {
|
switch el {
|
||||||
case tls.QUICEncryptionLevelEarly:
|
case tls.QUICEncryptionLevelEarly:
|
||||||
|
@ -478,16 +472,14 @@ func (h *cryptoSetup) SetReadKey(el tls.QUICEncryptionLevel, suiteID uint16, tra
|
||||||
default:
|
default:
|
||||||
panic("unexpected read encryption level")
|
panic("unexpected read encryption level")
|
||||||
}
|
}
|
||||||
h.mutex.Unlock()
|
|
||||||
h.events = append(h.events, Event{Kind: EventReceivedReadKeys})
|
h.events = append(h.events, Event{Kind: EventReceivedReadKeys})
|
||||||
if h.tracer != nil && h.tracer.UpdatedKeyFromTLS != nil {
|
if h.tracer != nil && h.tracer.UpdatedKeyFromTLS != nil {
|
||||||
h.tracer.UpdatedKeyFromTLS(qtls.FromTLSEncryptionLevel(el), h.perspective.Opposite())
|
h.tracer.UpdatedKeyFromTLS(qtls.FromTLSEncryptionLevel(el), h.perspective.Opposite())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *cryptoSetup) SetWriteKey(el tls.QUICEncryptionLevel, suiteID uint16, trafficSecret []byte) {
|
func (h *cryptoSetup) setWriteKey(el tls.QUICEncryptionLevel, suiteID uint16, trafficSecret []byte) {
|
||||||
suite := getCipherSuite(suiteID)
|
suite := getCipherSuite(suiteID)
|
||||||
h.mutex.Lock()
|
|
||||||
//nolint:exhaustive // The TLS stack doesn't export Initial keys.
|
//nolint:exhaustive // The TLS stack doesn't export Initial keys.
|
||||||
switch el {
|
switch el {
|
||||||
case tls.QUICEncryptionLevelEarly:
|
case tls.QUICEncryptionLevelEarly:
|
||||||
|
@ -498,7 +490,6 @@ func (h *cryptoSetup) SetWriteKey(el tls.QUICEncryptionLevel, suiteID uint16, tr
|
||||||
createAEAD(suite, trafficSecret, h.version),
|
createAEAD(suite, trafficSecret, h.version),
|
||||||
newHeaderProtector(suite, trafficSecret, true, h.version),
|
newHeaderProtector(suite, trafficSecret, true, h.version),
|
||||||
)
|
)
|
||||||
h.mutex.Unlock()
|
|
||||||
if h.logger.Debug() {
|
if h.logger.Debug() {
|
||||||
h.logger.Debugf("Installed 0-RTT Write keys (using %s)", tls.CipherSuiteName(suite.ID))
|
h.logger.Debugf("Installed 0-RTT Write keys (using %s)", tls.CipherSuiteName(suite.ID))
|
||||||
}
|
}
|
||||||
|
@ -533,7 +524,6 @@ func (h *cryptoSetup) SetWriteKey(el tls.QUICEncryptionLevel, suiteID uint16, tr
|
||||||
default:
|
default:
|
||||||
panic("unexpected write encryption level")
|
panic("unexpected write encryption level")
|
||||||
}
|
}
|
||||||
h.mutex.Unlock()
|
|
||||||
if h.tracer != nil && h.tracer.UpdatedKeyFromTLS != nil {
|
if h.tracer != nil && h.tracer.UpdatedKeyFromTLS != nil {
|
||||||
h.tracer.UpdatedKeyFromTLS(qtls.FromTLSEncryptionLevel(el), h.perspective)
|
h.tracer.UpdatedKeyFromTLS(qtls.FromTLSEncryptionLevel(el), h.perspective)
|
||||||
}
|
}
|
||||||
|
@ -555,11 +545,9 @@ func (h *cryptoSetup) writeRecord(encLevel tls.QUICEncryptionLevel, p []byte) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *cryptoSetup) DiscardInitialKeys() {
|
func (h *cryptoSetup) DiscardInitialKeys() {
|
||||||
h.mutex.Lock()
|
|
||||||
dropped := h.initialOpener != nil
|
dropped := h.initialOpener != nil
|
||||||
h.initialOpener = nil
|
h.initialOpener = nil
|
||||||
h.initialSealer = nil
|
h.initialSealer = nil
|
||||||
h.mutex.Unlock()
|
|
||||||
if dropped {
|
if dropped {
|
||||||
h.logger.Debugf("Dropping Initial keys.")
|
h.logger.Debugf("Dropping Initial keys.")
|
||||||
}
|
}
|
||||||
|
@ -574,22 +562,17 @@ func (h *cryptoSetup) SetHandshakeConfirmed() {
|
||||||
h.aead.SetHandshakeConfirmed()
|
h.aead.SetHandshakeConfirmed()
|
||||||
// drop Handshake keys
|
// drop Handshake keys
|
||||||
var dropped bool
|
var dropped bool
|
||||||
h.mutex.Lock()
|
|
||||||
if h.handshakeOpener != nil {
|
if h.handshakeOpener != nil {
|
||||||
h.handshakeOpener = nil
|
h.handshakeOpener = nil
|
||||||
h.handshakeSealer = nil
|
h.handshakeSealer = nil
|
||||||
dropped = true
|
dropped = true
|
||||||
}
|
}
|
||||||
h.mutex.Unlock()
|
|
||||||
if dropped {
|
if dropped {
|
||||||
h.logger.Debugf("Dropping Handshake keys.")
|
h.logger.Debugf("Dropping Handshake keys.")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *cryptoSetup) GetInitialSealer() (LongHeaderSealer, error) {
|
func (h *cryptoSetup) GetInitialSealer() (LongHeaderSealer, error) {
|
||||||
h.mutex.Lock()
|
|
||||||
defer h.mutex.Unlock()
|
|
||||||
|
|
||||||
if h.initialSealer == nil {
|
if h.initialSealer == nil {
|
||||||
return nil, ErrKeysDropped
|
return nil, ErrKeysDropped
|
||||||
}
|
}
|
||||||
|
@ -597,9 +580,6 @@ func (h *cryptoSetup) GetInitialSealer() (LongHeaderSealer, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *cryptoSetup) Get0RTTSealer() (LongHeaderSealer, error) {
|
func (h *cryptoSetup) Get0RTTSealer() (LongHeaderSealer, error) {
|
||||||
h.mutex.Lock()
|
|
||||||
defer h.mutex.Unlock()
|
|
||||||
|
|
||||||
if h.zeroRTTSealer == nil {
|
if h.zeroRTTSealer == nil {
|
||||||
return nil, ErrKeysDropped
|
return nil, ErrKeysDropped
|
||||||
}
|
}
|
||||||
|
@ -607,9 +587,6 @@ func (h *cryptoSetup) Get0RTTSealer() (LongHeaderSealer, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *cryptoSetup) GetHandshakeSealer() (LongHeaderSealer, error) {
|
func (h *cryptoSetup) GetHandshakeSealer() (LongHeaderSealer, error) {
|
||||||
h.mutex.Lock()
|
|
||||||
defer h.mutex.Unlock()
|
|
||||||
|
|
||||||
if h.handshakeSealer == nil {
|
if h.handshakeSealer == nil {
|
||||||
if h.initialSealer == nil {
|
if h.initialSealer == nil {
|
||||||
return nil, ErrKeysDropped
|
return nil, ErrKeysDropped
|
||||||
|
@ -620,9 +597,6 @@ func (h *cryptoSetup) GetHandshakeSealer() (LongHeaderSealer, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *cryptoSetup) Get1RTTSealer() (ShortHeaderSealer, error) {
|
func (h *cryptoSetup) Get1RTTSealer() (ShortHeaderSealer, error) {
|
||||||
h.mutex.Lock()
|
|
||||||
defer h.mutex.Unlock()
|
|
||||||
|
|
||||||
if !h.has1RTTSealer {
|
if !h.has1RTTSealer {
|
||||||
return nil, ErrKeysNotYetAvailable
|
return nil, ErrKeysNotYetAvailable
|
||||||
}
|
}
|
||||||
|
@ -630,9 +604,6 @@ func (h *cryptoSetup) Get1RTTSealer() (ShortHeaderSealer, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *cryptoSetup) GetInitialOpener() (LongHeaderOpener, error) {
|
func (h *cryptoSetup) GetInitialOpener() (LongHeaderOpener, error) {
|
||||||
h.mutex.Lock()
|
|
||||||
defer h.mutex.Unlock()
|
|
||||||
|
|
||||||
if h.initialOpener == nil {
|
if h.initialOpener == nil {
|
||||||
return nil, ErrKeysDropped
|
return nil, ErrKeysDropped
|
||||||
}
|
}
|
||||||
|
@ -640,9 +611,6 @@ func (h *cryptoSetup) GetInitialOpener() (LongHeaderOpener, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *cryptoSetup) Get0RTTOpener() (LongHeaderOpener, error) {
|
func (h *cryptoSetup) Get0RTTOpener() (LongHeaderOpener, error) {
|
||||||
h.mutex.Lock()
|
|
||||||
defer h.mutex.Unlock()
|
|
||||||
|
|
||||||
if h.zeroRTTOpener == nil {
|
if h.zeroRTTOpener == nil {
|
||||||
if h.initialOpener != nil {
|
if h.initialOpener != nil {
|
||||||
return nil, ErrKeysNotYetAvailable
|
return nil, ErrKeysNotYetAvailable
|
||||||
|
@ -654,9 +622,6 @@ func (h *cryptoSetup) Get0RTTOpener() (LongHeaderOpener, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *cryptoSetup) GetHandshakeOpener() (LongHeaderOpener, error) {
|
func (h *cryptoSetup) GetHandshakeOpener() (LongHeaderOpener, error) {
|
||||||
h.mutex.Lock()
|
|
||||||
defer h.mutex.Unlock()
|
|
||||||
|
|
||||||
if h.handshakeOpener == nil {
|
if h.handshakeOpener == nil {
|
||||||
if h.initialOpener != nil {
|
if h.initialOpener != nil {
|
||||||
return nil, ErrKeysNotYetAvailable
|
return nil, ErrKeysNotYetAvailable
|
||||||
|
@ -668,9 +633,6 @@ func (h *cryptoSetup) GetHandshakeOpener() (LongHeaderOpener, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *cryptoSetup) Get1RTTOpener() (ShortHeaderOpener, error) {
|
func (h *cryptoSetup) Get1RTTOpener() (ShortHeaderOpener, error) {
|
||||||
h.mutex.Lock()
|
|
||||||
defer h.mutex.Unlock()
|
|
||||||
|
|
||||||
if h.zeroRTTOpener != nil && time.Since(h.handshakeCompleteTime) > 3*h.rttStats.PTO(true) {
|
if h.zeroRTTOpener != nil && time.Since(h.handshakeCompleteTime) > 3*h.rttStats.PTO(true) {
|
||||||
h.zeroRTTOpener = nil
|
h.zeroRTTOpener = nil
|
||||||
h.logger.Debugf("Dropping 0-RTT keys.")
|
h.logger.Debugf("Dropping 0-RTT keys.")
|
||||||
|
|
|
@ -17,14 +17,14 @@ type headerProtector interface {
|
||||||
DecryptHeader(sample []byte, firstByte *byte, hdrBytes []byte)
|
DecryptHeader(sample []byte, firstByte *byte, hdrBytes []byte)
|
||||||
}
|
}
|
||||||
|
|
||||||
func hkdfHeaderProtectionLabel(v protocol.VersionNumber) string {
|
func hkdfHeaderProtectionLabel(v protocol.Version) string {
|
||||||
if v == protocol.Version2 {
|
if v == protocol.Version2 {
|
||||||
return "quicv2 hp"
|
return "quicv2 hp"
|
||||||
}
|
}
|
||||||
return "quic hp"
|
return "quic hp"
|
||||||
}
|
}
|
||||||
|
|
||||||
func newHeaderProtector(suite *cipherSuite, trafficSecret []byte, isLongHeader bool, v protocol.VersionNumber) headerProtector {
|
func newHeaderProtector(suite *cipherSuite, trafficSecret []byte, isLongHeader bool, v protocol.Version) headerProtector {
|
||||||
hkdfLabel := hkdfHeaderProtectionLabel(v)
|
hkdfLabel := hkdfHeaderProtectionLabel(v)
|
||||||
switch suite.ID {
|
switch suite.ID {
|
||||||
case tls.TLS_AES_128_GCM_SHA256, tls.TLS_AES_256_GCM_SHA384:
|
case tls.TLS_AES_128_GCM_SHA256, tls.TLS_AES_256_GCM_SHA384:
|
||||||
|
@ -37,7 +37,7 @@ func newHeaderProtector(suite *cipherSuite, trafficSecret []byte, isLongHeader b
|
||||||
}
|
}
|
||||||
|
|
||||||
type aesHeaderProtector struct {
|
type aesHeaderProtector struct {
|
||||||
mask []byte
|
mask [16]byte // AES always has a 16 byte block size
|
||||||
block cipher.Block
|
block cipher.Block
|
||||||
isLongHeader bool
|
isLongHeader bool
|
||||||
}
|
}
|
||||||
|
@ -52,7 +52,6 @@ func newAESHeaderProtector(suite *cipherSuite, trafficSecret []byte, isLongHeade
|
||||||
}
|
}
|
||||||
return &aesHeaderProtector{
|
return &aesHeaderProtector{
|
||||||
block: block,
|
block: block,
|
||||||
mask: make([]byte, block.BlockSize()),
|
|
||||||
isLongHeader: isLongHeader,
|
isLongHeader: isLongHeader,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -69,7 +68,7 @@ func (p *aesHeaderProtector) apply(sample []byte, firstByte *byte, hdrBytes []by
|
||||||
if len(sample) != len(p.mask) {
|
if len(sample) != len(p.mask) {
|
||||||
panic("invalid sample size")
|
panic("invalid sample size")
|
||||||
}
|
}
|
||||||
p.block.Encrypt(p.mask, sample)
|
p.block.Encrypt(p.mask[:], sample)
|
||||||
if p.isLongHeader {
|
if p.isLongHeader {
|
||||||
*firstByte ^= p.mask[0] & 0xf
|
*firstByte ^= p.mask[0] & 0xf
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -7,7 +7,7 @@ import (
|
||||||
"golang.org/x/crypto/hkdf"
|
"golang.org/x/crypto/hkdf"
|
||||||
)
|
)
|
||||||
|
|
||||||
// hkdfExpandLabel HKDF expands a label.
|
// hkdfExpandLabel HKDF expands a label as defined in RFC 8446, section 7.1.
|
||||||
// Since this implementation avoids using a cryptobyte.Builder, it is about 15% faster than the
|
// Since this implementation avoids using a cryptobyte.Builder, it is about 15% faster than the
|
||||||
// hkdfExpandLabel in the standard library.
|
// hkdfExpandLabel in the standard library.
|
||||||
func hkdfExpandLabel(hash crypto.Hash, secret, context []byte, label string, length int) []byte {
|
func hkdfExpandLabel(hash crypto.Hash, secret, context []byte, label string, length int) []byte {
|
||||||
|
|
|
@ -21,7 +21,7 @@ const (
|
||||||
hkdfLabelIVV2 = "quicv2 iv"
|
hkdfLabelIVV2 = "quicv2 iv"
|
||||||
)
|
)
|
||||||
|
|
||||||
func getSalt(v protocol.VersionNumber) []byte {
|
func getSalt(v protocol.Version) []byte {
|
||||||
if v == protocol.Version2 {
|
if v == protocol.Version2 {
|
||||||
return quicSaltV2
|
return quicSaltV2
|
||||||
}
|
}
|
||||||
|
@ -31,7 +31,7 @@ func getSalt(v protocol.VersionNumber) []byte {
|
||||||
var initialSuite = getCipherSuite(tls.TLS_AES_128_GCM_SHA256)
|
var initialSuite = getCipherSuite(tls.TLS_AES_128_GCM_SHA256)
|
||||||
|
|
||||||
// NewInitialAEAD creates a new AEAD for Initial encryption / decryption.
|
// NewInitialAEAD creates a new AEAD for Initial encryption / decryption.
|
||||||
func NewInitialAEAD(connID protocol.ConnectionID, pers protocol.Perspective, v protocol.VersionNumber) (LongHeaderSealer, LongHeaderOpener) {
|
func NewInitialAEAD(connID protocol.ConnectionID, pers protocol.Perspective, v protocol.Version) (LongHeaderSealer, LongHeaderOpener) {
|
||||||
clientSecret, serverSecret := computeSecrets(connID, v)
|
clientSecret, serverSecret := computeSecrets(connID, v)
|
||||||
var mySecret, otherSecret []byte
|
var mySecret, otherSecret []byte
|
||||||
if pers == protocol.PerspectiveClient {
|
if pers == protocol.PerspectiveClient {
|
||||||
|
@ -51,14 +51,14 @@ func NewInitialAEAD(connID protocol.ConnectionID, pers protocol.Perspective, v p
|
||||||
newLongHeaderOpener(decrypter, newAESHeaderProtector(initialSuite, otherSecret, true, hkdfHeaderProtectionLabel(v)))
|
newLongHeaderOpener(decrypter, newAESHeaderProtector(initialSuite, otherSecret, true, hkdfHeaderProtectionLabel(v)))
|
||||||
}
|
}
|
||||||
|
|
||||||
func computeSecrets(connID protocol.ConnectionID, v protocol.VersionNumber) (clientSecret, serverSecret []byte) {
|
func computeSecrets(connID protocol.ConnectionID, v protocol.Version) (clientSecret, serverSecret []byte) {
|
||||||
initialSecret := hkdf.Extract(crypto.SHA256.New, connID.Bytes(), getSalt(v))
|
initialSecret := hkdf.Extract(crypto.SHA256.New, connID.Bytes(), getSalt(v))
|
||||||
clientSecret = hkdfExpandLabel(crypto.SHA256, initialSecret, []byte{}, "client in", crypto.SHA256.Size())
|
clientSecret = hkdfExpandLabel(crypto.SHA256, initialSecret, []byte{}, "client in", crypto.SHA256.Size())
|
||||||
serverSecret = hkdfExpandLabel(crypto.SHA256, initialSecret, []byte{}, "server in", crypto.SHA256.Size())
|
serverSecret = hkdfExpandLabel(crypto.SHA256, initialSecret, []byte{}, "server in", crypto.SHA256.Size())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func computeInitialKeyAndIV(secret []byte, v protocol.VersionNumber) (key, iv []byte) {
|
func computeInitialKeyAndIV(secret []byte, v protocol.Version) (key, iv []byte) {
|
||||||
keyLabel := hkdfLabelKeyV1
|
keyLabel := hkdfLabelKeyV1
|
||||||
ivLabel := hkdfLabelIVV1
|
ivLabel := hkdfLabelIVV1
|
||||||
if v == protocol.Version2 {
|
if v == protocol.Version2 {
|
||||||
|
|
|
@ -40,7 +40,7 @@ var (
|
||||||
)
|
)
|
||||||
|
|
||||||
// GetRetryIntegrityTag calculates the integrity tag on a Retry packet
|
// GetRetryIntegrityTag calculates the integrity tag on a Retry packet
|
||||||
func GetRetryIntegrityTag(retry []byte, origDestConnID protocol.ConnectionID, version protocol.VersionNumber) *[16]byte {
|
func GetRetryIntegrityTag(retry []byte, origDestConnID protocol.ConnectionID, version protocol.Version) *[16]byte {
|
||||||
retryMutex.Lock()
|
retryMutex.Lock()
|
||||||
defer retryMutex.Unlock()
|
defer retryMutex.Unlock()
|
||||||
|
|
||||||
|
|
|
@ -59,7 +59,7 @@ type updatableAEAD struct {
|
||||||
|
|
||||||
tracer *logging.ConnectionTracer
|
tracer *logging.ConnectionTracer
|
||||||
logger utils.Logger
|
logger utils.Logger
|
||||||
version protocol.VersionNumber
|
version protocol.Version
|
||||||
|
|
||||||
// use a single slice to avoid allocations
|
// use a single slice to avoid allocations
|
||||||
nonceBuf []byte
|
nonceBuf []byte
|
||||||
|
@ -70,7 +70,7 @@ var (
|
||||||
_ ShortHeaderSealer = &updatableAEAD{}
|
_ ShortHeaderSealer = &updatableAEAD{}
|
||||||
)
|
)
|
||||||
|
|
||||||
func newUpdatableAEAD(rttStats *utils.RTTStats, tracer *logging.ConnectionTracer, logger utils.Logger, version protocol.VersionNumber) *updatableAEAD {
|
func newUpdatableAEAD(rttStats *utils.RTTStats, tracer *logging.ConnectionTracer, logger utils.Logger, version protocol.Version) *updatableAEAD {
|
||||||
return &updatableAEAD{
|
return &updatableAEAD{
|
||||||
firstPacketNumber: protocol.InvalidPacketNumber,
|
firstPacketNumber: protocol.InvalidPacketNumber,
|
||||||
largestAcked: protocol.InvalidPacketNumber,
|
largestAcked: protocol.InvalidPacketNumber,
|
||||||
|
@ -133,7 +133,7 @@ func (a *updatableAEAD) SetReadKey(suite *cipherSuite, trafficSecret []byte) {
|
||||||
|
|
||||||
// SetWriteKey sets the write key.
|
// SetWriteKey sets the write key.
|
||||||
// For the client, this function is called after SetReadKey.
|
// For the client, this function is called after SetReadKey.
|
||||||
// For the server, this function is called before SetWriteKey.
|
// For the server, this function is called before SetReadKey.
|
||||||
func (a *updatableAEAD) SetWriteKey(suite *cipherSuite, trafficSecret []byte) {
|
func (a *updatableAEAD) SetWriteKey(suite *cipherSuite, trafficSecret []byte) {
|
||||||
a.sendAEAD = createAEAD(suite, trafficSecret, a.version)
|
a.sendAEAD = createAEAD(suite, trafficSecret, a.version)
|
||||||
a.headerEncrypter = newHeaderProtector(suite, trafficSecret, false, a.version)
|
a.headerEncrypter = newHeaderProtector(suite, trafficSecret, false, a.version)
|
||||||
|
|
|
@ -17,9 +17,9 @@ func (p Perspective) Opposite() Perspective {
|
||||||
func (p Perspective) String() string {
|
func (p Perspective) String() string {
|
||||||
switch p {
|
switch p {
|
||||||
case PerspectiveServer:
|
case PerspectiveServer:
|
||||||
return "Server"
|
return "server"
|
||||||
case PerspectiveClient:
|
case PerspectiveClient:
|
||||||
return "Client"
|
return "client"
|
||||||
default:
|
default:
|
||||||
return "invalid perspective"
|
return "invalid perspective"
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,14 +1,17 @@
|
||||||
package protocol
|
package protocol
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/rand"
|
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"fmt"
|
"fmt"
|
||||||
"math"
|
"math"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"golang.org/x/exp/rand"
|
||||||
)
|
)
|
||||||
|
|
||||||
// VersionNumber is a version number as int
|
// Version is a version number as int
|
||||||
type VersionNumber uint32
|
type Version uint32
|
||||||
|
|
||||||
// gQUIC version range as defined in the wiki: https://github.com/quicwg/base-drafts/wiki/QUIC-Versions
|
// gQUIC version range as defined in the wiki: https://github.com/quicwg/base-drafts/wiki/QUIC-Versions
|
||||||
const (
|
const (
|
||||||
|
@ -18,22 +21,22 @@ const (
|
||||||
|
|
||||||
// The version numbers, making grepping easier
|
// The version numbers, making grepping easier
|
||||||
const (
|
const (
|
||||||
VersionUnknown VersionNumber = math.MaxUint32
|
VersionUnknown Version = math.MaxUint32
|
||||||
versionDraft29 VersionNumber = 0xff00001d // draft-29 used to be a widely deployed version
|
versionDraft29 Version = 0xff00001d // draft-29 used to be a widely deployed version
|
||||||
Version1 VersionNumber = 0x1
|
Version1 Version = 0x1
|
||||||
Version2 VersionNumber = 0x6b3343cf
|
Version2 Version = 0x6b3343cf
|
||||||
)
|
)
|
||||||
|
|
||||||
// SupportedVersions lists the versions that the server supports
|
// SupportedVersions lists the versions that the server supports
|
||||||
// must be in sorted descending order
|
// must be in sorted descending order
|
||||||
var SupportedVersions = []VersionNumber{Version1, Version2}
|
var SupportedVersions = []Version{Version1, Version2}
|
||||||
|
|
||||||
// IsValidVersion says if the version is known to quic-go
|
// IsValidVersion says if the version is known to quic-go
|
||||||
func IsValidVersion(v VersionNumber) bool {
|
func IsValidVersion(v Version) bool {
|
||||||
return v == Version1 || IsSupportedVersion(SupportedVersions, v)
|
return v == Version1 || IsSupportedVersion(SupportedVersions, v)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (vn VersionNumber) String() string {
|
func (vn Version) String() string {
|
||||||
//nolint:exhaustive
|
//nolint:exhaustive
|
||||||
switch vn {
|
switch vn {
|
||||||
case VersionUnknown:
|
case VersionUnknown:
|
||||||
|
@ -52,16 +55,16 @@ func (vn VersionNumber) String() string {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (vn VersionNumber) isGQUIC() bool {
|
func (vn Version) isGQUIC() bool {
|
||||||
return vn > gquicVersion0 && vn <= maxGquicVersion
|
return vn > gquicVersion0 && vn <= maxGquicVersion
|
||||||
}
|
}
|
||||||
|
|
||||||
func (vn VersionNumber) toGQUICVersion() int {
|
func (vn Version) toGQUICVersion() int {
|
||||||
return int(10*(vn-gquicVersion0)/0x100) + int(vn%0x10)
|
return int(10*(vn-gquicVersion0)/0x100) + int(vn%0x10)
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsSupportedVersion returns true if the server supports this version
|
// IsSupportedVersion returns true if the server supports this version
|
||||||
func IsSupportedVersion(supported []VersionNumber, v VersionNumber) bool {
|
func IsSupportedVersion(supported []Version, v Version) bool {
|
||||||
for _, t := range supported {
|
for _, t := range supported {
|
||||||
if t == v {
|
if t == v {
|
||||||
return true
|
return true
|
||||||
|
@ -74,7 +77,7 @@ func IsSupportedVersion(supported []VersionNumber, v VersionNumber) bool {
|
||||||
// ours is a slice of versions that we support, sorted by our preference (descending)
|
// ours is a slice of versions that we support, sorted by our preference (descending)
|
||||||
// theirs is a slice of versions offered by the peer. The order does not matter.
|
// theirs is a slice of versions offered by the peer. The order does not matter.
|
||||||
// The bool returned indicates if a matching version was found.
|
// The bool returned indicates if a matching version was found.
|
||||||
func ChooseSupportedVersion(ours, theirs []VersionNumber) (VersionNumber, bool) {
|
func ChooseSupportedVersion(ours, theirs []Version) (Version, bool) {
|
||||||
for _, ourVer := range ours {
|
for _, ourVer := range ours {
|
||||||
for _, theirVer := range theirs {
|
for _, theirVer := range theirs {
|
||||||
if ourVer == theirVer {
|
if ourVer == theirVer {
|
||||||
|
@ -85,19 +88,25 @@ func ChooseSupportedVersion(ours, theirs []VersionNumber) (VersionNumber, bool)
|
||||||
return 0, false
|
return 0, false
|
||||||
}
|
}
|
||||||
|
|
||||||
// generateReservedVersion generates a reserved version number (v & 0x0f0f0f0f == 0x0a0a0a0a)
|
var (
|
||||||
func generateReservedVersion() VersionNumber {
|
versionNegotiationMx sync.Mutex
|
||||||
b := make([]byte, 4)
|
versionNegotiationRand = rand.New(rand.NewSource(uint64(time.Now().UnixNano())))
|
||||||
_, _ = rand.Read(b) // ignore the error here. Failure to read random data doesn't break anything
|
)
|
||||||
return VersionNumber((binary.BigEndian.Uint32(b) | 0x0a0a0a0a) & 0xfafafafa)
|
|
||||||
|
// generateReservedVersion generates a reserved version (v & 0x0f0f0f0f == 0x0a0a0a0a)
|
||||||
|
func generateReservedVersion() Version {
|
||||||
|
var b [4]byte
|
||||||
|
_, _ = versionNegotiationRand.Read(b[:]) // ignore the error here. Failure to read random data doesn't break anything
|
||||||
|
return Version((binary.BigEndian.Uint32(b[:]) | 0x0a0a0a0a) & 0xfafafafa)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetGreasedVersions adds one reserved version number to a slice of version numbers, at a random position
|
// GetGreasedVersions adds one reserved version number to a slice of version numbers, at a random position.
|
||||||
func GetGreasedVersions(supported []VersionNumber) []VersionNumber {
|
// It doesn't modify the supported slice.
|
||||||
b := make([]byte, 1)
|
func GetGreasedVersions(supported []Version) []Version {
|
||||||
_, _ = rand.Read(b) // ignore the error here. Failure to read random data doesn't break anything
|
versionNegotiationMx.Lock()
|
||||||
randPos := int(b[0]) % (len(supported) + 1)
|
defer versionNegotiationMx.Unlock()
|
||||||
greased := make([]VersionNumber, len(supported)+1)
|
randPos := rand.Intn(len(supported) + 1)
|
||||||
|
greased := make([]Version, len(supported)+1)
|
||||||
copy(greased, supported[:randPos])
|
copy(greased, supported[:randPos])
|
||||||
greased[randPos] = generateReservedVersion()
|
greased[randPos] = generateReservedVersion()
|
||||||
copy(greased[randPos+1:], supported[randPos:])
|
copy(greased[randPos+1:], supported[randPos:])
|
||||||
|
|
|
@ -101,8 +101,8 @@ func (e *HandshakeTimeoutError) Is(target error) bool { return target == net.Err
|
||||||
|
|
||||||
// A VersionNegotiationError occurs when the client and the server can't agree on a QUIC version.
|
// A VersionNegotiationError occurs when the client and the server can't agree on a QUIC version.
|
||||||
type VersionNegotiationError struct {
|
type VersionNegotiationError struct {
|
||||||
Ours []protocol.VersionNumber
|
Ours []protocol.Version
|
||||||
Theirs []protocol.VersionNumber
|
Theirs []protocol.Version
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *VersionNegotiationError) Error() string {
|
func (e *VersionNegotiationError) Error() string {
|
||||||
|
|
|
@ -1,23 +1,11 @@
|
||||||
package qtls
|
package qtls
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto"
|
|
||||||
"crypto/cipher"
|
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"fmt"
|
"fmt"
|
||||||
"unsafe"
|
"unsafe"
|
||||||
)
|
)
|
||||||
|
|
||||||
type cipherSuiteTLS13 struct {
|
|
||||||
ID uint16
|
|
||||||
KeyLen int
|
|
||||||
AEAD func(key, fixedNonce []byte) cipher.AEAD
|
|
||||||
Hash crypto.Hash
|
|
||||||
}
|
|
||||||
|
|
||||||
//go:linkname cipherSuiteTLS13ByID crypto/tls.cipherSuiteTLS13ByID
|
|
||||||
func cipherSuiteTLS13ByID(id uint16) *cipherSuiteTLS13
|
|
||||||
|
|
||||||
//go:linkname cipherSuitesTLS13 crypto/tls.cipherSuitesTLS13
|
//go:linkname cipherSuitesTLS13 crypto/tls.cipherSuitesTLS13
|
||||||
var cipherSuitesTLS13 []unsafe.Pointer
|
var cipherSuitesTLS13 []unsafe.Pointer
|
||||||
|
|
||||||
|
|
|
@ -1,12 +1,12 @@
|
||||||
//go:build go1.21
|
|
||||||
|
|
||||||
package qtls
|
package qtls
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
|
"sync"
|
||||||
)
|
)
|
||||||
|
|
||||||
type clientSessionCache struct {
|
type clientSessionCache struct {
|
||||||
|
mx sync.Mutex
|
||||||
getData func(earlyData bool) []byte
|
getData func(earlyData bool) []byte
|
||||||
setData func(data []byte, earlyData bool) (allowEarlyData bool)
|
setData func(data []byte, earlyData bool) (allowEarlyData bool)
|
||||||
wrapped tls.ClientSessionCache
|
wrapped tls.ClientSessionCache
|
||||||
|
@ -14,7 +14,10 @@ type clientSessionCache struct {
|
||||||
|
|
||||||
var _ tls.ClientSessionCache = &clientSessionCache{}
|
var _ tls.ClientSessionCache = &clientSessionCache{}
|
||||||
|
|
||||||
func (c clientSessionCache) Put(key string, cs *tls.ClientSessionState) {
|
func (c *clientSessionCache) Put(key string, cs *tls.ClientSessionState) {
|
||||||
|
c.mx.Lock()
|
||||||
|
defer c.mx.Unlock()
|
||||||
|
|
||||||
if cs == nil {
|
if cs == nil {
|
||||||
c.wrapped.Put(key, nil)
|
c.wrapped.Put(key, nil)
|
||||||
return
|
return
|
||||||
|
@ -34,7 +37,10 @@ func (c clientSessionCache) Put(key string, cs *tls.ClientSessionState) {
|
||||||
c.wrapped.Put(key, newCS)
|
c.wrapped.Put(key, newCS)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c clientSessionCache) Get(key string) (*tls.ClientSessionState, bool) {
|
func (c *clientSessionCache) Get(key string) (*tls.ClientSessionState, bool) {
|
||||||
|
c.mx.Lock()
|
||||||
|
defer c.mx.Unlock()
|
||||||
|
|
||||||
cs, ok := c.wrapped.Get(key)
|
cs, ok := c.wrapped.Get(key)
|
||||||
if !ok || cs == nil {
|
if !ok || cs == nil {
|
||||||
return cs, ok
|
return cs, ok
|
||||||
|
|
|
@ -27,18 +27,6 @@ func MinTime(a, b time.Time) time.Time {
|
||||||
return a
|
return a
|
||||||
}
|
}
|
||||||
|
|
||||||
// MinNonZeroTime returns the earlist time that is not time.Time{}
|
|
||||||
// If both a and b are time.Time{}, it returns time.Time{}
|
|
||||||
func MinNonZeroTime(a, b time.Time) time.Time {
|
|
||||||
if a.IsZero() {
|
|
||||||
return b
|
|
||||||
}
|
|
||||||
if b.IsZero() {
|
|
||||||
return a
|
|
||||||
}
|
|
||||||
return MinTime(a, b)
|
|
||||||
}
|
|
||||||
|
|
||||||
// MaxTime returns the later time
|
// MaxTime returns the later time
|
||||||
func MaxTime(a, b time.Time) time.Time {
|
func MaxTime(a, b time.Time) time.Time {
|
||||||
if a.After(b) {
|
if a.After(b) {
|
||||||
|
|
|
@ -22,7 +22,7 @@ type AckFrame struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
// parseAckFrame reads an ACK frame
|
// parseAckFrame reads an ACK frame
|
||||||
func parseAckFrame(frame *AckFrame, r *bytes.Reader, typ uint64, ackDelayExponent uint8, _ protocol.VersionNumber) error {
|
func parseAckFrame(frame *AckFrame, r *bytes.Reader, typ uint64, ackDelayExponent uint8, _ protocol.Version) error {
|
||||||
ecn := typ == ackECNFrameType
|
ecn := typ == ackECNFrameType
|
||||||
|
|
||||||
la, err := quicvarint.Read(r)
|
la, err := quicvarint.Read(r)
|
||||||
|
@ -110,7 +110,7 @@ func parseAckFrame(frame *AckFrame, r *bytes.Reader, typ uint64, ackDelayExponen
|
||||||
}
|
}
|
||||||
|
|
||||||
// Append appends an ACK frame.
|
// Append appends an ACK frame.
|
||||||
func (f *AckFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, error) {
|
func (f *AckFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {
|
||||||
hasECN := f.ECT0 > 0 || f.ECT1 > 0 || f.ECNCE > 0
|
hasECN := f.ECT0 > 0 || f.ECT1 > 0 || f.ECNCE > 0
|
||||||
if hasECN {
|
if hasECN {
|
||||||
b = append(b, ackECNFrameType)
|
b = append(b, ackECNFrameType)
|
||||||
|
@ -143,7 +143,7 @@ func (f *AckFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Length of a written frame
|
// Length of a written frame
|
||||||
func (f *AckFrame) Length(_ protocol.VersionNumber) protocol.ByteCount {
|
func (f *AckFrame) Length(_ protocol.Version) protocol.ByteCount {
|
||||||
largestAcked := f.AckRanges[0].Largest
|
largestAcked := f.AckRanges[0].Largest
|
||||||
numRanges := f.numEncodableAckRanges()
|
numRanges := f.numEncodableAckRanges()
|
||||||
|
|
||||||
|
|
|
@ -16,7 +16,7 @@ type ConnectionCloseFrame struct {
|
||||||
ReasonPhrase string
|
ReasonPhrase string
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseConnectionCloseFrame(r *bytes.Reader, typ uint64, _ protocol.VersionNumber) (*ConnectionCloseFrame, error) {
|
func parseConnectionCloseFrame(r *bytes.Reader, typ uint64, _ protocol.Version) (*ConnectionCloseFrame, error) {
|
||||||
f := &ConnectionCloseFrame{IsApplicationError: typ == applicationCloseFrameType}
|
f := &ConnectionCloseFrame{IsApplicationError: typ == applicationCloseFrameType}
|
||||||
ec, err := quicvarint.Read(r)
|
ec, err := quicvarint.Read(r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -53,7 +53,7 @@ func parseConnectionCloseFrame(r *bytes.Reader, typ uint64, _ protocol.VersionNu
|
||||||
}
|
}
|
||||||
|
|
||||||
// Length of a written frame
|
// Length of a written frame
|
||||||
func (f *ConnectionCloseFrame) Length(protocol.VersionNumber) protocol.ByteCount {
|
func (f *ConnectionCloseFrame) Length(protocol.Version) protocol.ByteCount {
|
||||||
length := 1 + quicvarint.Len(f.ErrorCode) + quicvarint.Len(uint64(len(f.ReasonPhrase))) + protocol.ByteCount(len(f.ReasonPhrase))
|
length := 1 + quicvarint.Len(f.ErrorCode) + quicvarint.Len(uint64(len(f.ReasonPhrase))) + protocol.ByteCount(len(f.ReasonPhrase))
|
||||||
if !f.IsApplicationError {
|
if !f.IsApplicationError {
|
||||||
length += quicvarint.Len(f.FrameType) // for the frame type
|
length += quicvarint.Len(f.FrameType) // for the frame type
|
||||||
|
@ -61,7 +61,7 @@ func (f *ConnectionCloseFrame) Length(protocol.VersionNumber) protocol.ByteCount
|
||||||
return length
|
return length
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *ConnectionCloseFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, error) {
|
func (f *ConnectionCloseFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {
|
||||||
if f.IsApplicationError {
|
if f.IsApplicationError {
|
||||||
b = append(b, applicationCloseFrameType)
|
b = append(b, applicationCloseFrameType)
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -14,7 +14,7 @@ type CryptoFrame struct {
|
||||||
Data []byte
|
Data []byte
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseCryptoFrame(r *bytes.Reader, _ protocol.VersionNumber) (*CryptoFrame, error) {
|
func parseCryptoFrame(r *bytes.Reader, _ protocol.Version) (*CryptoFrame, error) {
|
||||||
frame := &CryptoFrame{}
|
frame := &CryptoFrame{}
|
||||||
offset, err := quicvarint.Read(r)
|
offset, err := quicvarint.Read(r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -38,7 +38,7 @@ func parseCryptoFrame(r *bytes.Reader, _ protocol.VersionNumber) (*CryptoFrame,
|
||||||
return frame, nil
|
return frame, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *CryptoFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, error) {
|
func (f *CryptoFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {
|
||||||
b = append(b, cryptoFrameType)
|
b = append(b, cryptoFrameType)
|
||||||
b = quicvarint.Append(b, uint64(f.Offset))
|
b = quicvarint.Append(b, uint64(f.Offset))
|
||||||
b = quicvarint.Append(b, uint64(len(f.Data)))
|
b = quicvarint.Append(b, uint64(len(f.Data)))
|
||||||
|
@ -47,7 +47,7 @@ func (f *CryptoFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Length of a written frame
|
// Length of a written frame
|
||||||
func (f *CryptoFrame) Length(_ protocol.VersionNumber) protocol.ByteCount {
|
func (f *CryptoFrame) Length(_ protocol.Version) protocol.ByteCount {
|
||||||
return 1 + quicvarint.Len(uint64(f.Offset)) + quicvarint.Len(uint64(len(f.Data))) + protocol.ByteCount(len(f.Data))
|
return 1 + quicvarint.Len(uint64(f.Offset)) + quicvarint.Len(uint64(len(f.Data))) + protocol.ByteCount(len(f.Data))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -71,7 +71,7 @@ func (f *CryptoFrame) MaxDataLen(maxSize protocol.ByteCount) protocol.ByteCount
|
||||||
// The frame might not be split if:
|
// The frame might not be split if:
|
||||||
// * the size is large enough to fit the whole frame
|
// * the size is large enough to fit the whole frame
|
||||||
// * the size is too small to fit even a 1-byte frame. In that case, the frame returned is nil.
|
// * the size is too small to fit even a 1-byte frame. In that case, the frame returned is nil.
|
||||||
func (f *CryptoFrame) MaybeSplitOffFrame(maxSize protocol.ByteCount, version protocol.VersionNumber) (*CryptoFrame, bool /* was splitting required */) {
|
func (f *CryptoFrame) MaybeSplitOffFrame(maxSize protocol.ByteCount, version protocol.Version) (*CryptoFrame, bool /* was splitting required */) {
|
||||||
if f.Length(version) <= maxSize {
|
if f.Length(version) <= maxSize {
|
||||||
return nil, false
|
return nil, false
|
||||||
}
|
}
|
||||||
|
|
|
@ -12,7 +12,7 @@ type DataBlockedFrame struct {
|
||||||
MaximumData protocol.ByteCount
|
MaximumData protocol.ByteCount
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseDataBlockedFrame(r *bytes.Reader, _ protocol.VersionNumber) (*DataBlockedFrame, error) {
|
func parseDataBlockedFrame(r *bytes.Reader, _ protocol.Version) (*DataBlockedFrame, error) {
|
||||||
offset, err := quicvarint.Read(r)
|
offset, err := quicvarint.Read(r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -20,12 +20,12 @@ func parseDataBlockedFrame(r *bytes.Reader, _ protocol.VersionNumber) (*DataBloc
|
||||||
return &DataBlockedFrame{MaximumData: protocol.ByteCount(offset)}, nil
|
return &DataBlockedFrame{MaximumData: protocol.ByteCount(offset)}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *DataBlockedFrame) Append(b []byte, version protocol.VersionNumber) ([]byte, error) {
|
func (f *DataBlockedFrame) Append(b []byte, version protocol.Version) ([]byte, error) {
|
||||||
b = append(b, dataBlockedFrameType)
|
b = append(b, dataBlockedFrameType)
|
||||||
return quicvarint.Append(b, uint64(f.MaximumData)), nil
|
return quicvarint.Append(b, uint64(f.MaximumData)), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Length of a written frame
|
// Length of a written frame
|
||||||
func (f *DataBlockedFrame) Length(version protocol.VersionNumber) protocol.ByteCount {
|
func (f *DataBlockedFrame) Length(version protocol.Version) protocol.ByteCount {
|
||||||
return 1 + quicvarint.Len(uint64(f.MaximumData))
|
return 1 + quicvarint.Len(uint64(f.MaximumData))
|
||||||
}
|
}
|
||||||
|
|
|
@ -20,7 +20,7 @@ type DatagramFrame struct {
|
||||||
Data []byte
|
Data []byte
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseDatagramFrame(r *bytes.Reader, typ uint64, _ protocol.VersionNumber) (*DatagramFrame, error) {
|
func parseDatagramFrame(r *bytes.Reader, typ uint64, _ protocol.Version) (*DatagramFrame, error) {
|
||||||
f := &DatagramFrame{}
|
f := &DatagramFrame{}
|
||||||
f.DataLenPresent = typ&0x1 > 0
|
f.DataLenPresent = typ&0x1 > 0
|
||||||
|
|
||||||
|
@ -45,7 +45,7 @@ func parseDatagramFrame(r *bytes.Reader, typ uint64, _ protocol.VersionNumber) (
|
||||||
return f, nil
|
return f, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *DatagramFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, error) {
|
func (f *DatagramFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {
|
||||||
typ := uint8(0x30)
|
typ := uint8(0x30)
|
||||||
if f.DataLenPresent {
|
if f.DataLenPresent {
|
||||||
typ ^= 0b1
|
typ ^= 0b1
|
||||||
|
@ -59,7 +59,7 @@ func (f *DatagramFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, erro
|
||||||
}
|
}
|
||||||
|
|
||||||
// MaxDataLen returns the maximum data length
|
// MaxDataLen returns the maximum data length
|
||||||
func (f *DatagramFrame) MaxDataLen(maxSize protocol.ByteCount, version protocol.VersionNumber) protocol.ByteCount {
|
func (f *DatagramFrame) MaxDataLen(maxSize protocol.ByteCount, version protocol.Version) protocol.ByteCount {
|
||||||
headerLen := protocol.ByteCount(1)
|
headerLen := protocol.ByteCount(1)
|
||||||
if f.DataLenPresent {
|
if f.DataLenPresent {
|
||||||
// pretend that the data size will be 1 bytes
|
// pretend that the data size will be 1 bytes
|
||||||
|
@ -77,7 +77,7 @@ func (f *DatagramFrame) MaxDataLen(maxSize protocol.ByteCount, version protocol.
|
||||||
}
|
}
|
||||||
|
|
||||||
// Length of a written frame
|
// Length of a written frame
|
||||||
func (f *DatagramFrame) Length(_ protocol.VersionNumber) protocol.ByteCount {
|
func (f *DatagramFrame) Length(_ protocol.Version) protocol.ByteCount {
|
||||||
length := 1 + protocol.ByteCount(len(f.Data))
|
length := 1 + protocol.ByteCount(len(f.Data))
|
||||||
if f.DataLenPresent {
|
if f.DataLenPresent {
|
||||||
length += quicvarint.Len(uint64(len(f.Data)))
|
length += quicvarint.Len(uint64(len(f.Data)))
|
||||||
|
|
|
@ -32,7 +32,7 @@ type ExtendedHeader struct {
|
||||||
parsedLen protocol.ByteCount
|
parsedLen protocol.ByteCount
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *ExtendedHeader) parse(b *bytes.Reader, v protocol.VersionNumber) (bool /* reserved bits valid */, error) {
|
func (h *ExtendedHeader) parse(b *bytes.Reader, v protocol.Version) (bool /* reserved bits valid */, error) {
|
||||||
startLen := b.Len()
|
startLen := b.Len()
|
||||||
// read the (now unencrypted) first byte
|
// read the (now unencrypted) first byte
|
||||||
var err error
|
var err error
|
||||||
|
@ -51,7 +51,7 @@ func (h *ExtendedHeader) parse(b *bytes.Reader, v protocol.VersionNumber) (bool
|
||||||
return reservedBitsValid, err
|
return reservedBitsValid, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *ExtendedHeader) parseLongHeader(b *bytes.Reader, _ protocol.VersionNumber) (bool /* reserved bits valid */, error) {
|
func (h *ExtendedHeader) parseLongHeader(b *bytes.Reader, _ protocol.Version) (bool /* reserved bits valid */, error) {
|
||||||
if err := h.readPacketNumber(b); err != nil {
|
if err := h.readPacketNumber(b); err != nil {
|
||||||
return false, err
|
return false, err
|
||||||
}
|
}
|
||||||
|
@ -95,7 +95,7 @@ func (h *ExtendedHeader) readPacketNumber(b *bytes.Reader) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Append appends the Header.
|
// Append appends the Header.
|
||||||
func (h *ExtendedHeader) Append(b []byte, v protocol.VersionNumber) ([]byte, error) {
|
func (h *ExtendedHeader) Append(b []byte, v protocol.Version) ([]byte, error) {
|
||||||
if h.DestConnectionID.Len() > protocol.MaxConnIDLen {
|
if h.DestConnectionID.Len() > protocol.MaxConnIDLen {
|
||||||
return nil, fmt.Errorf("invalid connection ID length: %d bytes", h.DestConnectionID.Len())
|
return nil, fmt.Errorf("invalid connection ID length: %d bytes", h.DestConnectionID.Len())
|
||||||
}
|
}
|
||||||
|
@ -162,7 +162,7 @@ func (h *ExtendedHeader) ParsedLen() protocol.ByteCount {
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetLength determines the length of the Header.
|
// GetLength determines the length of the Header.
|
||||||
func (h *ExtendedHeader) GetLength(_ protocol.VersionNumber) protocol.ByteCount {
|
func (h *ExtendedHeader) GetLength(_ protocol.Version) protocol.ByteCount {
|
||||||
length := 1 /* type byte */ + 4 /* version */ + 1 /* dest conn ID len */ + protocol.ByteCount(h.DestConnectionID.Len()) + 1 /* src conn ID len */ + protocol.ByteCount(h.SrcConnectionID.Len()) + protocol.ByteCount(h.PacketNumberLen) + 2 /* length */
|
length := 1 /* type byte */ + 4 /* version */ + 1 /* dest conn ID len */ + protocol.ByteCount(h.DestConnectionID.Len()) + 1 /* src conn ID len */ + protocol.ByteCount(h.SrcConnectionID.Len()) + protocol.ByteCount(h.PacketNumberLen) + 2 /* length */
|
||||||
if h.Type == protocol.PacketTypeInitial {
|
if h.Type == protocol.PacketTypeInitial {
|
||||||
length += quicvarint.Len(uint64(len(h.Token))) + protocol.ByteCount(len(h.Token))
|
length += quicvarint.Len(uint64(len(h.Token))) + protocol.ByteCount(len(h.Token))
|
||||||
|
|
|
@ -36,7 +36,8 @@ const (
|
||||||
handshakeDoneFrameType = 0x1e
|
handshakeDoneFrameType = 0x1e
|
||||||
)
|
)
|
||||||
|
|
||||||
type frameParser struct {
|
// The FrameParser parses QUIC frames, one by one.
|
||||||
|
type FrameParser struct {
|
||||||
r bytes.Reader // cached bytes.Reader, so we don't have to repeatedly allocate them
|
r bytes.Reader // cached bytes.Reader, so we don't have to repeatedly allocate them
|
||||||
|
|
||||||
ackDelayExponent uint8
|
ackDelayExponent uint8
|
||||||
|
@ -47,11 +48,9 @@ type frameParser struct {
|
||||||
ackFrame *AckFrame
|
ackFrame *AckFrame
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ FrameParser = &frameParser{}
|
|
||||||
|
|
||||||
// NewFrameParser creates a new frame parser.
|
// NewFrameParser creates a new frame parser.
|
||||||
func NewFrameParser(supportsDatagrams bool) *frameParser {
|
func NewFrameParser(supportsDatagrams bool) *FrameParser {
|
||||||
return &frameParser{
|
return &FrameParser{
|
||||||
r: *bytes.NewReader(nil),
|
r: *bytes.NewReader(nil),
|
||||||
supportsDatagrams: supportsDatagrams,
|
supportsDatagrams: supportsDatagrams,
|
||||||
ackFrame: &AckFrame{},
|
ackFrame: &AckFrame{},
|
||||||
|
@ -60,7 +59,7 @@ func NewFrameParser(supportsDatagrams bool) *frameParser {
|
||||||
|
|
||||||
// ParseNext parses the next frame.
|
// ParseNext parses the next frame.
|
||||||
// It skips PADDING frames.
|
// It skips PADDING frames.
|
||||||
func (p *frameParser) ParseNext(data []byte, encLevel protocol.EncryptionLevel, v protocol.VersionNumber) (int, Frame, error) {
|
func (p *FrameParser) ParseNext(data []byte, encLevel protocol.EncryptionLevel, v protocol.Version) (int, Frame, error) {
|
||||||
startLen := len(data)
|
startLen := len(data)
|
||||||
p.r.Reset(data)
|
p.r.Reset(data)
|
||||||
frame, err := p.parseNext(&p.r, encLevel, v)
|
frame, err := p.parseNext(&p.r, encLevel, v)
|
||||||
|
@ -69,7 +68,7 @@ func (p *frameParser) ParseNext(data []byte, encLevel protocol.EncryptionLevel,
|
||||||
return n, frame, err
|
return n, frame, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *frameParser) parseNext(r *bytes.Reader, encLevel protocol.EncryptionLevel, v protocol.VersionNumber) (Frame, error) {
|
func (p *FrameParser) parseNext(r *bytes.Reader, encLevel protocol.EncryptionLevel, v protocol.Version) (Frame, error) {
|
||||||
for r.Len() != 0 {
|
for r.Len() != 0 {
|
||||||
typ, err := quicvarint.Read(r)
|
typ, err := quicvarint.Read(r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -95,7 +94,7 @@ func (p *frameParser) parseNext(r *bytes.Reader, encLevel protocol.EncryptionLev
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *frameParser) parseFrame(r *bytes.Reader, typ uint64, encLevel protocol.EncryptionLevel, v protocol.VersionNumber) (Frame, error) {
|
func (p *FrameParser) parseFrame(r *bytes.Reader, typ uint64, encLevel protocol.EncryptionLevel, v protocol.Version) (Frame, error) {
|
||||||
var frame Frame
|
var frame Frame
|
||||||
var err error
|
var err error
|
||||||
if typ&0xf8 == 0x8 {
|
if typ&0xf8 == 0x8 {
|
||||||
|
@ -163,7 +162,7 @@ func (p *frameParser) parseFrame(r *bytes.Reader, typ uint64, encLevel protocol.
|
||||||
return frame, nil
|
return frame, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *frameParser) isAllowedAtEncLevel(f Frame, encLevel protocol.EncryptionLevel) bool {
|
func (p *FrameParser) isAllowedAtEncLevel(f Frame, encLevel protocol.EncryptionLevel) bool {
|
||||||
switch encLevel {
|
switch encLevel {
|
||||||
case protocol.EncryptionInitial, protocol.EncryptionHandshake:
|
case protocol.EncryptionInitial, protocol.EncryptionHandshake:
|
||||||
switch f.(type) {
|
switch f.(type) {
|
||||||
|
@ -186,6 +185,8 @@ func (p *frameParser) isAllowedAtEncLevel(f Frame, encLevel protocol.EncryptionL
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *frameParser) SetAckDelayExponent(exp uint8) {
|
// SetAckDelayExponent sets the acknowledgment delay exponent (sent in the transport parameters).
|
||||||
|
// This value is used to scale the ACK Delay field in the ACK frame.
|
||||||
|
func (p *FrameParser) SetAckDelayExponent(exp uint8) {
|
||||||
p.ackDelayExponent = exp
|
p.ackDelayExponent = exp
|
||||||
}
|
}
|
||||||
|
|
|
@ -7,11 +7,11 @@ import (
|
||||||
// A HandshakeDoneFrame is a HANDSHAKE_DONE frame
|
// A HandshakeDoneFrame is a HANDSHAKE_DONE frame
|
||||||
type HandshakeDoneFrame struct{}
|
type HandshakeDoneFrame struct{}
|
||||||
|
|
||||||
func (f *HandshakeDoneFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, error) {
|
func (f *HandshakeDoneFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {
|
||||||
return append(b, handshakeDoneFrameType), nil
|
return append(b, handshakeDoneFrameType), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Length of a written frame
|
// Length of a written frame
|
||||||
func (f *HandshakeDoneFrame) Length(_ protocol.VersionNumber) protocol.ByteCount {
|
func (f *HandshakeDoneFrame) Length(_ protocol.Version) protocol.ByteCount {
|
||||||
return 1
|
return 1
|
||||||
}
|
}
|
||||||
|
|
|
@ -85,11 +85,11 @@ func IsLongHeaderPacket(firstByte byte) bool {
|
||||||
|
|
||||||
// ParseVersion parses the QUIC version.
|
// ParseVersion parses the QUIC version.
|
||||||
// It should only be called for Long Header packets (Short Header packets don't contain a version number).
|
// It should only be called for Long Header packets (Short Header packets don't contain a version number).
|
||||||
func ParseVersion(data []byte) (protocol.VersionNumber, error) {
|
func ParseVersion(data []byte) (protocol.Version, error) {
|
||||||
if len(data) < 5 {
|
if len(data) < 5 {
|
||||||
return 0, io.EOF
|
return 0, io.EOF
|
||||||
}
|
}
|
||||||
return protocol.VersionNumber(binary.BigEndian.Uint32(data[1:5])), nil
|
return protocol.Version(binary.BigEndian.Uint32(data[1:5])), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsVersionNegotiationPacket says if this is a version negotiation packet
|
// IsVersionNegotiationPacket says if this is a version negotiation packet
|
||||||
|
@ -109,7 +109,7 @@ func Is0RTTPacket(b []byte) bool {
|
||||||
if !IsLongHeaderPacket(b[0]) {
|
if !IsLongHeaderPacket(b[0]) {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
version := protocol.VersionNumber(binary.BigEndian.Uint32(b[1:5]))
|
version := protocol.Version(binary.BigEndian.Uint32(b[1:5]))
|
||||||
//nolint:exhaustive // We only need to test QUIC versions that we support.
|
//nolint:exhaustive // We only need to test QUIC versions that we support.
|
||||||
switch version {
|
switch version {
|
||||||
case protocol.Version1:
|
case protocol.Version1:
|
||||||
|
@ -128,7 +128,7 @@ type Header struct {
|
||||||
typeByte byte
|
typeByte byte
|
||||||
Type protocol.PacketType
|
Type protocol.PacketType
|
||||||
|
|
||||||
Version protocol.VersionNumber
|
Version protocol.Version
|
||||||
SrcConnectionID protocol.ConnectionID
|
SrcConnectionID protocol.ConnectionID
|
||||||
DestConnectionID protocol.ConnectionID
|
DestConnectionID protocol.ConnectionID
|
||||||
|
|
||||||
|
@ -184,7 +184,7 @@ func (h *Header) parseLongHeader(b *bytes.Reader) error {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
h.Version = protocol.VersionNumber(v)
|
h.Version = protocol.Version(v)
|
||||||
if h.Version != 0 && h.typeByte&0x40 == 0 {
|
if h.Version != 0 && h.typeByte&0x40 == 0 {
|
||||||
return errors.New("not a QUIC packet")
|
return errors.New("not a QUIC packet")
|
||||||
}
|
}
|
||||||
|
@ -278,7 +278,7 @@ func (h *Header) ParsedLen() protocol.ByteCount {
|
||||||
|
|
||||||
// ParseExtended parses the version dependent part of the header.
|
// ParseExtended parses the version dependent part of the header.
|
||||||
// The Reader has to be set such that it points to the first byte of the header.
|
// The Reader has to be set such that it points to the first byte of the header.
|
||||||
func (h *Header) ParseExtended(b *bytes.Reader, ver protocol.VersionNumber) (*ExtendedHeader, error) {
|
func (h *Header) ParseExtended(b *bytes.Reader, ver protocol.Version) (*ExtendedHeader, error) {
|
||||||
extHdr := h.toExtendedHeader()
|
extHdr := h.toExtendedHeader()
|
||||||
reservedBitsValid, err := extHdr.parse(b, ver)
|
reservedBitsValid, err := extHdr.parse(b, ver)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -6,12 +6,6 @@ import (
|
||||||
|
|
||||||
// A Frame in QUIC
|
// A Frame in QUIC
|
||||||
type Frame interface {
|
type Frame interface {
|
||||||
Append(b []byte, version protocol.VersionNumber) ([]byte, error)
|
Append(b []byte, version protocol.Version) ([]byte, error)
|
||||||
Length(version protocol.VersionNumber) protocol.ByteCount
|
Length(version protocol.Version) protocol.ByteCount
|
||||||
}
|
|
||||||
|
|
||||||
// A FrameParser parses QUIC frames, one by one.
|
|
||||||
type FrameParser interface {
|
|
||||||
ParseNext([]byte, protocol.EncryptionLevel, protocol.VersionNumber) (int, Frame, error)
|
|
||||||
SetAckDelayExponent(uint8)
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -63,7 +63,9 @@ func LogFrame(logger utils.Logger, frame Frame, sent bool) {
|
||||||
logger.Debugf("\t%s &wire.StreamsBlockedFrame{Type: bidi, MaxStreams: %d}", dir, f.StreamLimit)
|
logger.Debugf("\t%s &wire.StreamsBlockedFrame{Type: bidi, MaxStreams: %d}", dir, f.StreamLimit)
|
||||||
}
|
}
|
||||||
case *NewConnectionIDFrame:
|
case *NewConnectionIDFrame:
|
||||||
logger.Debugf("\t%s &wire.NewConnectionIDFrame{SequenceNumber: %d, ConnectionID: %s, StatelessResetToken: %#x}", dir, f.SequenceNumber, f.ConnectionID, f.StatelessResetToken)
|
logger.Debugf("\t%s &wire.NewConnectionIDFrame{SequenceNumber: %d, RetirePriorTo: %d, ConnectionID: %s, StatelessResetToken: %#x}", dir, f.SequenceNumber, f.RetirePriorTo, f.ConnectionID, f.StatelessResetToken)
|
||||||
|
case *RetireConnectionIDFrame:
|
||||||
|
logger.Debugf("\t%s &wire.RetireConnectionIDFrame{SequenceNumber: %d}", dir, f.SequenceNumber)
|
||||||
case *NewTokenFrame:
|
case *NewTokenFrame:
|
||||||
logger.Debugf("\t%s &wire.NewTokenFrame{Token: %#x}", dir, f.Token)
|
logger.Debugf("\t%s &wire.NewTokenFrame{Token: %#x}", dir, f.Token)
|
||||||
default:
|
default:
|
||||||
|
|
|
@ -13,7 +13,7 @@ type MaxDataFrame struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
// parseMaxDataFrame parses a MAX_DATA frame
|
// parseMaxDataFrame parses a MAX_DATA frame
|
||||||
func parseMaxDataFrame(r *bytes.Reader, _ protocol.VersionNumber) (*MaxDataFrame, error) {
|
func parseMaxDataFrame(r *bytes.Reader, _ protocol.Version) (*MaxDataFrame, error) {
|
||||||
frame := &MaxDataFrame{}
|
frame := &MaxDataFrame{}
|
||||||
byteOffset, err := quicvarint.Read(r)
|
byteOffset, err := quicvarint.Read(r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -23,13 +23,13 @@ func parseMaxDataFrame(r *bytes.Reader, _ protocol.VersionNumber) (*MaxDataFrame
|
||||||
return frame, nil
|
return frame, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *MaxDataFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, error) {
|
func (f *MaxDataFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {
|
||||||
b = append(b, maxDataFrameType)
|
b = append(b, maxDataFrameType)
|
||||||
b = quicvarint.Append(b, uint64(f.MaximumData))
|
b = quicvarint.Append(b, uint64(f.MaximumData))
|
||||||
return b, nil
|
return b, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Length of a written frame
|
// Length of a written frame
|
||||||
func (f *MaxDataFrame) Length(_ protocol.VersionNumber) protocol.ByteCount {
|
func (f *MaxDataFrame) Length(_ protocol.Version) protocol.ByteCount {
|
||||||
return 1 + quicvarint.Len(uint64(f.MaximumData))
|
return 1 + quicvarint.Len(uint64(f.MaximumData))
|
||||||
}
|
}
|
||||||
|
|
|
@ -13,7 +13,7 @@ type MaxStreamDataFrame struct {
|
||||||
MaximumStreamData protocol.ByteCount
|
MaximumStreamData protocol.ByteCount
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseMaxStreamDataFrame(r *bytes.Reader, _ protocol.VersionNumber) (*MaxStreamDataFrame, error) {
|
func parseMaxStreamDataFrame(r *bytes.Reader, _ protocol.Version) (*MaxStreamDataFrame, error) {
|
||||||
sid, err := quicvarint.Read(r)
|
sid, err := quicvarint.Read(r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -29,7 +29,7 @@ func parseMaxStreamDataFrame(r *bytes.Reader, _ protocol.VersionNumber) (*MaxStr
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *MaxStreamDataFrame) Append(b []byte, version protocol.VersionNumber) ([]byte, error) {
|
func (f *MaxStreamDataFrame) Append(b []byte, version protocol.Version) ([]byte, error) {
|
||||||
b = append(b, maxStreamDataFrameType)
|
b = append(b, maxStreamDataFrameType)
|
||||||
b = quicvarint.Append(b, uint64(f.StreamID))
|
b = quicvarint.Append(b, uint64(f.StreamID))
|
||||||
b = quicvarint.Append(b, uint64(f.MaximumStreamData))
|
b = quicvarint.Append(b, uint64(f.MaximumStreamData))
|
||||||
|
@ -37,6 +37,6 @@ func (f *MaxStreamDataFrame) Append(b []byte, version protocol.VersionNumber) ([
|
||||||
}
|
}
|
||||||
|
|
||||||
// Length of a written frame
|
// Length of a written frame
|
||||||
func (f *MaxStreamDataFrame) Length(version protocol.VersionNumber) protocol.ByteCount {
|
func (f *MaxStreamDataFrame) Length(version protocol.Version) protocol.ByteCount {
|
||||||
return 1 + quicvarint.Len(uint64(f.StreamID)) + quicvarint.Len(uint64(f.MaximumStreamData))
|
return 1 + quicvarint.Len(uint64(f.StreamID)) + quicvarint.Len(uint64(f.MaximumStreamData))
|
||||||
}
|
}
|
||||||
|
|
|
@ -14,7 +14,7 @@ type MaxStreamsFrame struct {
|
||||||
MaxStreamNum protocol.StreamNum
|
MaxStreamNum protocol.StreamNum
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseMaxStreamsFrame(r *bytes.Reader, typ uint64, _ protocol.VersionNumber) (*MaxStreamsFrame, error) {
|
func parseMaxStreamsFrame(r *bytes.Reader, typ uint64, _ protocol.Version) (*MaxStreamsFrame, error) {
|
||||||
f := &MaxStreamsFrame{}
|
f := &MaxStreamsFrame{}
|
||||||
switch typ {
|
switch typ {
|
||||||
case bidiMaxStreamsFrameType:
|
case bidiMaxStreamsFrameType:
|
||||||
|
@ -33,7 +33,7 @@ func parseMaxStreamsFrame(r *bytes.Reader, typ uint64, _ protocol.VersionNumber)
|
||||||
return f, nil
|
return f, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *MaxStreamsFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, error) {
|
func (f *MaxStreamsFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {
|
||||||
switch f.Type {
|
switch f.Type {
|
||||||
case protocol.StreamTypeBidi:
|
case protocol.StreamTypeBidi:
|
||||||
b = append(b, bidiMaxStreamsFrameType)
|
b = append(b, bidiMaxStreamsFrameType)
|
||||||
|
@ -45,6 +45,6 @@ func (f *MaxStreamsFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, er
|
||||||
}
|
}
|
||||||
|
|
||||||
// Length of a written frame
|
// Length of a written frame
|
||||||
func (f *MaxStreamsFrame) Length(protocol.VersionNumber) protocol.ByteCount {
|
func (f *MaxStreamsFrame) Length(protocol.Version) protocol.ByteCount {
|
||||||
return 1 + quicvarint.Len(uint64(f.MaxStreamNum))
|
return 1 + quicvarint.Len(uint64(f.MaxStreamNum))
|
||||||
}
|
}
|
||||||
|
|
|
@ -18,7 +18,7 @@ type NewConnectionIDFrame struct {
|
||||||
StatelessResetToken protocol.StatelessResetToken
|
StatelessResetToken protocol.StatelessResetToken
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseNewConnectionIDFrame(r *bytes.Reader, _ protocol.VersionNumber) (*NewConnectionIDFrame, error) {
|
func parseNewConnectionIDFrame(r *bytes.Reader, _ protocol.Version) (*NewConnectionIDFrame, error) {
|
||||||
seq, err := quicvarint.Read(r)
|
seq, err := quicvarint.Read(r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -57,7 +57,7 @@ func parseNewConnectionIDFrame(r *bytes.Reader, _ protocol.VersionNumber) (*NewC
|
||||||
return frame, nil
|
return frame, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *NewConnectionIDFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, error) {
|
func (f *NewConnectionIDFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {
|
||||||
b = append(b, newConnectionIDFrameType)
|
b = append(b, newConnectionIDFrameType)
|
||||||
b = quicvarint.Append(b, f.SequenceNumber)
|
b = quicvarint.Append(b, f.SequenceNumber)
|
||||||
b = quicvarint.Append(b, f.RetirePriorTo)
|
b = quicvarint.Append(b, f.RetirePriorTo)
|
||||||
|
@ -72,6 +72,6 @@ func (f *NewConnectionIDFrame) Append(b []byte, _ protocol.VersionNumber) ([]byt
|
||||||
}
|
}
|
||||||
|
|
||||||
// Length of a written frame
|
// Length of a written frame
|
||||||
func (f *NewConnectionIDFrame) Length(protocol.VersionNumber) protocol.ByteCount {
|
func (f *NewConnectionIDFrame) Length(protocol.Version) protocol.ByteCount {
|
||||||
return 1 + quicvarint.Len(f.SequenceNumber) + quicvarint.Len(f.RetirePriorTo) + 1 /* connection ID length */ + protocol.ByteCount(f.ConnectionID.Len()) + 16
|
return 1 + quicvarint.Len(f.SequenceNumber) + quicvarint.Len(f.RetirePriorTo) + 1 /* connection ID length */ + protocol.ByteCount(f.ConnectionID.Len()) + 16
|
||||||
}
|
}
|
||||||
|
|
|
@ -14,7 +14,7 @@ type NewTokenFrame struct {
|
||||||
Token []byte
|
Token []byte
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseNewTokenFrame(r *bytes.Reader, _ protocol.VersionNumber) (*NewTokenFrame, error) {
|
func parseNewTokenFrame(r *bytes.Reader, _ protocol.Version) (*NewTokenFrame, error) {
|
||||||
tokenLen, err := quicvarint.Read(r)
|
tokenLen, err := quicvarint.Read(r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -32,7 +32,7 @@ func parseNewTokenFrame(r *bytes.Reader, _ protocol.VersionNumber) (*NewTokenFra
|
||||||
return &NewTokenFrame{Token: token}, nil
|
return &NewTokenFrame{Token: token}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *NewTokenFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, error) {
|
func (f *NewTokenFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {
|
||||||
b = append(b, newTokenFrameType)
|
b = append(b, newTokenFrameType)
|
||||||
b = quicvarint.Append(b, uint64(len(f.Token)))
|
b = quicvarint.Append(b, uint64(len(f.Token)))
|
||||||
b = append(b, f.Token...)
|
b = append(b, f.Token...)
|
||||||
|
@ -40,6 +40,6 @@ func (f *NewTokenFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, erro
|
||||||
}
|
}
|
||||||
|
|
||||||
// Length of a written frame
|
// Length of a written frame
|
||||||
func (f *NewTokenFrame) Length(protocol.VersionNumber) protocol.ByteCount {
|
func (f *NewTokenFrame) Length(protocol.Version) protocol.ByteCount {
|
||||||
return 1 + quicvarint.Len(uint64(len(f.Token))) + protocol.ByteCount(len(f.Token))
|
return 1 + quicvarint.Len(uint64(len(f.Token))) + protocol.ByteCount(len(f.Token))
|
||||||
}
|
}
|
||||||
|
|
|
@ -12,7 +12,7 @@ type PathChallengeFrame struct {
|
||||||
Data [8]byte
|
Data [8]byte
|
||||||
}
|
}
|
||||||
|
|
||||||
func parsePathChallengeFrame(r *bytes.Reader, _ protocol.VersionNumber) (*PathChallengeFrame, error) {
|
func parsePathChallengeFrame(r *bytes.Reader, _ protocol.Version) (*PathChallengeFrame, error) {
|
||||||
frame := &PathChallengeFrame{}
|
frame := &PathChallengeFrame{}
|
||||||
if _, err := io.ReadFull(r, frame.Data[:]); err != nil {
|
if _, err := io.ReadFull(r, frame.Data[:]); err != nil {
|
||||||
if err == io.ErrUnexpectedEOF {
|
if err == io.ErrUnexpectedEOF {
|
||||||
|
@ -23,13 +23,13 @@ func parsePathChallengeFrame(r *bytes.Reader, _ protocol.VersionNumber) (*PathCh
|
||||||
return frame, nil
|
return frame, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *PathChallengeFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, error) {
|
func (f *PathChallengeFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {
|
||||||
b = append(b, pathChallengeFrameType)
|
b = append(b, pathChallengeFrameType)
|
||||||
b = append(b, f.Data[:]...)
|
b = append(b, f.Data[:]...)
|
||||||
return b, nil
|
return b, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Length of a written frame
|
// Length of a written frame
|
||||||
func (f *PathChallengeFrame) Length(_ protocol.VersionNumber) protocol.ByteCount {
|
func (f *PathChallengeFrame) Length(_ protocol.Version) protocol.ByteCount {
|
||||||
return 1 + 8
|
return 1 + 8
|
||||||
}
|
}
|
||||||
|
|
|
@ -12,7 +12,7 @@ type PathResponseFrame struct {
|
||||||
Data [8]byte
|
Data [8]byte
|
||||||
}
|
}
|
||||||
|
|
||||||
func parsePathResponseFrame(r *bytes.Reader, _ protocol.VersionNumber) (*PathResponseFrame, error) {
|
func parsePathResponseFrame(r *bytes.Reader, _ protocol.Version) (*PathResponseFrame, error) {
|
||||||
frame := &PathResponseFrame{}
|
frame := &PathResponseFrame{}
|
||||||
if _, err := io.ReadFull(r, frame.Data[:]); err != nil {
|
if _, err := io.ReadFull(r, frame.Data[:]); err != nil {
|
||||||
if err == io.ErrUnexpectedEOF {
|
if err == io.ErrUnexpectedEOF {
|
||||||
|
@ -23,13 +23,13 @@ func parsePathResponseFrame(r *bytes.Reader, _ protocol.VersionNumber) (*PathRes
|
||||||
return frame, nil
|
return frame, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *PathResponseFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, error) {
|
func (f *PathResponseFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {
|
||||||
b = append(b, pathResponseFrameType)
|
b = append(b, pathResponseFrameType)
|
||||||
b = append(b, f.Data[:]...)
|
b = append(b, f.Data[:]...)
|
||||||
return b, nil
|
return b, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Length of a written frame
|
// Length of a written frame
|
||||||
func (f *PathResponseFrame) Length(_ protocol.VersionNumber) protocol.ByteCount {
|
func (f *PathResponseFrame) Length(_ protocol.Version) protocol.ByteCount {
|
||||||
return 1 + 8
|
return 1 + 8
|
||||||
}
|
}
|
||||||
|
|
|
@ -7,11 +7,11 @@ import (
|
||||||
// A PingFrame is a PING frame
|
// A PingFrame is a PING frame
|
||||||
type PingFrame struct{}
|
type PingFrame struct{}
|
||||||
|
|
||||||
func (f *PingFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, error) {
|
func (f *PingFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {
|
||||||
return append(b, pingFrameType), nil
|
return append(b, pingFrameType), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Length of a written frame
|
// Length of a written frame
|
||||||
func (f *PingFrame) Length(_ protocol.VersionNumber) protocol.ByteCount {
|
func (f *PingFrame) Length(_ protocol.Version) protocol.ByteCount {
|
||||||
return 1
|
return 1
|
||||||
}
|
}
|
||||||
|
|
|
@ -15,7 +15,7 @@ type ResetStreamFrame struct {
|
||||||
FinalSize protocol.ByteCount
|
FinalSize protocol.ByteCount
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseResetStreamFrame(r *bytes.Reader, _ protocol.VersionNumber) (*ResetStreamFrame, error) {
|
func parseResetStreamFrame(r *bytes.Reader, _ protocol.Version) (*ResetStreamFrame, error) {
|
||||||
var streamID protocol.StreamID
|
var streamID protocol.StreamID
|
||||||
var byteOffset protocol.ByteCount
|
var byteOffset protocol.ByteCount
|
||||||
sid, err := quicvarint.Read(r)
|
sid, err := quicvarint.Read(r)
|
||||||
|
@ -40,7 +40,7 @@ func parseResetStreamFrame(r *bytes.Reader, _ protocol.VersionNumber) (*ResetStr
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *ResetStreamFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, error) {
|
func (f *ResetStreamFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {
|
||||||
b = append(b, resetStreamFrameType)
|
b = append(b, resetStreamFrameType)
|
||||||
b = quicvarint.Append(b, uint64(f.StreamID))
|
b = quicvarint.Append(b, uint64(f.StreamID))
|
||||||
b = quicvarint.Append(b, uint64(f.ErrorCode))
|
b = quicvarint.Append(b, uint64(f.ErrorCode))
|
||||||
|
@ -49,6 +49,6 @@ func (f *ResetStreamFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, e
|
||||||
}
|
}
|
||||||
|
|
||||||
// Length of a written frame
|
// Length of a written frame
|
||||||
func (f *ResetStreamFrame) Length(version protocol.VersionNumber) protocol.ByteCount {
|
func (f *ResetStreamFrame) Length(version protocol.Version) protocol.ByteCount {
|
||||||
return 1 + quicvarint.Len(uint64(f.StreamID)) + quicvarint.Len(uint64(f.ErrorCode)) + quicvarint.Len(uint64(f.FinalSize))
|
return 1 + quicvarint.Len(uint64(f.StreamID)) + quicvarint.Len(uint64(f.ErrorCode)) + quicvarint.Len(uint64(f.FinalSize))
|
||||||
}
|
}
|
||||||
|
|
|
@ -12,7 +12,7 @@ type RetireConnectionIDFrame struct {
|
||||||
SequenceNumber uint64
|
SequenceNumber uint64
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseRetireConnectionIDFrame(r *bytes.Reader, _ protocol.VersionNumber) (*RetireConnectionIDFrame, error) {
|
func parseRetireConnectionIDFrame(r *bytes.Reader, _ protocol.Version) (*RetireConnectionIDFrame, error) {
|
||||||
seq, err := quicvarint.Read(r)
|
seq, err := quicvarint.Read(r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -20,13 +20,13 @@ func parseRetireConnectionIDFrame(r *bytes.Reader, _ protocol.VersionNumber) (*R
|
||||||
return &RetireConnectionIDFrame{SequenceNumber: seq}, nil
|
return &RetireConnectionIDFrame{SequenceNumber: seq}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *RetireConnectionIDFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, error) {
|
func (f *RetireConnectionIDFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {
|
||||||
b = append(b, retireConnectionIDFrameType)
|
b = append(b, retireConnectionIDFrameType)
|
||||||
b = quicvarint.Append(b, f.SequenceNumber)
|
b = quicvarint.Append(b, f.SequenceNumber)
|
||||||
return b, nil
|
return b, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Length of a written frame
|
// Length of a written frame
|
||||||
func (f *RetireConnectionIDFrame) Length(protocol.VersionNumber) protocol.ByteCount {
|
func (f *RetireConnectionIDFrame) Length(protocol.Version) protocol.ByteCount {
|
||||||
return 1 + quicvarint.Len(f.SequenceNumber)
|
return 1 + quicvarint.Len(f.SequenceNumber)
|
||||||
}
|
}
|
||||||
|
|
|
@ -15,7 +15,7 @@ type StopSendingFrame struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
// parseStopSendingFrame parses a STOP_SENDING frame
|
// parseStopSendingFrame parses a STOP_SENDING frame
|
||||||
func parseStopSendingFrame(r *bytes.Reader, _ protocol.VersionNumber) (*StopSendingFrame, error) {
|
func parseStopSendingFrame(r *bytes.Reader, _ protocol.Version) (*StopSendingFrame, error) {
|
||||||
streamID, err := quicvarint.Read(r)
|
streamID, err := quicvarint.Read(r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -32,11 +32,11 @@ func parseStopSendingFrame(r *bytes.Reader, _ protocol.VersionNumber) (*StopSend
|
||||||
}
|
}
|
||||||
|
|
||||||
// Length of a written frame
|
// Length of a written frame
|
||||||
func (f *StopSendingFrame) Length(_ protocol.VersionNumber) protocol.ByteCount {
|
func (f *StopSendingFrame) Length(_ protocol.Version) protocol.ByteCount {
|
||||||
return 1 + quicvarint.Len(uint64(f.StreamID)) + quicvarint.Len(uint64(f.ErrorCode))
|
return 1 + quicvarint.Len(uint64(f.StreamID)) + quicvarint.Len(uint64(f.ErrorCode))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *StopSendingFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, error) {
|
func (f *StopSendingFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {
|
||||||
b = append(b, stopSendingFrameType)
|
b = append(b, stopSendingFrameType)
|
||||||
b = quicvarint.Append(b, uint64(f.StreamID))
|
b = quicvarint.Append(b, uint64(f.StreamID))
|
||||||
b = quicvarint.Append(b, uint64(f.ErrorCode))
|
b = quicvarint.Append(b, uint64(f.ErrorCode))
|
||||||
|
|
|
@ -13,7 +13,7 @@ type StreamDataBlockedFrame struct {
|
||||||
MaximumStreamData protocol.ByteCount
|
MaximumStreamData protocol.ByteCount
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseStreamDataBlockedFrame(r *bytes.Reader, _ protocol.VersionNumber) (*StreamDataBlockedFrame, error) {
|
func parseStreamDataBlockedFrame(r *bytes.Reader, _ protocol.Version) (*StreamDataBlockedFrame, error) {
|
||||||
sid, err := quicvarint.Read(r)
|
sid, err := quicvarint.Read(r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -29,7 +29,7 @@ func parseStreamDataBlockedFrame(r *bytes.Reader, _ protocol.VersionNumber) (*St
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *StreamDataBlockedFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, error) {
|
func (f *StreamDataBlockedFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {
|
||||||
b = append(b, 0x15)
|
b = append(b, 0x15)
|
||||||
b = quicvarint.Append(b, uint64(f.StreamID))
|
b = quicvarint.Append(b, uint64(f.StreamID))
|
||||||
b = quicvarint.Append(b, uint64(f.MaximumStreamData))
|
b = quicvarint.Append(b, uint64(f.MaximumStreamData))
|
||||||
|
@ -37,6 +37,6 @@ func (f *StreamDataBlockedFrame) Append(b []byte, _ protocol.VersionNumber) ([]b
|
||||||
}
|
}
|
||||||
|
|
||||||
// Length of a written frame
|
// Length of a written frame
|
||||||
func (f *StreamDataBlockedFrame) Length(version protocol.VersionNumber) protocol.ByteCount {
|
func (f *StreamDataBlockedFrame) Length(version protocol.Version) protocol.ByteCount {
|
||||||
return 1 + quicvarint.Len(uint64(f.StreamID)) + quicvarint.Len(uint64(f.MaximumStreamData))
|
return 1 + quicvarint.Len(uint64(f.StreamID)) + quicvarint.Len(uint64(f.MaximumStreamData))
|
||||||
}
|
}
|
||||||
|
|
|
@ -20,7 +20,7 @@ type StreamFrame struct {
|
||||||
fromPool bool
|
fromPool bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseStreamFrame(r *bytes.Reader, typ uint64, _ protocol.VersionNumber) (*StreamFrame, error) {
|
func parseStreamFrame(r *bytes.Reader, typ uint64, _ protocol.Version) (*StreamFrame, error) {
|
||||||
hasOffset := typ&0b100 > 0
|
hasOffset := typ&0b100 > 0
|
||||||
fin := typ&0b1 > 0
|
fin := typ&0b1 > 0
|
||||||
hasDataLen := typ&0b10 > 0
|
hasDataLen := typ&0b10 > 0
|
||||||
|
@ -79,7 +79,7 @@ func parseStreamFrame(r *bytes.Reader, typ uint64, _ protocol.VersionNumber) (*S
|
||||||
}
|
}
|
||||||
|
|
||||||
// Write writes a STREAM frame
|
// Write writes a STREAM frame
|
||||||
func (f *StreamFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, error) {
|
func (f *StreamFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {
|
||||||
if len(f.Data) == 0 && !f.Fin {
|
if len(f.Data) == 0 && !f.Fin {
|
||||||
return nil, errors.New("StreamFrame: attempting to write empty frame without FIN")
|
return nil, errors.New("StreamFrame: attempting to write empty frame without FIN")
|
||||||
}
|
}
|
||||||
|
@ -108,7 +108,7 @@ func (f *StreamFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Length returns the total length of the STREAM frame
|
// Length returns the total length of the STREAM frame
|
||||||
func (f *StreamFrame) Length(version protocol.VersionNumber) protocol.ByteCount {
|
func (f *StreamFrame) Length(version protocol.Version) protocol.ByteCount {
|
||||||
length := 1 + quicvarint.Len(uint64(f.StreamID))
|
length := 1 + quicvarint.Len(uint64(f.StreamID))
|
||||||
if f.Offset != 0 {
|
if f.Offset != 0 {
|
||||||
length += quicvarint.Len(uint64(f.Offset))
|
length += quicvarint.Len(uint64(f.Offset))
|
||||||
|
@ -126,7 +126,7 @@ func (f *StreamFrame) DataLen() protocol.ByteCount {
|
||||||
|
|
||||||
// MaxDataLen returns the maximum data length
|
// MaxDataLen returns the maximum data length
|
||||||
// If 0 is returned, writing will fail (a STREAM frame must contain at least 1 byte of data).
|
// If 0 is returned, writing will fail (a STREAM frame must contain at least 1 byte of data).
|
||||||
func (f *StreamFrame) MaxDataLen(maxSize protocol.ByteCount, version protocol.VersionNumber) protocol.ByteCount {
|
func (f *StreamFrame) MaxDataLen(maxSize protocol.ByteCount, version protocol.Version) protocol.ByteCount {
|
||||||
headerLen := 1 + quicvarint.Len(uint64(f.StreamID))
|
headerLen := 1 + quicvarint.Len(uint64(f.StreamID))
|
||||||
if f.Offset != 0 {
|
if f.Offset != 0 {
|
||||||
headerLen += quicvarint.Len(uint64(f.Offset))
|
headerLen += quicvarint.Len(uint64(f.Offset))
|
||||||
|
@ -151,7 +151,7 @@ func (f *StreamFrame) MaxDataLen(maxSize protocol.ByteCount, version protocol.Ve
|
||||||
// The frame might not be split if:
|
// The frame might not be split if:
|
||||||
// * the size is large enough to fit the whole frame
|
// * the size is large enough to fit the whole frame
|
||||||
// * the size is too small to fit even a 1-byte frame. In that case, the frame returned is nil.
|
// * the size is too small to fit even a 1-byte frame. In that case, the frame returned is nil.
|
||||||
func (f *StreamFrame) MaybeSplitOffFrame(maxSize protocol.ByteCount, version protocol.VersionNumber) (*StreamFrame, bool /* was splitting required */) {
|
func (f *StreamFrame) MaybeSplitOffFrame(maxSize protocol.ByteCount, version protocol.Version) (*StreamFrame, bool /* was splitting required */) {
|
||||||
if maxSize >= f.Length(version) {
|
if maxSize >= f.Length(version) {
|
||||||
return nil, false
|
return nil, false
|
||||||
}
|
}
|
||||||
|
|
|
@ -14,7 +14,7 @@ type StreamsBlockedFrame struct {
|
||||||
StreamLimit protocol.StreamNum
|
StreamLimit protocol.StreamNum
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseStreamsBlockedFrame(r *bytes.Reader, typ uint64, _ protocol.VersionNumber) (*StreamsBlockedFrame, error) {
|
func parseStreamsBlockedFrame(r *bytes.Reader, typ uint64, _ protocol.Version) (*StreamsBlockedFrame, error) {
|
||||||
f := &StreamsBlockedFrame{}
|
f := &StreamsBlockedFrame{}
|
||||||
switch typ {
|
switch typ {
|
||||||
case bidiStreamBlockedFrameType:
|
case bidiStreamBlockedFrameType:
|
||||||
|
@ -33,7 +33,7 @@ func parseStreamsBlockedFrame(r *bytes.Reader, typ uint64, _ protocol.VersionNum
|
||||||
return f, nil
|
return f, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *StreamsBlockedFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, error) {
|
func (f *StreamsBlockedFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {
|
||||||
switch f.Type {
|
switch f.Type {
|
||||||
case protocol.StreamTypeBidi:
|
case protocol.StreamTypeBidi:
|
||||||
b = append(b, bidiStreamBlockedFrameType)
|
b = append(b, bidiStreamBlockedFrameType)
|
||||||
|
@ -45,6 +45,6 @@ func (f *StreamsBlockedFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte
|
||||||
}
|
}
|
||||||
|
|
||||||
// Length of a written frame
|
// Length of a written frame
|
||||||
func (f *StreamsBlockedFrame) Length(_ protocol.VersionNumber) protocol.ByteCount {
|
func (f *StreamsBlockedFrame) Length(_ protocol.Version) protocol.ByteCount {
|
||||||
return 1 + quicvarint.Len(uint64(f.StreamLimit))
|
return 1 + quicvarint.Len(uint64(f.StreamLimit))
|
||||||
}
|
}
|
||||||
|
|
|
@ -7,7 +7,7 @@ import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net/netip"
|
||||||
"sort"
|
"sort"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -51,10 +51,7 @@ const (
|
||||||
|
|
||||||
// PreferredAddress is the value encoding in the preferred_address transport parameter
|
// PreferredAddress is the value encoding in the preferred_address transport parameter
|
||||||
type PreferredAddress struct {
|
type PreferredAddress struct {
|
||||||
IPv4 net.IP
|
IPv4, IPv6 netip.AddrPort
|
||||||
IPv4Port uint16
|
|
||||||
IPv6 net.IP
|
|
||||||
IPv6Port uint16
|
|
||||||
ConnectionID protocol.ConnectionID
|
ConnectionID protocol.ConnectionID
|
||||||
StatelessResetToken protocol.StatelessResetToken
|
StatelessResetToken protocol.StatelessResetToken
|
||||||
}
|
}
|
||||||
|
@ -218,26 +215,24 @@ func (p *TransportParameters) unmarshal(r *bytes.Reader, sentBy protocol.Perspec
|
||||||
func (p *TransportParameters) readPreferredAddress(r *bytes.Reader, expectedLen int) error {
|
func (p *TransportParameters) readPreferredAddress(r *bytes.Reader, expectedLen int) error {
|
||||||
remainingLen := r.Len()
|
remainingLen := r.Len()
|
||||||
pa := &PreferredAddress{}
|
pa := &PreferredAddress{}
|
||||||
ipv4 := make([]byte, 4)
|
var ipv4 [4]byte
|
||||||
if _, err := io.ReadFull(r, ipv4); err != nil {
|
if _, err := io.ReadFull(r, ipv4[:]); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
pa.IPv4 = net.IP(ipv4)
|
|
||||||
port, err := utils.BigEndian.ReadUint16(r)
|
port, err := utils.BigEndian.ReadUint16(r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
pa.IPv4Port = port
|
pa.IPv4 = netip.AddrPortFrom(netip.AddrFrom4(ipv4), port)
|
||||||
ipv6 := make([]byte, 16)
|
var ipv6 [16]byte
|
||||||
if _, err := io.ReadFull(r, ipv6); err != nil {
|
if _, err := io.ReadFull(r, ipv6[:]); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
pa.IPv6 = net.IP(ipv6)
|
|
||||||
port, err = utils.BigEndian.ReadUint16(r)
|
port, err = utils.BigEndian.ReadUint16(r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
pa.IPv6Port = port
|
pa.IPv6 = netip.AddrPortFrom(netip.AddrFrom16(ipv6), port)
|
||||||
connIDLen, err := r.ReadByte()
|
connIDLen, err := r.ReadByte()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -384,13 +379,12 @@ func (p *TransportParameters) Marshal(pers protocol.Perspective) []byte {
|
||||||
if p.PreferredAddress != nil {
|
if p.PreferredAddress != nil {
|
||||||
b = quicvarint.Append(b, uint64(preferredAddressParameterID))
|
b = quicvarint.Append(b, uint64(preferredAddressParameterID))
|
||||||
b = quicvarint.Append(b, 4+2+16+2+1+uint64(p.PreferredAddress.ConnectionID.Len())+16)
|
b = quicvarint.Append(b, 4+2+16+2+1+uint64(p.PreferredAddress.ConnectionID.Len())+16)
|
||||||
ipv4 := p.PreferredAddress.IPv4
|
ip4 := p.PreferredAddress.IPv4.Addr().As4()
|
||||||
b = append(b, ipv4[len(ipv4)-4:]...)
|
b = append(b, ip4[:]...)
|
||||||
b = append(b, []byte{0, 0}...)
|
b = binary.BigEndian.AppendUint16(b, p.PreferredAddress.IPv4.Port())
|
||||||
binary.BigEndian.PutUint16(b[len(b)-2:], p.PreferredAddress.IPv4Port)
|
ip6 := p.PreferredAddress.IPv6.Addr().As16()
|
||||||
b = append(b, p.PreferredAddress.IPv6...)
|
b = append(b, ip6[:]...)
|
||||||
b = append(b, []byte{0, 0}...)
|
b = binary.BigEndian.AppendUint16(b, p.PreferredAddress.IPv6.Port())
|
||||||
binary.BigEndian.PutUint16(b[len(b)-2:], p.PreferredAddress.IPv6Port)
|
|
||||||
b = append(b, uint8(p.PreferredAddress.ConnectionID.Len()))
|
b = append(b, uint8(p.PreferredAddress.ConnectionID.Len()))
|
||||||
b = append(b, p.PreferredAddress.ConnectionID.Bytes()...)
|
b = append(b, p.PreferredAddress.ConnectionID.Bytes()...)
|
||||||
b = append(b, p.PreferredAddress.StatelessResetToken[:]...)
|
b = append(b, p.PreferredAddress.StatelessResetToken[:]...)
|
||||||
|
|
|
@ -1,17 +1,15 @@
|
||||||
package wire
|
package wire
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"errors"
|
"errors"
|
||||||
|
|
||||||
"github.com/quic-go/quic-go/internal/protocol"
|
"github.com/quic-go/quic-go/internal/protocol"
|
||||||
"github.com/quic-go/quic-go/internal/utils"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// ParseVersionNegotiationPacket parses a Version Negotiation packet.
|
// ParseVersionNegotiationPacket parses a Version Negotiation packet.
|
||||||
func ParseVersionNegotiationPacket(b []byte) (dest, src protocol.ArbitraryLenConnectionID, _ []protocol.VersionNumber, _ error) {
|
func ParseVersionNegotiationPacket(b []byte) (dest, src protocol.ArbitraryLenConnectionID, _ []protocol.Version, _ error) {
|
||||||
n, dest, src, err := ParseArbitraryLenConnectionIDs(b)
|
n, dest, src, err := ParseArbitraryLenConnectionIDs(b)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, nil, err
|
return nil, nil, nil, err
|
||||||
|
@ -25,32 +23,31 @@ func ParseVersionNegotiationPacket(b []byte) (dest, src protocol.ArbitraryLenCon
|
||||||
//nolint:stylecheck
|
//nolint:stylecheck
|
||||||
return nil, nil, nil, errors.New("Version Negotiation packet has a version list with an invalid length")
|
return nil, nil, nil, errors.New("Version Negotiation packet has a version list with an invalid length")
|
||||||
}
|
}
|
||||||
versions := make([]protocol.VersionNumber, len(b)/4)
|
versions := make([]protocol.Version, len(b)/4)
|
||||||
for i := 0; len(b) > 0; i++ {
|
for i := 0; len(b) > 0; i++ {
|
||||||
versions[i] = protocol.VersionNumber(binary.BigEndian.Uint32(b[:4]))
|
versions[i] = protocol.Version(binary.BigEndian.Uint32(b[:4]))
|
||||||
b = b[4:]
|
b = b[4:]
|
||||||
}
|
}
|
||||||
return dest, src, versions, nil
|
return dest, src, versions, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// ComposeVersionNegotiation composes a Version Negotiation
|
// ComposeVersionNegotiation composes a Version Negotiation
|
||||||
func ComposeVersionNegotiation(destConnID, srcConnID protocol.ArbitraryLenConnectionID, versions []protocol.VersionNumber) []byte {
|
func ComposeVersionNegotiation(destConnID, srcConnID protocol.ArbitraryLenConnectionID, versions []protocol.Version) []byte {
|
||||||
greasedVersions := protocol.GetGreasedVersions(versions)
|
greasedVersions := protocol.GetGreasedVersions(versions)
|
||||||
expectedLen := 1 /* type byte */ + 4 /* version field */ + 1 /* dest connection ID length field */ + destConnID.Len() + 1 /* src connection ID length field */ + srcConnID.Len() + len(greasedVersions)*4
|
expectedLen := 1 /* type byte */ + 4 /* version field */ + 1 /* dest connection ID length field */ + destConnID.Len() + 1 /* src connection ID length field */ + srcConnID.Len() + len(greasedVersions)*4
|
||||||
buf := bytes.NewBuffer(make([]byte, 0, expectedLen))
|
buf := make([]byte, 1+4 /* type byte and version field */, expectedLen)
|
||||||
r := make([]byte, 1)
|
_, _ = rand.Read(buf[:1]) // ignore the error here. It is not critical to have perfect random here.
|
||||||
_, _ = rand.Read(r) // ignore the error here. It is not critical to have perfect random here.
|
|
||||||
// Setting the "QUIC bit" (0x40) is not required by the RFC,
|
// Setting the "QUIC bit" (0x40) is not required by the RFC,
|
||||||
// but it allows clients to demultiplex QUIC with a long list of other protocols.
|
// but it allows clients to demultiplex QUIC with a long list of other protocols.
|
||||||
// See RFC 9443 and https://mailarchive.ietf.org/arch/msg/quic/oR4kxGKY6mjtPC1CZegY1ED4beg/ for details.
|
// See RFC 9443 and https://mailarchive.ietf.org/arch/msg/quic/oR4kxGKY6mjtPC1CZegY1ED4beg/ for details.
|
||||||
buf.WriteByte(r[0] | 0xc0)
|
buf[0] |= 0xc0
|
||||||
utils.BigEndian.WriteUint32(buf, 0) // version 0
|
// The next 4 bytes are left at 0 (version number).
|
||||||
buf.WriteByte(uint8(destConnID.Len()))
|
buf = append(buf, uint8(destConnID.Len()))
|
||||||
buf.Write(destConnID.Bytes())
|
buf = append(buf, destConnID.Bytes()...)
|
||||||
buf.WriteByte(uint8(srcConnID.Len()))
|
buf = append(buf, uint8(srcConnID.Len()))
|
||||||
buf.Write(srcConnID.Bytes())
|
buf = append(buf, srcConnID.Bytes()...)
|
||||||
for _, v := range greasedVersions {
|
for _, v := range greasedVersions {
|
||||||
utils.BigEndian.WriteUint32(buf, uint32(v))
|
buf = binary.BigEndian.AppendUint32(buf, uint32(v))
|
||||||
}
|
}
|
||||||
return buf.Bytes()
|
return buf
|
||||||
}
|
}
|
||||||
|
|
|
@ -27,9 +27,9 @@ type ConnectionTracer struct {
|
||||||
UpdatedCongestionState func(CongestionState)
|
UpdatedCongestionState func(CongestionState)
|
||||||
UpdatedPTOCount func(value uint32)
|
UpdatedPTOCount func(value uint32)
|
||||||
UpdatedKeyFromTLS func(EncryptionLevel, Perspective)
|
UpdatedKeyFromTLS func(EncryptionLevel, Perspective)
|
||||||
UpdatedKey func(generation KeyPhase, remote bool)
|
UpdatedKey func(keyPhase KeyPhase, remote bool)
|
||||||
DroppedEncryptionLevel func(EncryptionLevel)
|
DroppedEncryptionLevel func(EncryptionLevel)
|
||||||
DroppedKey func(generation KeyPhase)
|
DroppedKey func(keyPhase KeyPhase)
|
||||||
SetLossTimer func(TimerType, EncryptionLevel, time.Time)
|
SetLossTimer func(TimerType, EncryptionLevel, time.Time)
|
||||||
LossTimerExpired func(TimerType, EncryptionLevel)
|
LossTimerExpired func(TimerType, EncryptionLevel)
|
||||||
LossTimerCanceled func()
|
LossTimerCanceled func()
|
||||||
|
|
|
@ -37,7 +37,7 @@ type (
|
||||||
// The StreamType is the type of the stream (unidirectional or bidirectional).
|
// The StreamType is the type of the stream (unidirectional or bidirectional).
|
||||||
StreamType = protocol.StreamType
|
StreamType = protocol.StreamType
|
||||||
// The VersionNumber is the QUIC version.
|
// The VersionNumber is the QUIC version.
|
||||||
VersionNumber = protocol.VersionNumber
|
VersionNumber = protocol.Version
|
||||||
|
|
||||||
// The Header is the QUIC packet header, before removing header protection.
|
// The Header is the QUIC packet header, before removing header protection.
|
||||||
Header = wire.Header
|
Header = wire.Header
|
||||||
|
|
|
@ -7,6 +7,8 @@ type Tracer struct {
|
||||||
SentPacket func(net.Addr, *Header, ByteCount, []Frame)
|
SentPacket func(net.Addr, *Header, ByteCount, []Frame)
|
||||||
SentVersionNegotiationPacket func(_ net.Addr, dest, src ArbitraryLenConnectionID, _ []VersionNumber)
|
SentVersionNegotiationPacket func(_ net.Addr, dest, src ArbitraryLenConnectionID, _ []VersionNumber)
|
||||||
DroppedPacket func(net.Addr, PacketType, ByteCount, PacketDropReason)
|
DroppedPacket func(net.Addr, PacketType, ByteCount, PacketDropReason)
|
||||||
|
Debug func(name, msg string)
|
||||||
|
Close func()
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewMultiplexedTracer creates a new tracer that multiplexes events to multiple tracers.
|
// NewMultiplexedTracer creates a new tracer that multiplexes events to multiple tracers.
|
||||||
|
@ -39,5 +41,19 @@ func NewMultiplexedTracer(tracers ...*Tracer) *Tracer {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
Debug: func(name, msg string) {
|
||||||
|
for _, t := range tracers {
|
||||||
|
if t.Debug != nil {
|
||||||
|
t.Debug(name, msg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
Close: func() {
|
||||||
|
for _, t := range tracers {
|
||||||
|
if t.Close != nil {
|
||||||
|
t.Close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -3,12 +3,12 @@
|
||||||
# Install Go manually, since oss-fuzz ships with an outdated Go version.
|
# Install Go manually, since oss-fuzz ships with an outdated Go version.
|
||||||
# See https://github.com/google/oss-fuzz/pull/10643.
|
# See https://github.com/google/oss-fuzz/pull/10643.
|
||||||
export CXX="${CXX} -lresolv" # required by Go 1.20
|
export CXX="${CXX} -lresolv" # required by Go 1.20
|
||||||
wget https://go.dev/dl/go1.21.5.linux-amd64.tar.gz \
|
wget https://go.dev/dl/go1.22.0.linux-amd64.tar.gz \
|
||||||
&& mkdir temp-go \
|
&& mkdir temp-go \
|
||||||
&& rm -rf /root/.go/* \
|
&& rm -rf /root/.go/* \
|
||||||
&& tar -C temp-go/ -xzf go1.21.5.linux-amd64.tar.gz \
|
&& tar -C temp-go/ -xzf go1.22.0.linux-amd64.tar.gz \
|
||||||
&& mv temp-go/go/* /root/.go/ \
|
&& mv temp-go/go/* /root/.go/ \
|
||||||
&& rm -rf temp-go go1.21.5.linux-amd64.tar.gz
|
&& rm -rf temp-go go1.22.0.linux-amd64.tar.gz
|
||||||
|
|
||||||
(
|
(
|
||||||
# fuzz qpack
|
# fuzz qpack
|
||||||
|
|
|
@ -129,7 +129,7 @@ func (h *packetHandlerMap) Add(id protocol.ConnectionID, handler packetHandler)
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *packetHandlerMap) AddWithConnID(clientDestConnID, newConnID protocol.ConnectionID, fn func() (packetHandler, bool)) bool {
|
func (h *packetHandlerMap) AddWithConnID(clientDestConnID, newConnID protocol.ConnectionID, handler packetHandler) bool {
|
||||||
h.mutex.Lock()
|
h.mutex.Lock()
|
||||||
defer h.mutex.Unlock()
|
defer h.mutex.Unlock()
|
||||||
|
|
||||||
|
@ -137,12 +137,8 @@ func (h *packetHandlerMap) AddWithConnID(clientDestConnID, newConnID protocol.Co
|
||||||
h.logger.Debugf("Not adding connection ID %s for a new connection, as it already exists.", clientDestConnID)
|
h.logger.Debugf("Not adding connection ID %s for a new connection, as it already exists.", clientDestConnID)
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
conn, ok := fn()
|
h.handlers[clientDestConnID] = handler
|
||||||
if !ok {
|
h.handlers[newConnID] = handler
|
||||||
return false
|
|
||||||
}
|
|
||||||
h.handlers[clientDestConnID] = conn
|
|
||||||
h.handlers[newConnID] = conn
|
|
||||||
h.logger.Debugf("Adding connection IDs %s and %s for a new connection.", clientDestConnID, newConnID)
|
h.logger.Debugf("Adding connection IDs %s and %s for a new connection.", clientDestConnID, newConnID)
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
@ -168,18 +164,17 @@ func (h *packetHandlerMap) Retire(id protocol.ConnectionID) {
|
||||||
// Depending on which side closed the connection, we need to:
|
// Depending on which side closed the connection, we need to:
|
||||||
// * remote close: absorb delayed packets
|
// * remote close: absorb delayed packets
|
||||||
// * local close: retransmit the CONNECTION_CLOSE packet, in case it was lost
|
// * local close: retransmit the CONNECTION_CLOSE packet, in case it was lost
|
||||||
func (h *packetHandlerMap) ReplaceWithClosed(ids []protocol.ConnectionID, pers protocol.Perspective, connClosePacket []byte) {
|
func (h *packetHandlerMap) ReplaceWithClosed(ids []protocol.ConnectionID, connClosePacket []byte) {
|
||||||
var handler packetHandler
|
var handler packetHandler
|
||||||
if connClosePacket != nil {
|
if connClosePacket != nil {
|
||||||
handler = newClosedLocalConn(
|
handler = newClosedLocalConn(
|
||||||
func(addr net.Addr, info packetInfo) {
|
func(addr net.Addr, info packetInfo) {
|
||||||
h.enqueueClosePacket(closePacket{payload: connClosePacket, addr: addr, info: info})
|
h.enqueueClosePacket(closePacket{payload: connClosePacket, addr: addr, info: info})
|
||||||
},
|
},
|
||||||
pers,
|
|
||||||
h.logger,
|
h.logger,
|
||||||
)
|
)
|
||||||
} else {
|
} else {
|
||||||
handler = newClosedRemoteConn(pers)
|
handler = newClosedRemoteConn()
|
||||||
}
|
}
|
||||||
|
|
||||||
h.mutex.Lock()
|
h.mutex.Lock()
|
||||||
|
@ -191,7 +186,6 @@ func (h *packetHandlerMap) ReplaceWithClosed(ids []protocol.ConnectionID, pers p
|
||||||
|
|
||||||
time.AfterFunc(h.deleteRetiredConnsAfter, func() {
|
time.AfterFunc(h.deleteRetiredConnsAfter, func() {
|
||||||
h.mutex.Lock()
|
h.mutex.Lock()
|
||||||
handler.shutdown()
|
|
||||||
for _, id := range ids {
|
for _, id := range ids {
|
||||||
delete(h.handlers, id)
|
delete(h.handlers, id)
|
||||||
}
|
}
|
||||||
|
|
|
@ -18,13 +18,13 @@ import (
|
||||||
var errNothingToPack = errors.New("nothing to pack")
|
var errNothingToPack = errors.New("nothing to pack")
|
||||||
|
|
||||||
type packer interface {
|
type packer interface {
|
||||||
PackCoalescedPacket(onlyAck bool, maxPacketSize protocol.ByteCount, v protocol.VersionNumber) (*coalescedPacket, error)
|
PackCoalescedPacket(onlyAck bool, maxPacketSize protocol.ByteCount, v protocol.Version) (*coalescedPacket, error)
|
||||||
PackAckOnlyPacket(maxPacketSize protocol.ByteCount, v protocol.VersionNumber) (shortHeaderPacket, *packetBuffer, error)
|
PackAckOnlyPacket(maxPacketSize protocol.ByteCount, v protocol.Version) (shortHeaderPacket, *packetBuffer, error)
|
||||||
AppendPacket(buf *packetBuffer, maxPacketSize protocol.ByteCount, v protocol.VersionNumber) (shortHeaderPacket, error)
|
AppendPacket(buf *packetBuffer, maxPacketSize protocol.ByteCount, v protocol.Version) (shortHeaderPacket, error)
|
||||||
MaybePackProbePacket(protocol.EncryptionLevel, protocol.ByteCount, protocol.VersionNumber) (*coalescedPacket, error)
|
MaybePackProbePacket(protocol.EncryptionLevel, protocol.ByteCount, protocol.Version) (*coalescedPacket, error)
|
||||||
PackConnectionClose(*qerr.TransportError, protocol.ByteCount, protocol.VersionNumber) (*coalescedPacket, error)
|
PackConnectionClose(*qerr.TransportError, protocol.ByteCount, protocol.Version) (*coalescedPacket, error)
|
||||||
PackApplicationClose(*qerr.ApplicationError, protocol.ByteCount, protocol.VersionNumber) (*coalescedPacket, error)
|
PackApplicationClose(*qerr.ApplicationError, protocol.ByteCount, protocol.Version) (*coalescedPacket, error)
|
||||||
PackMTUProbePacket(ping ackhandler.Frame, size protocol.ByteCount, v protocol.VersionNumber) (shortHeaderPacket, *packetBuffer, error)
|
PackMTUProbePacket(ping ackhandler.Frame, size protocol.ByteCount, v protocol.Version) (shortHeaderPacket, *packetBuffer, error)
|
||||||
|
|
||||||
SetToken([]byte)
|
SetToken([]byte)
|
||||||
}
|
}
|
||||||
|
@ -106,8 +106,8 @@ type sealingManager interface {
|
||||||
|
|
||||||
type frameSource interface {
|
type frameSource interface {
|
||||||
HasData() bool
|
HasData() bool
|
||||||
AppendStreamFrames([]ackhandler.StreamFrame, protocol.ByteCount, protocol.VersionNumber) ([]ackhandler.StreamFrame, protocol.ByteCount)
|
AppendStreamFrames([]ackhandler.StreamFrame, protocol.ByteCount, protocol.Version) ([]ackhandler.StreamFrame, protocol.ByteCount)
|
||||||
AppendControlFrames([]ackhandler.Frame, protocol.ByteCount, protocol.VersionNumber) ([]ackhandler.Frame, protocol.ByteCount)
|
AppendControlFrames([]ackhandler.Frame, protocol.ByteCount, protocol.Version) ([]ackhandler.Frame, protocol.ByteCount)
|
||||||
}
|
}
|
||||||
|
|
||||||
type ackFrameSource interface {
|
type ackFrameSource interface {
|
||||||
|
@ -170,7 +170,7 @@ func newPacketPacker(
|
||||||
}
|
}
|
||||||
|
|
||||||
// PackConnectionClose packs a packet that closes the connection with a transport error.
|
// PackConnectionClose packs a packet that closes the connection with a transport error.
|
||||||
func (p *packetPacker) PackConnectionClose(e *qerr.TransportError, maxPacketSize protocol.ByteCount, v protocol.VersionNumber) (*coalescedPacket, error) {
|
func (p *packetPacker) PackConnectionClose(e *qerr.TransportError, maxPacketSize protocol.ByteCount, v protocol.Version) (*coalescedPacket, error) {
|
||||||
var reason string
|
var reason string
|
||||||
// don't send details of crypto errors
|
// don't send details of crypto errors
|
||||||
if !e.ErrorCode.IsCryptoError() {
|
if !e.ErrorCode.IsCryptoError() {
|
||||||
|
@ -180,7 +180,7 @@ func (p *packetPacker) PackConnectionClose(e *qerr.TransportError, maxPacketSize
|
||||||
}
|
}
|
||||||
|
|
||||||
// PackApplicationClose packs a packet that closes the connection with an application error.
|
// PackApplicationClose packs a packet that closes the connection with an application error.
|
||||||
func (p *packetPacker) PackApplicationClose(e *qerr.ApplicationError, maxPacketSize protocol.ByteCount, v protocol.VersionNumber) (*coalescedPacket, error) {
|
func (p *packetPacker) PackApplicationClose(e *qerr.ApplicationError, maxPacketSize protocol.ByteCount, v protocol.Version) (*coalescedPacket, error) {
|
||||||
return p.packConnectionClose(true, uint64(e.ErrorCode), 0, e.ErrorMessage, maxPacketSize, v)
|
return p.packConnectionClose(true, uint64(e.ErrorCode), 0, e.ErrorMessage, maxPacketSize, v)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -190,7 +190,7 @@ func (p *packetPacker) packConnectionClose(
|
||||||
frameType uint64,
|
frameType uint64,
|
||||||
reason string,
|
reason string,
|
||||||
maxPacketSize protocol.ByteCount,
|
maxPacketSize protocol.ByteCount,
|
||||||
v protocol.VersionNumber,
|
v protocol.Version,
|
||||||
) (*coalescedPacket, error) {
|
) (*coalescedPacket, error) {
|
||||||
var sealers [4]sealer
|
var sealers [4]sealer
|
||||||
var hdrs [3]*wire.ExtendedHeader
|
var hdrs [3]*wire.ExtendedHeader
|
||||||
|
@ -293,7 +293,7 @@ func (p *packetPacker) packConnectionClose(
|
||||||
// longHeaderPacketLength calculates the length of a serialized long header packet.
|
// longHeaderPacketLength calculates the length of a serialized long header packet.
|
||||||
// It takes into account that packets that have a tiny payload need to be padded,
|
// It takes into account that packets that have a tiny payload need to be padded,
|
||||||
// such that len(payload) + packet number len >= 4 + AEAD overhead
|
// such that len(payload) + packet number len >= 4 + AEAD overhead
|
||||||
func (p *packetPacker) longHeaderPacketLength(hdr *wire.ExtendedHeader, pl payload, v protocol.VersionNumber) protocol.ByteCount {
|
func (p *packetPacker) longHeaderPacketLength(hdr *wire.ExtendedHeader, pl payload, v protocol.Version) protocol.ByteCount {
|
||||||
var paddingLen protocol.ByteCount
|
var paddingLen protocol.ByteCount
|
||||||
pnLen := protocol.ByteCount(hdr.PacketNumberLen)
|
pnLen := protocol.ByteCount(hdr.PacketNumberLen)
|
||||||
if pl.length < 4-pnLen {
|
if pl.length < 4-pnLen {
|
||||||
|
@ -328,7 +328,7 @@ func (p *packetPacker) initialPaddingLen(frames []ackhandler.Frame, currentSize,
|
||||||
// PackCoalescedPacket packs a new packet.
|
// PackCoalescedPacket packs a new packet.
|
||||||
// It packs an Initial / Handshake if there is data to send in these packet number spaces.
|
// It packs an Initial / Handshake if there is data to send in these packet number spaces.
|
||||||
// It should only be called before the handshake is confirmed.
|
// It should only be called before the handshake is confirmed.
|
||||||
func (p *packetPacker) PackCoalescedPacket(onlyAck bool, maxPacketSize protocol.ByteCount, v protocol.VersionNumber) (*coalescedPacket, error) {
|
func (p *packetPacker) PackCoalescedPacket(onlyAck bool, maxPacketSize protocol.ByteCount, v protocol.Version) (*coalescedPacket, error) {
|
||||||
var (
|
var (
|
||||||
initialHdr, handshakeHdr, zeroRTTHdr *wire.ExtendedHeader
|
initialHdr, handshakeHdr, zeroRTTHdr *wire.ExtendedHeader
|
||||||
initialPayload, handshakePayload, zeroRTTPayload, oneRTTPayload payload
|
initialPayload, handshakePayload, zeroRTTPayload, oneRTTPayload payload
|
||||||
|
@ -442,7 +442,7 @@ func (p *packetPacker) PackCoalescedPacket(onlyAck bool, maxPacketSize protocol.
|
||||||
|
|
||||||
// PackAckOnlyPacket packs a packet containing only an ACK in the application data packet number space.
|
// PackAckOnlyPacket packs a packet containing only an ACK in the application data packet number space.
|
||||||
// It should be called after the handshake is confirmed.
|
// It should be called after the handshake is confirmed.
|
||||||
func (p *packetPacker) PackAckOnlyPacket(maxPacketSize protocol.ByteCount, v protocol.VersionNumber) (shortHeaderPacket, *packetBuffer, error) {
|
func (p *packetPacker) PackAckOnlyPacket(maxPacketSize protocol.ByteCount, v protocol.Version) (shortHeaderPacket, *packetBuffer, error) {
|
||||||
buf := getPacketBuffer()
|
buf := getPacketBuffer()
|
||||||
packet, err := p.appendPacket(buf, true, maxPacketSize, v)
|
packet, err := p.appendPacket(buf, true, maxPacketSize, v)
|
||||||
return packet, buf, err
|
return packet, buf, err
|
||||||
|
@ -450,11 +450,11 @@ func (p *packetPacker) PackAckOnlyPacket(maxPacketSize protocol.ByteCount, v pro
|
||||||
|
|
||||||
// AppendPacket packs a packet in the application data packet number space.
|
// AppendPacket packs a packet in the application data packet number space.
|
||||||
// It should be called after the handshake is confirmed.
|
// It should be called after the handshake is confirmed.
|
||||||
func (p *packetPacker) AppendPacket(buf *packetBuffer, maxPacketSize protocol.ByteCount, v protocol.VersionNumber) (shortHeaderPacket, error) {
|
func (p *packetPacker) AppendPacket(buf *packetBuffer, maxPacketSize protocol.ByteCount, v protocol.Version) (shortHeaderPacket, error) {
|
||||||
return p.appendPacket(buf, false, maxPacketSize, v)
|
return p.appendPacket(buf, false, maxPacketSize, v)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *packetPacker) appendPacket(buf *packetBuffer, onlyAck bool, maxPacketSize protocol.ByteCount, v protocol.VersionNumber) (shortHeaderPacket, error) {
|
func (p *packetPacker) appendPacket(buf *packetBuffer, onlyAck bool, maxPacketSize protocol.ByteCount, v protocol.Version) (shortHeaderPacket, error) {
|
||||||
sealer, err := p.cryptoSetup.Get1RTTSealer()
|
sealer, err := p.cryptoSetup.Get1RTTSealer()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return shortHeaderPacket{}, err
|
return shortHeaderPacket{}, err
|
||||||
|
@ -471,7 +471,7 @@ func (p *packetPacker) appendPacket(buf *packetBuffer, onlyAck bool, maxPacketSi
|
||||||
return p.appendShortHeaderPacket(buf, connID, pn, pnLen, kp, pl, 0, maxPacketSize, sealer, false, v)
|
return p.appendShortHeaderPacket(buf, connID, pn, pnLen, kp, pl, 0, maxPacketSize, sealer, false, v)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *packetPacker) maybeGetCryptoPacket(maxPacketSize protocol.ByteCount, encLevel protocol.EncryptionLevel, onlyAck, ackAllowed bool, v protocol.VersionNumber) (*wire.ExtendedHeader, payload) {
|
func (p *packetPacker) maybeGetCryptoPacket(maxPacketSize protocol.ByteCount, encLevel protocol.EncryptionLevel, onlyAck, ackAllowed bool, v protocol.Version) (*wire.ExtendedHeader, payload) {
|
||||||
if onlyAck {
|
if onlyAck {
|
||||||
if ack := p.acks.GetAckFrame(encLevel, true); ack != nil {
|
if ack := p.acks.GetAckFrame(encLevel, true); ack != nil {
|
||||||
return p.getLongHeader(encLevel, v), payload{
|
return p.getLongHeader(encLevel, v), payload{
|
||||||
|
@ -543,7 +543,7 @@ func (p *packetPacker) maybeGetCryptoPacket(maxPacketSize protocol.ByteCount, en
|
||||||
return hdr, pl
|
return hdr, pl
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *packetPacker) maybeGetAppDataPacketFor0RTT(sealer sealer, maxPacketSize protocol.ByteCount, v protocol.VersionNumber) (*wire.ExtendedHeader, payload) {
|
func (p *packetPacker) maybeGetAppDataPacketFor0RTT(sealer sealer, maxPacketSize protocol.ByteCount, v protocol.Version) (*wire.ExtendedHeader, payload) {
|
||||||
if p.perspective != protocol.PerspectiveClient {
|
if p.perspective != protocol.PerspectiveClient {
|
||||||
return nil, payload{}
|
return nil, payload{}
|
||||||
}
|
}
|
||||||
|
@ -553,12 +553,12 @@ func (p *packetPacker) maybeGetAppDataPacketFor0RTT(sealer sealer, maxPacketSize
|
||||||
return hdr, p.maybeGetAppDataPacket(maxPayloadSize, false, false, v)
|
return hdr, p.maybeGetAppDataPacket(maxPayloadSize, false, false, v)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *packetPacker) maybeGetShortHeaderPacket(sealer handshake.ShortHeaderSealer, hdrLen protocol.ByteCount, maxPacketSize protocol.ByteCount, onlyAck, ackAllowed bool, v protocol.VersionNumber) payload {
|
func (p *packetPacker) maybeGetShortHeaderPacket(sealer handshake.ShortHeaderSealer, hdrLen protocol.ByteCount, maxPacketSize protocol.ByteCount, onlyAck, ackAllowed bool, v protocol.Version) payload {
|
||||||
maxPayloadSize := maxPacketSize - hdrLen - protocol.ByteCount(sealer.Overhead())
|
maxPayloadSize := maxPacketSize - hdrLen - protocol.ByteCount(sealer.Overhead())
|
||||||
return p.maybeGetAppDataPacket(maxPayloadSize, onlyAck, ackAllowed, v)
|
return p.maybeGetAppDataPacket(maxPayloadSize, onlyAck, ackAllowed, v)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *packetPacker) maybeGetAppDataPacket(maxPayloadSize protocol.ByteCount, onlyAck, ackAllowed bool, v protocol.VersionNumber) payload {
|
func (p *packetPacker) maybeGetAppDataPacket(maxPayloadSize protocol.ByteCount, onlyAck, ackAllowed bool, v protocol.Version) payload {
|
||||||
pl := p.composeNextPacket(maxPayloadSize, onlyAck, ackAllowed, v)
|
pl := p.composeNextPacket(maxPayloadSize, onlyAck, ackAllowed, v)
|
||||||
|
|
||||||
// check if we have anything to send
|
// check if we have anything to send
|
||||||
|
@ -581,7 +581,7 @@ func (p *packetPacker) maybeGetAppDataPacket(maxPayloadSize protocol.ByteCount,
|
||||||
return pl
|
return pl
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *packetPacker) composeNextPacket(maxFrameSize protocol.ByteCount, onlyAck, ackAllowed bool, v protocol.VersionNumber) payload {
|
func (p *packetPacker) composeNextPacket(maxFrameSize protocol.ByteCount, onlyAck, ackAllowed bool, v protocol.Version) payload {
|
||||||
if onlyAck {
|
if onlyAck {
|
||||||
if ack := p.acks.GetAckFrame(protocol.Encryption1RTT, true); ack != nil {
|
if ack := p.acks.GetAckFrame(protocol.Encryption1RTT, true); ack != nil {
|
||||||
return payload{ack: ack, length: ack.Length(v)}
|
return payload{ack: ack, length: ack.Length(v)}
|
||||||
|
@ -589,12 +589,11 @@ func (p *packetPacker) composeNextPacket(maxFrameSize protocol.ByteCount, onlyAc
|
||||||
return payload{}
|
return payload{}
|
||||||
}
|
}
|
||||||
|
|
||||||
pl := payload{streamFrames: make([]ackhandler.StreamFrame, 0, 1)}
|
|
||||||
|
|
||||||
hasData := p.framer.HasData()
|
hasData := p.framer.HasData()
|
||||||
hasRetransmission := p.retransmissionQueue.HasAppData()
|
hasRetransmission := p.retransmissionQueue.HasAppData()
|
||||||
|
|
||||||
var hasAck bool
|
var hasAck bool
|
||||||
|
var pl payload
|
||||||
if ackAllowed {
|
if ackAllowed {
|
||||||
if ack := p.acks.GetAckFrame(protocol.Encryption1RTT, !hasRetransmission && !hasData); ack != nil {
|
if ack := p.acks.GetAckFrame(protocol.Encryption1RTT, !hasRetransmission && !hasData); ack != nil {
|
||||||
pl.ack = ack
|
pl.ack = ack
|
||||||
|
@ -661,7 +660,7 @@ func (p *packetPacker) composeNextPacket(maxFrameSize protocol.ByteCount, onlyAc
|
||||||
return pl
|
return pl
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *packetPacker) MaybePackProbePacket(encLevel protocol.EncryptionLevel, maxPacketSize protocol.ByteCount, v protocol.VersionNumber) (*coalescedPacket, error) {
|
func (p *packetPacker) MaybePackProbePacket(encLevel protocol.EncryptionLevel, maxPacketSize protocol.ByteCount, v protocol.Version) (*coalescedPacket, error) {
|
||||||
if encLevel == protocol.Encryption1RTT {
|
if encLevel == protocol.Encryption1RTT {
|
||||||
s, err := p.cryptoSetup.Get1RTTSealer()
|
s, err := p.cryptoSetup.Get1RTTSealer()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -727,7 +726,7 @@ func (p *packetPacker) MaybePackProbePacket(encLevel protocol.EncryptionLevel, m
|
||||||
return packet, nil
|
return packet, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *packetPacker) PackMTUProbePacket(ping ackhandler.Frame, size protocol.ByteCount, v protocol.VersionNumber) (shortHeaderPacket, *packetBuffer, error) {
|
func (p *packetPacker) PackMTUProbePacket(ping ackhandler.Frame, size protocol.ByteCount, v protocol.Version) (shortHeaderPacket, *packetBuffer, error) {
|
||||||
pl := payload{
|
pl := payload{
|
||||||
frames: []ackhandler.Frame{ping},
|
frames: []ackhandler.Frame{ping},
|
||||||
length: ping.Frame.Length(v),
|
length: ping.Frame.Length(v),
|
||||||
|
@ -745,7 +744,7 @@ func (p *packetPacker) PackMTUProbePacket(ping ackhandler.Frame, size protocol.B
|
||||||
return packet, buffer, err
|
return packet, buffer, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *packetPacker) getLongHeader(encLevel protocol.EncryptionLevel, v protocol.VersionNumber) *wire.ExtendedHeader {
|
func (p *packetPacker) getLongHeader(encLevel protocol.EncryptionLevel, v protocol.Version) *wire.ExtendedHeader {
|
||||||
pn, pnLen := p.pnManager.PeekPacketNumber(encLevel)
|
pn, pnLen := p.pnManager.PeekPacketNumber(encLevel)
|
||||||
hdr := &wire.ExtendedHeader{
|
hdr := &wire.ExtendedHeader{
|
||||||
PacketNumber: pn,
|
PacketNumber: pn,
|
||||||
|
@ -768,7 +767,7 @@ func (p *packetPacker) getLongHeader(encLevel protocol.EncryptionLevel, v protoc
|
||||||
return hdr
|
return hdr
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *packetPacker) appendLongHeaderPacket(buffer *packetBuffer, header *wire.ExtendedHeader, pl payload, padding protocol.ByteCount, encLevel protocol.EncryptionLevel, sealer sealer, v protocol.VersionNumber) (*longHeaderPacket, error) {
|
func (p *packetPacker) appendLongHeaderPacket(buffer *packetBuffer, header *wire.ExtendedHeader, pl payload, padding protocol.ByteCount, encLevel protocol.EncryptionLevel, sealer sealer, v protocol.Version) (*longHeaderPacket, error) {
|
||||||
var paddingLen protocol.ByteCount
|
var paddingLen protocol.ByteCount
|
||||||
pnLen := protocol.ByteCount(header.PacketNumberLen)
|
pnLen := protocol.ByteCount(header.PacketNumberLen)
|
||||||
if pl.length < 4-pnLen {
|
if pl.length < 4-pnLen {
|
||||||
|
@ -814,7 +813,7 @@ func (p *packetPacker) appendShortHeaderPacket(
|
||||||
padding, maxPacketSize protocol.ByteCount,
|
padding, maxPacketSize protocol.ByteCount,
|
||||||
sealer sealer,
|
sealer sealer,
|
||||||
isMTUProbePacket bool,
|
isMTUProbePacket bool,
|
||||||
v protocol.VersionNumber,
|
v protocol.Version,
|
||||||
) (shortHeaderPacket, error) {
|
) (shortHeaderPacket, error) {
|
||||||
var paddingLen protocol.ByteCount
|
var paddingLen protocol.ByteCount
|
||||||
if pl.length < 4-protocol.ByteCount(pnLen) {
|
if pl.length < 4-protocol.ByteCount(pnLen) {
|
||||||
|
@ -860,7 +859,7 @@ func (p *packetPacker) appendShortHeaderPacket(
|
||||||
|
|
||||||
// appendPacketPayload serializes the payload of a packet into the raw byte slice.
|
// appendPacketPayload serializes the payload of a packet into the raw byte slice.
|
||||||
// It modifies the order of payload.frames.
|
// It modifies the order of payload.frames.
|
||||||
func (p *packetPacker) appendPacketPayload(raw []byte, pl payload, paddingLen protocol.ByteCount, v protocol.VersionNumber) ([]byte, error) {
|
func (p *packetPacker) appendPacketPayload(raw []byte, pl payload, paddingLen protocol.ByteCount, v protocol.Version) ([]byte, error) {
|
||||||
payloadOffset := len(raw)
|
payloadOffset := len(raw)
|
||||||
if pl.ack != nil {
|
if pl.ack != nil {
|
||||||
var err error
|
var err error
|
||||||
|
|
|
@ -53,7 +53,7 @@ func newPacketUnpacker(cs handshake.CryptoSetup, shortHdrConnIDLen int) *packetU
|
||||||
// If the reserved bits are invalid, the error is wire.ErrInvalidReservedBits.
|
// If the reserved bits are invalid, the error is wire.ErrInvalidReservedBits.
|
||||||
// If any other error occurred when parsing the header, the error is of type headerParseError.
|
// If any other error occurred when parsing the header, the error is of type headerParseError.
|
||||||
// If decrypting the payload fails for any reason, the error is the error returned by the AEAD.
|
// If decrypting the payload fails for any reason, the error is the error returned by the AEAD.
|
||||||
func (u *packetUnpacker) UnpackLongHeader(hdr *wire.Header, rcvTime time.Time, data []byte, v protocol.VersionNumber) (*unpackedPacket, error) {
|
func (u *packetUnpacker) UnpackLongHeader(hdr *wire.Header, rcvTime time.Time, data []byte, v protocol.Version) (*unpackedPacket, error) {
|
||||||
var encLevel protocol.EncryptionLevel
|
var encLevel protocol.EncryptionLevel
|
||||||
var extHdr *wire.ExtendedHeader
|
var extHdr *wire.ExtendedHeader
|
||||||
var decrypted []byte
|
var decrypted []byte
|
||||||
|
@ -125,7 +125,7 @@ func (u *packetUnpacker) UnpackShortHeader(rcvTime time.Time, data []byte) (prot
|
||||||
return pn, pnLen, kp, decrypted, nil
|
return pn, pnLen, kp, decrypted, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *packetUnpacker) unpackLongHeaderPacket(opener handshake.LongHeaderOpener, hdr *wire.Header, data []byte, v protocol.VersionNumber) (*wire.ExtendedHeader, []byte, error) {
|
func (u *packetUnpacker) unpackLongHeaderPacket(opener handshake.LongHeaderOpener, hdr *wire.Header, data []byte, v protocol.Version) (*wire.ExtendedHeader, []byte, error) {
|
||||||
extHdr, parseErr := u.unpackLongHeader(opener, hdr, data, v)
|
extHdr, parseErr := u.unpackLongHeader(opener, hdr, data, v)
|
||||||
// If the reserved bits are set incorrectly, we still need to continue unpacking.
|
// If the reserved bits are set incorrectly, we still need to continue unpacking.
|
||||||
// This avoids a timing side-channel, which otherwise might allow an attacker
|
// This avoids a timing side-channel, which otherwise might allow an attacker
|
||||||
|
@ -187,7 +187,7 @@ func (u *packetUnpacker) unpackShortHeader(hd headerDecryptor, data []byte) (int
|
||||||
}
|
}
|
||||||
|
|
||||||
// The error is either nil, a wire.ErrInvalidReservedBits or of type headerParseError.
|
// The error is either nil, a wire.ErrInvalidReservedBits or of type headerParseError.
|
||||||
func (u *packetUnpacker) unpackLongHeader(hd headerDecryptor, hdr *wire.Header, data []byte, v protocol.VersionNumber) (*wire.ExtendedHeader, error) {
|
func (u *packetUnpacker) unpackLongHeader(hd headerDecryptor, hdr *wire.Header, data []byte, v protocol.Version) (*wire.ExtendedHeader, error) {
|
||||||
extHdr, err := unpackLongHeader(hd, hdr, data, v)
|
extHdr, err := unpackLongHeader(hd, hdr, data, v)
|
||||||
if err != nil && err != wire.ErrInvalidReservedBits {
|
if err != nil && err != wire.ErrInvalidReservedBits {
|
||||||
return nil, &headerParseError{err: err}
|
return nil, &headerParseError{err: err}
|
||||||
|
@ -195,7 +195,7 @@ func (u *packetUnpacker) unpackLongHeader(hd headerDecryptor, hdr *wire.Header,
|
||||||
return extHdr, err
|
return extHdr, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func unpackLongHeader(hd headerDecryptor, hdr *wire.Header, data []byte, v protocol.VersionNumber) (*wire.ExtendedHeader, error) {
|
func unpackLongHeader(hd headerDecryptor, hdr *wire.Header, data []byte, v protocol.Version) (*wire.ExtendedHeader, error) {
|
||||||
r := bytes.NewReader(data)
|
r := bytes.NewReader(data)
|
||||||
|
|
||||||
hdrLen := hdr.ParsedLen()
|
hdrLen := hdr.ParsedLen()
|
||||||
|
|
|
@ -292,10 +292,6 @@ func (s *receiveStream) handleResetStreamFrameImpl(frame *wire.ResetStreamFrame)
|
||||||
return newlyRcvdFinalOffset, nil
|
return newlyRcvdFinalOffset, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *receiveStream) CloseRemote(offset protocol.ByteCount) {
|
|
||||||
s.handleStreamFrame(&wire.StreamFrame{Fin: true, Offset: offset})
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *receiveStream) SetReadDeadline(t time.Time) error {
|
func (s *receiveStream) SetReadDeadline(t time.Time) error {
|
||||||
s.mutex.Lock()
|
s.mutex.Lock()
|
||||||
s.deadline = t
|
s.deadline = t
|
||||||
|
|
|
@ -74,7 +74,7 @@ func (q *retransmissionQueue) addAppData(f wire.Frame) {
|
||||||
q.appData = append(q.appData, f)
|
q.appData = append(q.appData, f)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (q *retransmissionQueue) GetInitialFrame(maxLen protocol.ByteCount, v protocol.VersionNumber) wire.Frame {
|
func (q *retransmissionQueue) GetInitialFrame(maxLen protocol.ByteCount, v protocol.Version) wire.Frame {
|
||||||
if len(q.initialCryptoData) > 0 {
|
if len(q.initialCryptoData) > 0 {
|
||||||
f := q.initialCryptoData[0]
|
f := q.initialCryptoData[0]
|
||||||
newFrame, needsSplit := f.MaybeSplitOffFrame(maxLen, v)
|
newFrame, needsSplit := f.MaybeSplitOffFrame(maxLen, v)
|
||||||
|
@ -97,7 +97,7 @@ func (q *retransmissionQueue) GetInitialFrame(maxLen protocol.ByteCount, v proto
|
||||||
return f
|
return f
|
||||||
}
|
}
|
||||||
|
|
||||||
func (q *retransmissionQueue) GetHandshakeFrame(maxLen protocol.ByteCount, v protocol.VersionNumber) wire.Frame {
|
func (q *retransmissionQueue) GetHandshakeFrame(maxLen protocol.ByteCount, v protocol.Version) wire.Frame {
|
||||||
if len(q.handshakeCryptoData) > 0 {
|
if len(q.handshakeCryptoData) > 0 {
|
||||||
f := q.handshakeCryptoData[0]
|
f := q.handshakeCryptoData[0]
|
||||||
newFrame, needsSplit := f.MaybeSplitOffFrame(maxLen, v)
|
newFrame, needsSplit := f.MaybeSplitOffFrame(maxLen, v)
|
||||||
|
@ -120,7 +120,7 @@ func (q *retransmissionQueue) GetHandshakeFrame(maxLen protocol.ByteCount, v pro
|
||||||
return f
|
return f
|
||||||
}
|
}
|
||||||
|
|
||||||
func (q *retransmissionQueue) GetAppDataFrame(maxLen protocol.ByteCount, v protocol.VersionNumber) wire.Frame {
|
func (q *retransmissionQueue) GetAppDataFrame(maxLen protocol.ByteCount, v protocol.Version) wire.Frame {
|
||||||
if len(q.appData) == 0 {
|
if len(q.appData) == 0 {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -18,7 +18,7 @@ type sendStreamI interface {
|
||||||
SendStream
|
SendStream
|
||||||
handleStopSendingFrame(*wire.StopSendingFrame)
|
handleStopSendingFrame(*wire.StopSendingFrame)
|
||||||
hasData() bool
|
hasData() bool
|
||||||
popStreamFrame(maxBytes protocol.ByteCount, v protocol.VersionNumber) (frame ackhandler.StreamFrame, ok, hasMore bool)
|
popStreamFrame(maxBytes protocol.ByteCount, v protocol.Version) (frame ackhandler.StreamFrame, ok, hasMore bool)
|
||||||
closeForShutdown(error)
|
closeForShutdown(error)
|
||||||
updateSendWindow(protocol.ByteCount)
|
updateSendWindow(protocol.ByteCount)
|
||||||
}
|
}
|
||||||
|
@ -198,7 +198,7 @@ func (s *sendStream) canBufferStreamFrame() bool {
|
||||||
|
|
||||||
// popStreamFrame returns the next STREAM frame that is supposed to be sent on this stream
|
// popStreamFrame returns the next STREAM frame that is supposed to be sent on this stream
|
||||||
// maxBytes is the maximum length this frame (including frame header) will have.
|
// maxBytes is the maximum length this frame (including frame header) will have.
|
||||||
func (s *sendStream) popStreamFrame(maxBytes protocol.ByteCount, v protocol.VersionNumber) (af ackhandler.StreamFrame, ok, hasMore bool) {
|
func (s *sendStream) popStreamFrame(maxBytes protocol.ByteCount, v protocol.Version) (af ackhandler.StreamFrame, ok, hasMore bool) {
|
||||||
s.mutex.Lock()
|
s.mutex.Lock()
|
||||||
f, hasMoreData := s.popNewOrRetransmittedStreamFrame(maxBytes, v)
|
f, hasMoreData := s.popNewOrRetransmittedStreamFrame(maxBytes, v)
|
||||||
if f != nil {
|
if f != nil {
|
||||||
|
@ -215,7 +215,7 @@ func (s *sendStream) popStreamFrame(maxBytes protocol.ByteCount, v protocol.Vers
|
||||||
}, true, hasMoreData
|
}, true, hasMoreData
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *sendStream) popNewOrRetransmittedStreamFrame(maxBytes protocol.ByteCount, v protocol.VersionNumber) (*wire.StreamFrame, bool /* has more data to send */) {
|
func (s *sendStream) popNewOrRetransmittedStreamFrame(maxBytes protocol.ByteCount, v protocol.Version) (*wire.StreamFrame, bool /* has more data to send */) {
|
||||||
if s.cancelWriteErr != nil || s.closeForShutdownErr != nil {
|
if s.cancelWriteErr != nil || s.closeForShutdownErr != nil {
|
||||||
return nil, false
|
return nil, false
|
||||||
}
|
}
|
||||||
|
@ -269,7 +269,7 @@ func (s *sendStream) popNewOrRetransmittedStreamFrame(maxBytes protocol.ByteCoun
|
||||||
return f, hasMoreData
|
return f, hasMoreData
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *sendStream) popNewStreamFrame(maxBytes, sendWindow protocol.ByteCount, v protocol.VersionNumber) (*wire.StreamFrame, bool) {
|
func (s *sendStream) popNewStreamFrame(maxBytes, sendWindow protocol.ByteCount, v protocol.Version) (*wire.StreamFrame, bool) {
|
||||||
if s.nextFrame != nil {
|
if s.nextFrame != nil {
|
||||||
nextFrame := s.nextFrame
|
nextFrame := s.nextFrame
|
||||||
s.nextFrame = nil
|
s.nextFrame = nil
|
||||||
|
@ -304,7 +304,7 @@ func (s *sendStream) popNewStreamFrame(maxBytes, sendWindow protocol.ByteCount,
|
||||||
return f, hasMoreData
|
return f, hasMoreData
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *sendStream) popNewStreamFrameWithoutBuffer(f *wire.StreamFrame, maxBytes, sendWindow protocol.ByteCount, v protocol.VersionNumber) bool {
|
func (s *sendStream) popNewStreamFrameWithoutBuffer(f *wire.StreamFrame, maxBytes, sendWindow protocol.ByteCount, v protocol.Version) bool {
|
||||||
maxDataLen := f.MaxDataLen(maxBytes, v)
|
maxDataLen := f.MaxDataLen(maxBytes, v)
|
||||||
if maxDataLen == 0 { // a STREAM frame must have at least one byte of data
|
if maxDataLen == 0 { // a STREAM frame must have at least one byte of data
|
||||||
return s.dataForWriting != nil || s.nextFrame != nil || s.finishedWriting
|
return s.dataForWriting != nil || s.nextFrame != nil || s.finishedWriting
|
||||||
|
@ -314,7 +314,7 @@ func (s *sendStream) popNewStreamFrameWithoutBuffer(f *wire.StreamFrame, maxByte
|
||||||
return s.dataForWriting != nil || s.nextFrame != nil || s.finishedWriting
|
return s.dataForWriting != nil || s.nextFrame != nil || s.finishedWriting
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *sendStream) maybeGetRetransmission(maxBytes protocol.ByteCount, v protocol.VersionNumber) (*wire.StreamFrame, bool /* has more retransmissions */) {
|
func (s *sendStream) maybeGetRetransmission(maxBytes protocol.ByteCount, v protocol.Version) (*wire.StreamFrame, bool /* has more retransmissions */) {
|
||||||
f := s.retransmissionQueue[0]
|
f := s.retransmissionQueue[0]
|
||||||
newFrame, needsSplit := f.MaybeSplitOffFrame(maxBytes, v)
|
newFrame, needsSplit := f.MaybeSplitOffFrame(maxBytes, v)
|
||||||
if needsSplit {
|
if needsSplit {
|
||||||
|
@ -404,11 +404,13 @@ func (s *sendStream) cancelWriteImpl(errorCode qerr.StreamErrorCode, remote bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *sendStream) updateSendWindow(limit protocol.ByteCount) {
|
func (s *sendStream) updateSendWindow(limit protocol.ByteCount) {
|
||||||
|
updated := s.flowController.UpdateSendWindow(limit)
|
||||||
|
if !updated { // duplicate or reordered MAX_STREAM_DATA frame
|
||||||
|
return
|
||||||
|
}
|
||||||
s.mutex.Lock()
|
s.mutex.Lock()
|
||||||
hasStreamData := s.dataForWriting != nil || s.nextFrame != nil
|
hasStreamData := s.dataForWriting != nil || s.nextFrame != nil
|
||||||
s.mutex.Unlock()
|
s.mutex.Unlock()
|
||||||
|
|
||||||
s.flowController.UpdateSendWindow(limit)
|
|
||||||
if hasStreamData {
|
if hasStreamData {
|
||||||
s.sender.onHasStreamData(s.streamID)
|
s.sender.onHasStreamData(s.streamID)
|
||||||
}
|
}
|
||||||
|
|
|
@ -7,7 +7,6 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/quic-go/quic-go/internal/handshake"
|
"github.com/quic-go/quic-go/internal/handshake"
|
||||||
|
@ -24,15 +23,14 @@ var ErrServerClosed = errors.New("quic: server closed")
|
||||||
// packetHandler handles packets
|
// packetHandler handles packets
|
||||||
type packetHandler interface {
|
type packetHandler interface {
|
||||||
handlePacket(receivedPacket)
|
handlePacket(receivedPacket)
|
||||||
shutdown()
|
|
||||||
destroy(error)
|
destroy(error)
|
||||||
getPerspective() protocol.Perspective
|
closeWithTransportError(qerr.TransportErrorCode)
|
||||||
}
|
}
|
||||||
|
|
||||||
type packetHandlerManager interface {
|
type packetHandlerManager interface {
|
||||||
Get(protocol.ConnectionID) (packetHandler, bool)
|
Get(protocol.ConnectionID) (packetHandler, bool)
|
||||||
GetByResetToken(protocol.StatelessResetToken) (packetHandler, bool)
|
GetByResetToken(protocol.StatelessResetToken) (packetHandler, bool)
|
||||||
AddWithConnID(protocol.ConnectionID, protocol.ConnectionID, func() (packetHandler, bool)) bool
|
AddWithConnID(destConnID, newConnID protocol.ConnectionID, h packetHandler) bool
|
||||||
Close(error)
|
Close(error)
|
||||||
connRunner
|
connRunner
|
||||||
}
|
}
|
||||||
|
@ -41,11 +39,9 @@ type quicConn interface {
|
||||||
EarlyConnection
|
EarlyConnection
|
||||||
earlyConnReady() <-chan struct{}
|
earlyConnReady() <-chan struct{}
|
||||||
handlePacket(receivedPacket)
|
handlePacket(receivedPacket)
|
||||||
GetVersion() protocol.VersionNumber
|
|
||||||
getPerspective() protocol.Perspective
|
|
||||||
run() error
|
run() error
|
||||||
destroy(error)
|
destroy(error)
|
||||||
shutdown()
|
closeWithTransportError(TransportErrorCode)
|
||||||
}
|
}
|
||||||
|
|
||||||
type zeroRTTQueue struct {
|
type zeroRTTQueue struct {
|
||||||
|
@ -98,10 +94,10 @@ type baseServer struct {
|
||||||
*logging.ConnectionTracer,
|
*logging.ConnectionTracer,
|
||||||
uint64,
|
uint64,
|
||||||
utils.Logger,
|
utils.Logger,
|
||||||
protocol.VersionNumber,
|
protocol.Version,
|
||||||
) quicConn
|
) quicConn
|
||||||
|
|
||||||
closeOnce sync.Once
|
closeMx sync.Mutex
|
||||||
errorChan chan struct{} // is closed when the server is closed
|
errorChan chan struct{} // is closed when the server is closed
|
||||||
closeErr error
|
closeErr error
|
||||||
running chan struct{} // closed as soon as run() returns
|
running chan struct{} // closed as soon as run() returns
|
||||||
|
@ -111,8 +107,9 @@ type baseServer struct {
|
||||||
connectionRefusedQueue chan rejectedPacket
|
connectionRefusedQueue chan rejectedPacket
|
||||||
retryQueue chan rejectedPacket
|
retryQueue chan rejectedPacket
|
||||||
|
|
||||||
connQueue chan quicConn
|
verifySourceAddress func(net.Addr) bool
|
||||||
connQueueLen int32 // to be used as an atomic
|
|
||||||
|
connQueue chan quicConn
|
||||||
|
|
||||||
tracer *logging.Tracer
|
tracer *logging.Tracer
|
||||||
|
|
||||||
|
@ -240,6 +237,7 @@ func newServer(
|
||||||
onClose func(),
|
onClose func(),
|
||||||
tokenGeneratorKey TokenGeneratorKey,
|
tokenGeneratorKey TokenGeneratorKey,
|
||||||
maxTokenAge time.Duration,
|
maxTokenAge time.Duration,
|
||||||
|
verifySourceAddress func(net.Addr) bool,
|
||||||
disableVersionNegotiation bool,
|
disableVersionNegotiation bool,
|
||||||
acceptEarly bool,
|
acceptEarly bool,
|
||||||
) *baseServer {
|
) *baseServer {
|
||||||
|
@ -249,9 +247,10 @@ func newServer(
|
||||||
config: config,
|
config: config,
|
||||||
tokenGenerator: handshake.NewTokenGenerator(tokenGeneratorKey),
|
tokenGenerator: handshake.NewTokenGenerator(tokenGeneratorKey),
|
||||||
maxTokenAge: maxTokenAge,
|
maxTokenAge: maxTokenAge,
|
||||||
|
verifySourceAddress: verifySourceAddress,
|
||||||
connIDGenerator: connIDGenerator,
|
connIDGenerator: connIDGenerator,
|
||||||
connHandler: connHandler,
|
connHandler: connHandler,
|
||||||
connQueue: make(chan quicConn),
|
connQueue: make(chan quicConn, protocol.MaxAcceptQueueSize),
|
||||||
errorChan: make(chan struct{}),
|
errorChan: make(chan struct{}),
|
||||||
running: make(chan struct{}),
|
running: make(chan struct{}),
|
||||||
receivedPackets: make(chan receivedPacket, protocol.MaxServerUnprocessedPackets),
|
receivedPackets: make(chan receivedPacket, protocol.MaxServerUnprocessedPackets),
|
||||||
|
@ -322,7 +321,6 @@ func (s *baseServer) accept(ctx context.Context) (quicConn, error) {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return nil, ctx.Err()
|
return nil, ctx.Err()
|
||||||
case conn := <-s.connQueue:
|
case conn := <-s.connQueue:
|
||||||
atomic.AddInt32(&s.connQueueLen, -1)
|
|
||||||
return conn, nil
|
return conn, nil
|
||||||
case <-s.errorChan:
|
case <-s.errorChan:
|
||||||
return nil, s.closeErr
|
return nil, s.closeErr
|
||||||
|
@ -335,15 +333,19 @@ func (s *baseServer) Close() error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *baseServer) close(e error, notifyOnClose bool) {
|
func (s *baseServer) close(e error, notifyOnClose bool) {
|
||||||
s.closeOnce.Do(func() {
|
s.closeMx.Lock()
|
||||||
s.closeErr = e
|
if s.closeErr != nil {
|
||||||
close(s.errorChan)
|
s.closeMx.Unlock()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.closeErr = e
|
||||||
|
close(s.errorChan)
|
||||||
|
<-s.running
|
||||||
|
s.closeMx.Unlock()
|
||||||
|
|
||||||
<-s.running
|
if notifyOnClose {
|
||||||
if notifyOnClose {
|
s.onClose()
|
||||||
s.onClose()
|
}
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Addr returns the server's network address
|
// Addr returns the server's network address
|
||||||
|
@ -542,10 +544,10 @@ func (s *baseServer) validateToken(token *handshake.Token, addr net.Addr) bool {
|
||||||
|
|
||||||
func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error {
|
func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error {
|
||||||
if len(hdr.Token) == 0 && hdr.DestConnectionID.Len() < protocol.MinConnectionIDLenInitial {
|
if len(hdr.Token) == 0 && hdr.DestConnectionID.Len() < protocol.MinConnectionIDLenInitial {
|
||||||
p.buffer.Release()
|
|
||||||
if s.tracer != nil && s.tracer.DroppedPacket != nil {
|
if s.tracer != nil && s.tracer.DroppedPacket != nil {
|
||||||
s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropUnexpectedPacket)
|
s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropUnexpectedPacket)
|
||||||
}
|
}
|
||||||
|
p.buffer.Release()
|
||||||
return errors.New("too short connection ID")
|
return errors.New("too short connection ID")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -558,8 +560,9 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
token *handshake.Token
|
token *handshake.Token
|
||||||
retrySrcConnID *protocol.ConnectionID
|
retrySrcConnID *protocol.ConnectionID
|
||||||
|
clientAddrVerified bool
|
||||||
)
|
)
|
||||||
origDestConnID := hdr.DestConnectionID
|
origDestConnID := hdr.DestConnectionID
|
||||||
if len(hdr.Token) > 0 {
|
if len(hdr.Token) > 0 {
|
||||||
|
@ -572,28 +575,30 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error
|
||||||
token = tok
|
token = tok
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if token != nil {
|
||||||
clientAddrIsValid := s.validateToken(token, p.remoteAddr)
|
clientAddrVerified = s.validateToken(token, p.remoteAddr)
|
||||||
if token != nil && !clientAddrIsValid {
|
if !clientAddrVerified {
|
||||||
// For invalid and expired non-retry tokens, we don't send an INVALID_TOKEN error.
|
// For invalid and expired non-retry tokens, we don't send an INVALID_TOKEN error.
|
||||||
// We just ignore them, and act as if there was no token on this packet at all.
|
// We just ignore them, and act as if there was no token on this packet at all.
|
||||||
// This also means we might send a Retry later.
|
// This also means we might send a Retry later.
|
||||||
if !token.IsRetryToken {
|
if !token.IsRetryToken {
|
||||||
token = nil
|
token = nil
|
||||||
} else {
|
} else {
|
||||||
// For Retry tokens, we send an INVALID_ERROR if
|
// For Retry tokens, we send an INVALID_ERROR if
|
||||||
// * the token is too old, or
|
// * the token is too old, or
|
||||||
// * the token is invalid, in case of a retry token.
|
// * the token is invalid, in case of a retry token.
|
||||||
select {
|
select {
|
||||||
case s.invalidTokenQueue <- rejectedPacket{receivedPacket: p, hdr: hdr}:
|
case s.invalidTokenQueue <- rejectedPacket{receivedPacket: p, hdr: hdr}:
|
||||||
default:
|
default:
|
||||||
// drop packet if we can't send out the INVALID_TOKEN packets fast enough
|
// drop packet if we can't send out the INVALID_TOKEN packets fast enough
|
||||||
p.buffer.Release()
|
p.buffer.Release()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if token == nil && s.config.RequireAddressValidation(p.remoteAddr) {
|
|
||||||
|
if token == nil && s.verifySourceAddress != nil && s.verifySourceAddress(p.remoteAddr) {
|
||||||
// Retry invalidates all 0-RTT packets sent.
|
// Retry invalidates all 0-RTT packets sent.
|
||||||
delete(s.zeroRTTQueues, hdr.DestConnectionID)
|
delete(s.zeroRTTQueues, hdr.DestConnectionID)
|
||||||
select {
|
select {
|
||||||
|
@ -605,121 +610,116 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if queueLen := atomic.LoadInt32(&s.connQueueLen); queueLen >= protocol.MaxAcceptQueueSize {
|
config := s.config
|
||||||
s.logger.Debugf("Rejecting new connection. Server currently busy. Accept queue length: %d (max %d)", queueLen, protocol.MaxAcceptQueueSize)
|
if s.config.GetConfigForClient != nil {
|
||||||
select {
|
conf, err := s.config.GetConfigForClient(&ClientHelloInfo{
|
||||||
case s.connectionRefusedQueue <- rejectedPacket{receivedPacket: p, hdr: hdr}:
|
RemoteAddr: p.remoteAddr,
|
||||||
default:
|
AddrVerified: clientAddrVerified,
|
||||||
// drop packet if we can't send out the CONNECTION_REFUSED fast enough
|
})
|
||||||
p.buffer.Release()
|
if err != nil {
|
||||||
|
s.logger.Debugf("Rejecting new connection due to GetConfigForClient callback")
|
||||||
|
delete(s.zeroRTTQueues, hdr.DestConnectionID)
|
||||||
|
select {
|
||||||
|
case s.connectionRefusedQueue <- rejectedPacket{receivedPacket: p, hdr: hdr}:
|
||||||
|
default:
|
||||||
|
// drop packet if we can't send out the CONNECTION_REFUSED fast enough
|
||||||
|
p.buffer.Release()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
return nil
|
config = populateConfig(conf)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var conn quicConn
|
||||||
|
tracingID := nextConnTracingID()
|
||||||
|
var tracer *logging.ConnectionTracer
|
||||||
|
if config.Tracer != nil {
|
||||||
|
// Use the same connection ID that is passed to the client's GetLogWriter callback.
|
||||||
|
connID := hdr.DestConnectionID
|
||||||
|
if origDestConnID.Len() > 0 {
|
||||||
|
connID = origDestConnID
|
||||||
|
}
|
||||||
|
tracer = config.Tracer(context.WithValue(context.Background(), ConnectionTracingKey, tracingID), protocol.PerspectiveServer, connID)
|
||||||
|
}
|
||||||
connID, err := s.connIDGenerator.GenerateConnectionID()
|
connID, err := s.connIDGenerator.GenerateConnectionID()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
s.logger.Debugf("Changing connection ID to %s.", connID)
|
s.logger.Debugf("Changing connection ID to %s.", connID)
|
||||||
var conn quicConn
|
conn = s.newConn(
|
||||||
tracingID := nextConnTracingID()
|
newSendConn(s.conn, p.remoteAddr, p.info, s.logger),
|
||||||
if added := s.connHandler.AddWithConnID(hdr.DestConnectionID, connID, func() (packetHandler, bool) {
|
s.connHandler,
|
||||||
config := s.config
|
origDestConnID,
|
||||||
if s.config.GetConfigForClient != nil {
|
retrySrcConnID,
|
||||||
conf, err := s.config.GetConfigForClient(&ClientHelloInfo{RemoteAddr: p.remoteAddr})
|
hdr.DestConnectionID,
|
||||||
if err != nil {
|
hdr.SrcConnectionID,
|
||||||
s.logger.Debugf("Rejecting new connection due to GetConfigForClient callback")
|
connID,
|
||||||
return nil, false
|
s.connIDGenerator,
|
||||||
}
|
s.connHandler.GetStatelessResetToken(connID),
|
||||||
config = populateConfig(conf)
|
config,
|
||||||
}
|
s.tlsConf,
|
||||||
var tracer *logging.ConnectionTracer
|
s.tokenGenerator,
|
||||||
if config.Tracer != nil {
|
clientAddrVerified,
|
||||||
// Use the same connection ID that is passed to the client's GetLogWriter callback.
|
tracer,
|
||||||
connID := hdr.DestConnectionID
|
tracingID,
|
||||||
if origDestConnID.Len() > 0 {
|
s.logger,
|
||||||
connID = origDestConnID
|
hdr.Version,
|
||||||
}
|
)
|
||||||
tracer = config.Tracer(context.WithValue(context.Background(), ConnectionTracingKey, tracingID), protocol.PerspectiveServer, connID)
|
conn.handlePacket(p)
|
||||||
}
|
// Adding the connection will fail if the client's chosen Destination Connection ID is already in use.
|
||||||
conn = s.newConn(
|
// This is very unlikely: Even if an attacker chooses a connection ID that's already in use,
|
||||||
newSendConn(s.conn, p.remoteAddr, p.info, s.logger),
|
// under normal circumstances the packet would just be routed to that connection.
|
||||||
s.connHandler,
|
// The only time this collision will occur if we receive the two Initial packets at the same time.
|
||||||
origDestConnID,
|
if added := s.connHandler.AddWithConnID(hdr.DestConnectionID, connID, conn); !added {
|
||||||
retrySrcConnID,
|
delete(s.zeroRTTQueues, hdr.DestConnectionID)
|
||||||
hdr.DestConnectionID,
|
conn.closeWithTransportError(qerr.ConnectionRefused)
|
||||||
hdr.SrcConnectionID,
|
|
||||||
connID,
|
|
||||||
s.connIDGenerator,
|
|
||||||
s.connHandler.GetStatelessResetToken(connID),
|
|
||||||
config,
|
|
||||||
s.tlsConf,
|
|
||||||
s.tokenGenerator,
|
|
||||||
clientAddrIsValid,
|
|
||||||
tracer,
|
|
||||||
tracingID,
|
|
||||||
s.logger,
|
|
||||||
hdr.Version,
|
|
||||||
)
|
|
||||||
conn.handlePacket(p)
|
|
||||||
|
|
||||||
if q, ok := s.zeroRTTQueues[hdr.DestConnectionID]; ok {
|
|
||||||
for _, p := range q.packets {
|
|
||||||
conn.handlePacket(p)
|
|
||||||
}
|
|
||||||
delete(s.zeroRTTQueues, hdr.DestConnectionID)
|
|
||||||
}
|
|
||||||
|
|
||||||
return conn, true
|
|
||||||
}); !added {
|
|
||||||
select {
|
|
||||||
case s.connectionRefusedQueue <- rejectedPacket{receivedPacket: p, hdr: hdr}:
|
|
||||||
default:
|
|
||||||
// drop packet if we can't send out the CONNECTION_REFUSED fast enough
|
|
||||||
p.buffer.Release()
|
|
||||||
}
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
// Pass queued 0-RTT to the newly established connection.
|
||||||
|
if q, ok := s.zeroRTTQueues[hdr.DestConnectionID]; ok {
|
||||||
|
for _, p := range q.packets {
|
||||||
|
conn.handlePacket(p)
|
||||||
|
}
|
||||||
|
delete(s.zeroRTTQueues, hdr.DestConnectionID)
|
||||||
|
}
|
||||||
|
|
||||||
go conn.run()
|
go conn.run()
|
||||||
go s.handleNewConn(conn)
|
go func() {
|
||||||
if conn == nil {
|
if completed := s.handleNewConn(conn); !completed {
|
||||||
p.buffer.Release()
|
return
|
||||||
return nil
|
}
|
||||||
}
|
|
||||||
|
select {
|
||||||
|
case s.connQueue <- conn:
|
||||||
|
default:
|
||||||
|
conn.closeWithTransportError(ConnectionRefused)
|
||||||
|
}
|
||||||
|
}()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *baseServer) handleNewConn(conn quicConn) {
|
func (s *baseServer) handleNewConn(conn quicConn) bool {
|
||||||
connCtx := conn.Context()
|
|
||||||
if s.acceptEarlyConns {
|
if s.acceptEarlyConns {
|
||||||
// wait until the early connection is ready, the handshake fails, or the server is closed
|
// wait until the early connection is ready, the handshake fails, or the server is closed
|
||||||
select {
|
select {
|
||||||
case <-s.errorChan:
|
case <-s.errorChan:
|
||||||
conn.destroy(&qerr.TransportError{ErrorCode: ConnectionRefused})
|
conn.closeWithTransportError(ConnectionRefused)
|
||||||
return
|
return false
|
||||||
|
case <-conn.Context().Done():
|
||||||
|
return false
|
||||||
case <-conn.earlyConnReady():
|
case <-conn.earlyConnReady():
|
||||||
case <-connCtx.Done():
|
return true
|
||||||
return
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// wait until the handshake is complete (or fails)
|
|
||||||
select {
|
|
||||||
case <-s.errorChan:
|
|
||||||
conn.destroy(&qerr.TransportError{ErrorCode: ConnectionRefused})
|
|
||||||
return
|
|
||||||
case <-conn.HandshakeComplete():
|
|
||||||
case <-connCtx.Done():
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
// wait until the handshake completes, fails, or the server is closed
|
||||||
atomic.AddInt32(&s.connQueueLen, 1)
|
|
||||||
select {
|
select {
|
||||||
case s.connQueue <- conn:
|
case <-s.errorChan:
|
||||||
// blocks until the connection is accepted
|
conn.closeWithTransportError(ConnectionRefused)
|
||||||
case <-connCtx.Done():
|
return false
|
||||||
atomic.AddInt32(&s.connQueueLen, -1)
|
case <-conn.Context().Done():
|
||||||
// don't pass connections that were already closed to Accept()
|
return false
|
||||||
|
case <-conn.HandshakeComplete():
|
||||||
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -60,7 +60,7 @@ type streamI interface {
|
||||||
// for sending
|
// for sending
|
||||||
hasData() bool
|
hasData() bool
|
||||||
handleStopSendingFrame(*wire.StopSendingFrame)
|
handleStopSendingFrame(*wire.StopSendingFrame)
|
||||||
popStreamFrame(maxBytes protocol.ByteCount, v protocol.VersionNumber) (ackhandler.StreamFrame, bool, bool)
|
popStreamFrame(maxBytes protocol.ByteCount, v protocol.Version) (ackhandler.StreamFrame, bool, bool)
|
||||||
updateSendWindow(protocol.ByteCount)
|
updateSendWindow(protocol.ByteCount)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -221,7 +221,7 @@ func (c *oobConn) ReadPacket() (receivedPacket, error) {
|
||||||
// unsigned int ipi6_ifindex; /* send/recv interface index */
|
// unsigned int ipi6_ifindex; /* send/recv interface index */
|
||||||
// };
|
// };
|
||||||
if len(body) == 20 {
|
if len(body) == 20 {
|
||||||
p.info.addr = netip.AddrFrom16(*(*[16]byte)(body[:16]))
|
p.info.addr = netip.AddrFrom16(*(*[16]byte)(body[:16])).Unmap()
|
||||||
p.info.ifIndex = binary.LittleEndian.Uint32(body[16:])
|
p.info.ifIndex = binary.LittleEndian.Uint32(body[16:])
|
||||||
} else {
|
} else {
|
||||||
invalidCmsgOnceV6.Do(func() {
|
invalidCmsgOnceV6.Do(func() {
|
||||||
|
|
|
@ -41,7 +41,8 @@ type Transport struct {
|
||||||
Conn net.PacketConn
|
Conn net.PacketConn
|
||||||
|
|
||||||
// The length of the connection ID in bytes.
|
// The length of the connection ID in bytes.
|
||||||
// It can be 0, or any value between 4 and 18.
|
// It can be any value between 1 and 20.
|
||||||
|
// Due to the increased risk of collisions, it is not recommended to use connection IDs shorter than 4 bytes.
|
||||||
// If unset, a 4 byte connection ID will be used.
|
// If unset, a 4 byte connection ID will be used.
|
||||||
ConnectionIDLength int
|
ConnectionIDLength int
|
||||||
|
|
||||||
|
@ -77,7 +78,19 @@ type Transport struct {
|
||||||
// It has no effect for clients.
|
// It has no effect for clients.
|
||||||
DisableVersionNegotiationPackets bool
|
DisableVersionNegotiationPackets bool
|
||||||
|
|
||||||
|
// VerifySourceAddress decides if a connection attempt originating from unvalidated source
|
||||||
|
// addresses first needs to go through source address validation using QUIC's Retry mechanism,
|
||||||
|
// as described in RFC 9000 section 8.1.2.
|
||||||
|
// Note that the address passed to this callback is unvalidated, and might be spoofed in case
|
||||||
|
// of an attack.
|
||||||
|
// Validating the source address adds one additional network roundtrip to the handshake,
|
||||||
|
// and should therefore only be used if a suspiciously high number of incoming connection is recorded.
|
||||||
|
// For most use cases, wrapping the Allow function of a rate.Limiter will be a reasonable
|
||||||
|
// implementation of this callback (negating its return value).
|
||||||
|
VerifySourceAddress func(net.Addr) bool
|
||||||
|
|
||||||
// A Tracer traces events that don't belong to a single QUIC connection.
|
// A Tracer traces events that don't belong to a single QUIC connection.
|
||||||
|
// Tracer.Close is called when the transport is closed.
|
||||||
Tracer *logging.Tracer
|
Tracer *logging.Tracer
|
||||||
|
|
||||||
handlerMap packetHandlerManager
|
handlerMap packetHandlerManager
|
||||||
|
@ -147,7 +160,7 @@ func (t *Transport) createServer(tlsConf *tls.Config, conf *Config, allow0RTT bo
|
||||||
if t.server != nil {
|
if t.server != nil {
|
||||||
return nil, errListenerAlreadySet
|
return nil, errListenerAlreadySet
|
||||||
}
|
}
|
||||||
conf = populateServerConfig(conf)
|
conf = populateConfig(conf)
|
||||||
if err := t.init(false); err != nil {
|
if err := t.init(false); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -161,6 +174,7 @@ func (t *Transport) createServer(tlsConf *tls.Config, conf *Config, allow0RTT bo
|
||||||
t.closeServer,
|
t.closeServer,
|
||||||
*t.TokenGeneratorKey,
|
*t.TokenGeneratorKey,
|
||||||
t.MaxTokenAge,
|
t.MaxTokenAge,
|
||||||
|
t.VerifySourceAddress,
|
||||||
t.DisableVersionNegotiationPackets,
|
t.DisableVersionNegotiationPackets,
|
||||||
allow0RTT,
|
allow0RTT,
|
||||||
)
|
)
|
||||||
|
@ -323,6 +337,9 @@ func (t *Transport) close(e error) {
|
||||||
if t.server != nil {
|
if t.server != nil {
|
||||||
t.server.close(e, false)
|
t.server.close(e, false)
|
||||||
}
|
}
|
||||||
|
if t.Tracer != nil && t.Tracer.Close != nil {
|
||||||
|
t.Tracer.Close()
|
||||||
|
}
|
||||||
t.closed = true
|
t.closed = true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -379,13 +396,21 @@ func (t *Transport) handlePacket(p receivedPacket) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if isStatelessReset := t.maybeHandleStatelessReset(p.data); isStatelessReset {
|
// If there's a connection associated with the connection ID, pass the packet there.
|
||||||
return
|
|
||||||
}
|
|
||||||
if handler, ok := t.handlerMap.Get(connID); ok {
|
if handler, ok := t.handlerMap.Get(connID); ok {
|
||||||
handler.handlePacket(p)
|
handler.handlePacket(p)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
// RFC 9000 section 10.3.1 requires that the stateless reset detection logic is run for both
|
||||||
|
// packets that cannot be associated with any connections, and for packets that can't be decrypted.
|
||||||
|
// We deviate from the RFC and ignore the latter: If a packet's connection ID is associated with an
|
||||||
|
// existing connection, it is dropped there if if it can't be decrypted.
|
||||||
|
// Stateless resets use random connection IDs, and at reasonable connection ID lengths collisions are
|
||||||
|
// exceedingly rare. In the unlikely event that a stateless reset is misrouted to an existing connection,
|
||||||
|
// it is to be expected that the next stateless reset will be correctly detected.
|
||||||
|
if isStatelessReset := t.maybeHandleStatelessReset(p.data); isStatelessReset {
|
||||||
|
return
|
||||||
|
}
|
||||||
if !wire.IsLongHeaderPacket(p.data[0]) {
|
if !wire.IsLongHeaderPacket(p.data[0]) {
|
||||||
t.maybeSendStatelessReset(p)
|
t.maybeSendStatelessReset(p)
|
||||||
return
|
return
|
||||||
|
|
|
@ -1,37 +0,0 @@
|
||||||
# This is the official list of people who can contribute (and typically
|
|
||||||
# have contributed) code to the gomock repository.
|
|
||||||
# The AUTHORS file lists the copyright holders; this file
|
|
||||||
# lists people. For example, Google employees are listed here
|
|
||||||
# but not in AUTHORS, because Google holds the copyright.
|
|
||||||
#
|
|
||||||
# The submission process automatically checks to make sure
|
|
||||||
# that people submitting code are listed in this file (by email address).
|
|
||||||
#
|
|
||||||
# Names should be added to this file only after verifying that
|
|
||||||
# the individual or the individual's organization has agreed to
|
|
||||||
# the appropriate Contributor License Agreement, found here:
|
|
||||||
#
|
|
||||||
# http://code.google.com/legal/individual-cla-v1.0.html
|
|
||||||
# http://code.google.com/legal/corporate-cla-v1.0.html
|
|
||||||
#
|
|
||||||
# The agreement for individuals can be filled out on the web.
|
|
||||||
#
|
|
||||||
# When adding J Random Contributor's name to this file,
|
|
||||||
# either J's name or J's organization's name should be
|
|
||||||
# added to the AUTHORS file, depending on whether the
|
|
||||||
# individual or corporate CLA was used.
|
|
||||||
|
|
||||||
# Names should be added to this file like so:
|
|
||||||
# Name <email address>
|
|
||||||
#
|
|
||||||
# An entry with two email addresses specifies that the
|
|
||||||
# first address should be used in the submit logs and
|
|
||||||
# that the second address should be recognized as the
|
|
||||||
# same person when interacting with Rietveld.
|
|
||||||
|
|
||||||
# Please keep the list sorted.
|
|
||||||
|
|
||||||
Aaron Jacobs <jacobsa@google.com> <aaronjjacobs@gmail.com>
|
|
||||||
Alex Reece <awreece@gmail.com>
|
|
||||||
David Symonds <dsymonds@golang.org>
|
|
||||||
Ryan Barrett <ryanb@google.com>
|
|
|
@ -11,8 +11,10 @@
|
||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"go/ast"
|
"go/ast"
|
||||||
|
"go/token"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"go.uber.org/mock/mockgen/model"
|
"go.uber.org/mock/mockgen/model"
|
||||||
|
@ -98,6 +100,16 @@ func (p *fileParser) parseGenericMethod(field *ast.Field, it *namedInterface, if
|
||||||
case *ast.IndexListExpr:
|
case *ast.IndexListExpr:
|
||||||
indices = v.Indices
|
indices = v.Indices
|
||||||
typ = v.X
|
typ = v.X
|
||||||
|
case *ast.UnaryExpr:
|
||||||
|
if v.Op == token.TILDE {
|
||||||
|
return nil, errConstraintInterface
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("~T may only appear as constraint for %T", field.Type)
|
||||||
|
case *ast.BinaryExpr:
|
||||||
|
if v.Op == token.OR {
|
||||||
|
return nil, errConstraintInterface
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("A|B may only appear as constraint for %T", field.Type)
|
||||||
default:
|
default:
|
||||||
return nil, fmt.Errorf("don't know how to mock method of type %T", field.Type)
|
return nil, fmt.Errorf("don't know how to mock method of type %T", field.Type)
|
||||||
}
|
}
|
||||||
|
@ -114,3 +126,5 @@ func (p *fileParser) parseGenericMethod(field *ast.Field, it *namedInterface, if
|
||||||
|
|
||||||
return p.parseMethod(nf, it, iface, pkg, tps)
|
return p.parseMethod(nf, it, iface, pkg, tps)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var errConstraintInterface = errors.New("interface contains constraints")
|
||||||
|
|
|
@ -31,6 +31,7 @@ import (
|
||||||
"os/exec"
|
"os/exec"
|
||||||
"path"
|
"path"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
"runtime"
|
||||||
"sort"
|
"sort"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
@ -314,8 +315,14 @@ func (g *generator) Generate(pkg *model.Package, outputPkgName string, outputPac
|
||||||
}
|
}
|
||||||
g.p("//")
|
g.p("//")
|
||||||
g.p("// Generated by this command:")
|
g.p("// Generated by this command:")
|
||||||
|
g.p("//")
|
||||||
// only log the name of the executable, not the full path
|
// only log the name of the executable, not the full path
|
||||||
g.p("// %v", strings.Join(append([]string{filepath.Base(os.Args[0])}, os.Args[1:]...), " "))
|
name := filepath.Base(os.Args[0])
|
||||||
|
if runtime.GOOS == "windows" {
|
||||||
|
name = strings.TrimSuffix(name, ".exe")
|
||||||
|
}
|
||||||
|
g.p("//\t%v", strings.Join(append([]string{name}, os.Args[1:]...), " "))
|
||||||
|
g.p("//")
|
||||||
|
|
||||||
// Get all required imports, and generate unique names for them all.
|
// Get all required imports, and generate unique names for them all.
|
||||||
im := pkg.Imports()
|
im := pkg.Imports()
|
||||||
|
@ -371,7 +378,7 @@ func (g *generator) Generate(pkg *model.Package, outputPkgName string, outputPac
|
||||||
}
|
}
|
||||||
|
|
||||||
i := 0
|
i := 0
|
||||||
for localNames[pkgName] || token.Lookup(pkgName).IsKeyword() {
|
for localNames[pkgName] || token.Lookup(pkgName).IsKeyword() || pkgName == "any" {
|
||||||
pkgName = base + strconv.Itoa(i)
|
pkgName = base + strconv.Itoa(i)
|
||||||
i++
|
i++
|
||||||
}
|
}
|
||||||
|
@ -386,6 +393,10 @@ func (g *generator) Generate(pkg *model.Package, outputPkgName string, outputPac
|
||||||
}
|
}
|
||||||
|
|
||||||
if *writePkgComment {
|
if *writePkgComment {
|
||||||
|
// Ensure there's an empty line before the package to follow the recommendations:
|
||||||
|
// https://github.com/golang/go/wiki/CodeReviewComments#package-comments
|
||||||
|
g.p("")
|
||||||
|
|
||||||
g.p("// Package %v is a generated GoMock package.", outputPkgName)
|
g.p("// Package %v is a generated GoMock package.", outputPkgName)
|
||||||
}
|
}
|
||||||
g.p("package %v", outputPkgName)
|
g.p("package %v", outputPkgName)
|
||||||
|
@ -508,7 +519,7 @@ func (g *generator) GenerateMockMethods(mockType string, intf *model.Interface,
|
||||||
g.p("")
|
g.p("")
|
||||||
_ = g.GenerateMockMethod(mockType, m, pkgOverride, shortTp)
|
_ = g.GenerateMockMethod(mockType, m, pkgOverride, shortTp)
|
||||||
g.p("")
|
g.p("")
|
||||||
_ = g.GenerateMockRecorderMethod(intf, mockType, m, shortTp, typed)
|
_ = g.GenerateMockRecorderMethod(intf, m, shortTp, typed)
|
||||||
if typed {
|
if typed {
|
||||||
g.p("")
|
g.p("")
|
||||||
_ = g.GenerateMockReturnCallMethod(intf, m, pkgOverride, longTp, shortTp)
|
_ = g.GenerateMockReturnCallMethod(intf, m, pkgOverride, longTp, shortTp)
|
||||||
|
@ -596,7 +607,8 @@ func (g *generator) GenerateMockMethod(mockType string, m *model.Method, pkgOver
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *generator) GenerateMockRecorderMethod(intf *model.Interface, mockType string, m *model.Method, shortTp string, typed bool) error {
|
func (g *generator) GenerateMockRecorderMethod(intf *model.Interface, m *model.Method, shortTp string, typed bool) error {
|
||||||
|
mockType := g.mockName(intf.Name)
|
||||||
argNames := g.getArgNames(m, true)
|
argNames := g.getArgNames(m, true)
|
||||||
|
|
||||||
var argString string
|
var argString string
|
||||||
|
@ -621,7 +633,7 @@ func (g *generator) GenerateMockRecorderMethod(intf *model.Interface, mockType s
|
||||||
|
|
||||||
g.p("// %v indicates an expected call of %v.", m.Name, m.Name)
|
g.p("// %v indicates an expected call of %v.", m.Name, m.Name)
|
||||||
if typed {
|
if typed {
|
||||||
g.p("func (%s *%vMockRecorder%v) %v(%v) *%s%sCall%s {", idRecv, mockType, shortTp, m.Name, argString, intf.Name, m.Name, shortTp)
|
g.p("func (%s *%vMockRecorder%v) %v(%v) *%s%sCall%s {", idRecv, mockType, shortTp, m.Name, argString, mockType, m.Name, shortTp)
|
||||||
} else {
|
} else {
|
||||||
g.p("func (%s *%vMockRecorder%v) %v(%v) *gomock.Call {", idRecv, mockType, shortTp, m.Name, argString)
|
g.p("func (%s *%vMockRecorder%v) %v(%v) *gomock.Call {", idRecv, mockType, shortTp, m.Name, argString)
|
||||||
}
|
}
|
||||||
|
@ -650,7 +662,7 @@ func (g *generator) GenerateMockRecorderMethod(intf *model.Interface, mockType s
|
||||||
}
|
}
|
||||||
if typed {
|
if typed {
|
||||||
g.p(`call := %s.mock.ctrl.RecordCallWithMethodType(%s.mock, "%s", reflect.TypeOf((*%s%s)(nil).%s)%s)`, idRecv, idRecv, m.Name, mockType, shortTp, m.Name, callArgs)
|
g.p(`call := %s.mock.ctrl.RecordCallWithMethodType(%s.mock, "%s", reflect.TypeOf((*%s%s)(nil).%s)%s)`, idRecv, idRecv, m.Name, mockType, shortTp, m.Name, callArgs)
|
||||||
g.p(`return &%s%sCall%s{Call: call}`, intf.Name, m.Name, shortTp)
|
g.p(`return &%s%sCall%s{Call: call}`, mockType, m.Name, shortTp)
|
||||||
} else {
|
} else {
|
||||||
g.p(`return %s.mock.ctrl.RecordCallWithMethodType(%s.mock, "%s", reflect.TypeOf((*%s%s)(nil).%s)%s)`, idRecv, idRecv, m.Name, mockType, shortTp, m.Name, callArgs)
|
g.p(`return %s.mock.ctrl.RecordCallWithMethodType(%s.mock, "%s", reflect.TypeOf((*%s%s)(nil).%s)%s)`, idRecv, idRecv, m.Name, mockType, shortTp, m.Name, callArgs)
|
||||||
}
|
}
|
||||||
|
@ -661,6 +673,7 @@ func (g *generator) GenerateMockRecorderMethod(intf *model.Interface, mockType s
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *generator) GenerateMockReturnCallMethod(intf *model.Interface, m *model.Method, pkgOverride, longTp, shortTp string) error {
|
func (g *generator) GenerateMockReturnCallMethod(intf *model.Interface, m *model.Method, pkgOverride, longTp, shortTp string) error {
|
||||||
|
mockType := g.mockName(intf.Name)
|
||||||
argNames := g.getArgNames(m, true /* in */)
|
argNames := g.getArgNames(m, true /* in */)
|
||||||
retNames := g.getArgNames(m, false /* out */)
|
retNames := g.getArgNames(m, false /* out */)
|
||||||
argTypes := g.getArgTypes(m, pkgOverride, true /* in */)
|
argTypes := g.getArgTypes(m, pkgOverride, true /* in */)
|
||||||
|
@ -683,10 +696,10 @@ func (g *generator) GenerateMockReturnCallMethod(intf *model.Interface, m *model
|
||||||
ia := newIdentifierAllocator(argNames)
|
ia := newIdentifierAllocator(argNames)
|
||||||
idRecv := ia.allocateIdentifier("c")
|
idRecv := ia.allocateIdentifier("c")
|
||||||
|
|
||||||
recvStructName := intf.Name + m.Name
|
recvStructName := mockType + m.Name
|
||||||
|
|
||||||
g.p("// %s%sCall wrap *gomock.Call", intf.Name, m.Name)
|
g.p("// %s%sCall wrap *gomock.Call", mockType, m.Name)
|
||||||
g.p("type %s%sCall%s struct{", intf.Name, m.Name, longTp)
|
g.p("type %s%sCall%s struct{", mockType, m.Name, longTp)
|
||||||
g.in()
|
g.in()
|
||||||
g.p("*gomock.Call")
|
g.p("*gomock.Call")
|
||||||
g.out()
|
g.out()
|
||||||
|
|
|
@ -305,7 +305,7 @@ type PredeclaredType string
|
||||||
func (pt PredeclaredType) String(map[string]string, string) string { return string(pt) }
|
func (pt PredeclaredType) String(map[string]string, string) string { return string(pt) }
|
||||||
func (pt PredeclaredType) addImports(map[string]bool) {}
|
func (pt PredeclaredType) addImports(map[string]bool) {}
|
||||||
|
|
||||||
// TypeParametersType contains type paramters for a NamedType.
|
// TypeParametersType contains type parameters for a NamedType.
|
||||||
type TypeParametersType struct {
|
type TypeParametersType struct {
|
||||||
TypeParameters []Type
|
TypeParameters []Type
|
||||||
}
|
}
|
||||||
|
|
|
@ -232,6 +232,9 @@ func (p *fileParser) parseFile(importPath string, file *ast.File) (*model.Packag
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
i, err := p.parseInterface(ni.name.String(), importPath, ni)
|
i, err := p.parseInterface(ni.name.String(), importPath, ni)
|
||||||
|
if errors.Is(err, errConstraintInterface) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -187,6 +187,7 @@ type reflectData struct {
|
||||||
// gob encoding of a model.Package to standard output.
|
// gob encoding of a model.Package to standard output.
|
||||||
// JSON doesn't work because of the model.Type interface.
|
// JSON doesn't work because of the model.Type interface.
|
||||||
var reflectProgram = template.Must(template.New("program").Parse(`
|
var reflectProgram = template.Must(template.New("program").Parse(`
|
||||||
|
// Code generated by MockGen. DO NOT EDIT.
|
||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
|
|
@ -230,7 +230,7 @@ github.com/prometheus/common/model
|
||||||
github.com/prometheus/procfs
|
github.com/prometheus/procfs
|
||||||
github.com/prometheus/procfs/internal/fs
|
github.com/prometheus/procfs/internal/fs
|
||||||
github.com/prometheus/procfs/internal/util
|
github.com/prometheus/procfs/internal/util
|
||||||
# github.com/quic-go/quic-go v0.40.1-0.20240101045026-22b7f7744eb6
|
# github.com/quic-go/quic-go v0.42.0
|
||||||
## explicit; go 1.21
|
## explicit; go 1.21
|
||||||
github.com/quic-go/quic-go
|
github.com/quic-go/quic-go
|
||||||
github.com/quic-go/quic-go/internal/ackhandler
|
github.com/quic-go/quic-go/internal/ackhandler
|
||||||
|
@ -313,7 +313,7 @@ go.opentelemetry.io/proto/otlp/trace/v1
|
||||||
go.uber.org/automaxprocs/internal/cgroups
|
go.uber.org/automaxprocs/internal/cgroups
|
||||||
go.uber.org/automaxprocs/internal/runtime
|
go.uber.org/automaxprocs/internal/runtime
|
||||||
go.uber.org/automaxprocs/maxprocs
|
go.uber.org/automaxprocs/maxprocs
|
||||||
# go.uber.org/mock v0.3.0
|
# go.uber.org/mock v0.4.0
|
||||||
## explicit; go 1.20
|
## explicit; go 1.20
|
||||||
go.uber.org/mock/mockgen
|
go.uber.org/mock/mockgen
|
||||||
go.uber.org/mock/mockgen/model
|
go.uber.org/mock/mockgen/model
|
||||||
|
|
Loading…
Reference in New Issue