TUN-8371: Bump quic-go to v0.42.0
## Summary We discovered that we were being impacted by a bug in quic-go, that could create deadlocks and not close connections. This commit bumps quic-go to the version that contains the fix to prevent that from happening.
This commit is contained in:
parent
5e5f2f4d8c
commit
84833011ec
4
go.mod
4
go.mod
|
@ -25,7 +25,7 @@ require (
|
|||
github.com/pkg/errors v0.9.1
|
||||
github.com/prometheus/client_golang v1.13.0
|
||||
github.com/prometheus/client_model v0.2.0
|
||||
github.com/quic-go/quic-go v0.40.1-0.20240101045026-22b7f7744eb6
|
||||
github.com/quic-go/quic-go v0.42.0
|
||||
github.com/rs/zerolog v1.20.0
|
||||
github.com/stretchr/testify v1.8.4
|
||||
github.com/urfave/cli/v2 v2.3.0
|
||||
|
@ -84,7 +84,7 @@ require (
|
|||
github.com/prometheus/procfs v0.8.0 // indirect
|
||||
github.com/russross/blackfriday/v2 v2.1.0 // indirect
|
||||
go.opentelemetry.io/otel/metric v1.21.0 // indirect
|
||||
go.uber.org/mock v0.3.0 // indirect
|
||||
go.uber.org/mock v0.4.0 // indirect
|
||||
golang.org/x/exp v0.0.0-20221205204356-47842c84f3db // indirect
|
||||
golang.org/x/mod v0.11.0 // indirect
|
||||
golang.org/x/oauth2 v0.13.0 // indirect
|
||||
|
|
10
go.sum
10
go.sum
|
@ -322,8 +322,8 @@ github.com/prometheus/procfs v0.6.0/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1
|
|||
github.com/prometheus/procfs v0.7.3/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1xBZuNvfVA=
|
||||
github.com/prometheus/procfs v0.8.0 h1:ODq8ZFEaYeCaZOJlZZdJA2AbQR98dSHSM1KW/You5mo=
|
||||
github.com/prometheus/procfs v0.8.0/go.mod h1:z7EfXMXOkbkqb9IINtpCn86r/to3BnA0uaxHdg830/4=
|
||||
github.com/quic-go/quic-go v0.40.1-0.20240101045026-22b7f7744eb6 h1:OI4WiysowCcxLtcZMGBZildo12di3ljcMN4vWdUQpoU=
|
||||
github.com/quic-go/quic-go v0.40.1-0.20240101045026-22b7f7744eb6/go.mod h1:qCkNjqczPEvgsOnxZ0eCD14lv+B2LHlFAB++CNOh9hA=
|
||||
github.com/quic-go/quic-go v0.42.0 h1:uSfdap0eveIl8KXnipv9K7nlwZ5IqLlYOpJ58u5utpM=
|
||||
github.com/quic-go/quic-go v0.42.0/go.mod h1:132kz4kL3F9vxhW3CtQJLDVwcFe5wdWeJXXijhsO57M=
|
||||
github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4=
|
||||
github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ=
|
||||
github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog=
|
||||
|
@ -381,8 +381,8 @@ go.opentelemetry.io/proto/otlp v1.0.0 h1:T0TX0tmXU8a3CbNXzEKGeU5mIVOdf0oykP+u2lI
|
|||
go.opentelemetry.io/proto/otlp v1.0.0/go.mod h1:Sy6pihPLfYHkr3NkUbEhGHFhINUSI/v80hjKIs5JXpM=
|
||||
go.uber.org/automaxprocs v1.4.0 h1:CpDZl6aOlLhReez+8S3eEotD7Jx0Os++lemPlMULQP0=
|
||||
go.uber.org/automaxprocs v1.4.0/go.mod h1:/mTEdr7LvHhs0v7mjdxDreTz1OG5zdZGqgOnhWiR/+Q=
|
||||
go.uber.org/mock v0.3.0 h1:3mUxI1No2/60yUYax92Pt8eNOEecx2D3lcXZh2NEZJo=
|
||||
go.uber.org/mock v0.3.0/go.mod h1:a6FSlNadKUHUa9IP5Vyt1zh4fC7uAwxMutEAscFbkZc=
|
||||
go.uber.org/mock v0.4.0 h1:VcM4ZOtdbR4f6VXfiOpwpVJDL6lCReaZ6mw31wqh7KU=
|
||||
go.uber.org/mock v0.4.0/go.mod h1:a6FSlNadKUHUa9IP5Vyt1zh4fC7uAwxMutEAscFbkZc=
|
||||
golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
|
||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||
golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||
|
@ -552,6 +552,8 @@ golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
|
|||
golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
||||
golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
||||
golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
||||
golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk=
|
||||
golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY=
|
||||
|
|
|
@ -2,16 +2,6 @@ run:
|
|||
skip-files:
|
||||
- internal/handshake/cipher_suite.go
|
||||
linters-settings:
|
||||
depguard:
|
||||
rules:
|
||||
qtls:
|
||||
list-mode: lax
|
||||
files:
|
||||
- "!internal/qtls/**"
|
||||
- "$all"
|
||||
deny:
|
||||
- pkg: github.com/quic-go/qtls-go1-20
|
||||
desc: "importing qtls only allowed in internal/qtls"
|
||||
misspell:
|
||||
ignore-words:
|
||||
- ect
|
||||
|
@ -20,7 +10,6 @@ linters:
|
|||
disable-all: true
|
||||
enable:
|
||||
- asciicheck
|
||||
- depguard
|
||||
- exhaustive
|
||||
- exportloopref
|
||||
- goimports
|
||||
|
|
|
@ -183,26 +183,20 @@ quic-go logs a wide range of events defined in [draft-ietf-quic-qlog-quic-events
|
|||
|
||||
qlog files can be processed by a number of 3rd-party tools. [qviz](https://qvis.quictools.info/) has proven very useful for debugging all kinds of QUIC connection failures.
|
||||
|
||||
qlog is activated by setting a `Tracer` callback on the `Config`. It is called as soon as quic-go decides to starts the QUIC handshake on a new connection.
|
||||
A useful implementation of this callback could look like this:
|
||||
qlog can be activated by setting the `Tracer` callback on the `Config`. It is called as soon as quic-go decides to start the QUIC handshake on a new connection.
|
||||
`qlog.DefaultTracer` provides a tracer implementation which writes qlog files to a directory specified by the `QLOGDIR` environment variable, if set.
|
||||
The default qlog tracer can be used like this:
|
||||
```go
|
||||
quic.Config{
|
||||
Tracer: func(ctx context.Context, p logging.Perspective, connID quic.ConnectionID) *logging.ConnectionTracer {
|
||||
role := "server"
|
||||
if p == logging.PerspectiveClient {
|
||||
role = "client"
|
||||
}
|
||||
filename := fmt.Sprintf("./log_%s_%s.qlog", connID, role)
|
||||
f, err := os.Create(filename)
|
||||
// handle the error
|
||||
return qlog.NewConnectionTracer(f, p, connID)
|
||||
}
|
||||
Tracer: qlog.DefaultTracer,
|
||||
}
|
||||
```
|
||||
|
||||
This implementation of the callback creates a new qlog file in the current directory named `log_<client / server>_<QUIC connection ID>.qlog`.
|
||||
This example creates a new qlog file under `<QLOGDIR>/<Original Destination Connection ID>_<Vantage Point>.qlog`, e.g. `qlogs/2e0407da_client.qlog`.
|
||||
|
||||
|
||||
For custom qlog behavior, `qlog.NewConnectionTracer` can be used.
|
||||
|
||||
## Using HTTP/3
|
||||
|
||||
### As a server
|
||||
|
@ -232,11 +226,13 @@ http.Client{
|
|||
| [algernon](https://github.com/xyproto/algernon) | Small self-contained pure-Go web server with Lua, Markdown, HTTP/2, QUIC, Redis and PostgreSQL support | ![GitHub Repo stars](https://img.shields.io/github/stars/xyproto/algernon?style=flat-square) |
|
||||
| [caddy](https://github.com/caddyserver/caddy/) | Fast, multi-platform web server with automatic HTTPS | ![GitHub Repo stars](https://img.shields.io/github/stars/caddyserver/caddy?style=flat-square) |
|
||||
| [cloudflared](https://github.com/cloudflare/cloudflared) | A tunneling daemon that proxies traffic from the Cloudflare network to your origins | ![GitHub Repo stars](https://img.shields.io/github/stars/cloudflare/cloudflared?style=flat-square) |
|
||||
| [frp](https://github.com/fatedier/frp) | A fast reverse proxy to help you expose a local server behind a NAT or firewall to the internet | ![GitHub Repo stars](https://img.shields.io/github/stars/fatedier/frp?style=flat-square) |
|
||||
| [go-libp2p](https://github.com/libp2p/go-libp2p) | libp2p implementation in Go, powering [Kubo](https://github.com/ipfs/kubo) (IPFS) and [Lotus](https://github.com/filecoin-project/lotus) (Filecoin), among others | ![GitHub Repo stars](https://img.shields.io/github/stars/libp2p/go-libp2p?style=flat-square) |
|
||||
| [gost](https://github.com/go-gost/gost) | A simple security tunnel written in Go | ![GitHub Repo stars](https://img.shields.io/github/stars/go-gost/gost?style=flat-square) |
|
||||
| [Hysteria](https://github.com/apernet/hysteria) | A powerful, lightning fast and censorship resistant proxy | ![GitHub Repo stars](https://img.shields.io/github/stars/apernet/hysteria?style=flat-square) |
|
||||
| [Mercure](https://github.com/dunglas/mercure) | An open, easy, fast, reliable and battery-efficient solution for real-time communications | ![GitHub Repo stars](https://img.shields.io/github/stars/dunglas/mercure?style=flat-square) |
|
||||
| [OONI Probe](https://github.com/ooni/probe-cli) | Next generation OONI Probe. Library and CLI tool. | ![GitHub Repo stars](https://img.shields.io/github/stars/ooni/probe-cli?style=flat-square) |
|
||||
| [RoadRunner](https://github.com/roadrunner-server/roadrunner) | High-performance PHP application server, process manager written in Go and powered with plugins | ![GitHub Repo stars](https://img.shields.io/github/stars/roadrunner-server/roadrunner?style=flat-square) |
|
||||
| [syncthing](https://github.com/syncthing/syncthing/) | Open Source Continuous File Synchronization | ![GitHub Repo stars](https://img.shields.io/github/stars/syncthing/syncthing?style=flat-square) |
|
||||
| [traefik](https://github.com/traefik/traefik) | The Cloud Native Application Proxy | ![GitHub Repo stars](https://img.shields.io/github/stars/traefik/traefik?style=flat-square) |
|
||||
| [v2ray-core](https://github.com/v2fly/v2ray-core) | A platform for building proxies to bypass network restrictions | ![GitHub Repo stars](https://img.shields.io/github/stars/v2fly/v2ray-core?style=flat-square) |
|
||||
|
|
|
@ -28,7 +28,7 @@ type client struct {
|
|||
|
||||
initialPacketNumber protocol.PacketNumber
|
||||
hasNegotiatedVersion bool
|
||||
version protocol.VersionNumber
|
||||
version protocol.Version
|
||||
|
||||
handshakeChan chan struct{}
|
||||
|
||||
|
@ -232,7 +232,7 @@ func (c *client) dial(ctx context.Context) error {
|
|||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
c.conn.shutdown()
|
||||
c.conn.destroy(nil)
|
||||
return context.Cause(ctx)
|
||||
case err := <-errorChan:
|
||||
return err
|
||||
|
|
|
@ -4,7 +4,6 @@ import (
|
|||
"math/bits"
|
||||
"net"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
"github.com/quic-go/quic-go/internal/utils"
|
||||
)
|
||||
|
||||
|
@ -12,9 +11,8 @@ import (
|
|||
// When receiving packets for such a connection, we need to retransmit the packet containing the CONNECTION_CLOSE frame,
|
||||
// with an exponential backoff.
|
||||
type closedLocalConn struct {
|
||||
counter uint32
|
||||
perspective protocol.Perspective
|
||||
logger utils.Logger
|
||||
counter uint32
|
||||
logger utils.Logger
|
||||
|
||||
sendPacket func(net.Addr, packetInfo)
|
||||
}
|
||||
|
@ -22,11 +20,10 @@ type closedLocalConn struct {
|
|||
var _ packetHandler = &closedLocalConn{}
|
||||
|
||||
// newClosedLocalConn creates a new closedLocalConn and runs it.
|
||||
func newClosedLocalConn(sendPacket func(net.Addr, packetInfo), pers protocol.Perspective, logger utils.Logger) packetHandler {
|
||||
func newClosedLocalConn(sendPacket func(net.Addr, packetInfo), logger utils.Logger) packetHandler {
|
||||
return &closedLocalConn{
|
||||
sendPacket: sendPacket,
|
||||
perspective: pers,
|
||||
logger: logger,
|
||||
sendPacket: sendPacket,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -41,24 +38,20 @@ func (c *closedLocalConn) handlePacket(p receivedPacket) {
|
|||
c.sendPacket(p.remoteAddr, p.info)
|
||||
}
|
||||
|
||||
func (c *closedLocalConn) shutdown() {}
|
||||
func (c *closedLocalConn) destroy(error) {}
|
||||
func (c *closedLocalConn) getPerspective() protocol.Perspective { return c.perspective }
|
||||
func (c *closedLocalConn) destroy(error) {}
|
||||
func (c *closedLocalConn) closeWithTransportError(TransportErrorCode) {}
|
||||
|
||||
// A closedRemoteConn is a connection that was closed remotely.
|
||||
// For such a connection, we might receive reordered packets that were sent before the CONNECTION_CLOSE.
|
||||
// We can just ignore those packets.
|
||||
type closedRemoteConn struct {
|
||||
perspective protocol.Perspective
|
||||
}
|
||||
type closedRemoteConn struct{}
|
||||
|
||||
var _ packetHandler = &closedRemoteConn{}
|
||||
|
||||
func newClosedRemoteConn(pers protocol.Perspective) packetHandler {
|
||||
return &closedRemoteConn{perspective: pers}
|
||||
func newClosedRemoteConn() packetHandler {
|
||||
return &closedRemoteConn{}
|
||||
}
|
||||
|
||||
func (s *closedRemoteConn) handlePacket(receivedPacket) {}
|
||||
func (s *closedRemoteConn) shutdown() {}
|
||||
func (s *closedRemoteConn) destroy(error) {}
|
||||
func (s *closedRemoteConn) getPerspective() protocol.Perspective { return s.perspective }
|
||||
func (c *closedRemoteConn) handlePacket(receivedPacket) {}
|
||||
func (c *closedRemoteConn) destroy(error) {}
|
||||
func (c *closedRemoteConn) closeWithTransportError(TransportErrorCode) {}
|
||||
|
|
|
@ -5,6 +5,8 @@ coverage:
|
|||
- interop/
|
||||
- internal/handshake/cipher_suite.go
|
||||
- internal/utils/linkedlist/linkedlist.go
|
||||
- internal/testdata
|
||||
- testutils/
|
||||
- fuzzing/
|
||||
- metrics/
|
||||
status:
|
||||
|
|
|
@ -2,7 +2,6 @@ package quic
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
|
@ -49,16 +48,6 @@ func validateConfig(config *Config) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// populateServerConfig populates fields in the quic.Config with their default values, if none are set
|
||||
// it may be called with nil
|
||||
func populateServerConfig(config *Config) *Config {
|
||||
config = populateConfig(config)
|
||||
if config.RequireAddressValidation == nil {
|
||||
config.RequireAddressValidation = func(net.Addr) bool { return false }
|
||||
}
|
||||
return config
|
||||
}
|
||||
|
||||
// populateConfig populates fields in the quic.Config with their default values, if none are set
|
||||
// it may be called with nil
|
||||
func populateConfig(config *Config) *Config {
|
||||
|
@ -111,7 +100,6 @@ func populateConfig(config *Config) *Config {
|
|||
Versions: versions,
|
||||
HandshakeIdleTimeout: handshakeIdleTimeout,
|
||||
MaxIdleTimeout: idleTimeout,
|
||||
RequireAddressValidation: config.RequireAddressValidation,
|
||||
KeepAlivePeriod: config.KeepAlivePeriod,
|
||||
InitialStreamReceiveWindow: initialStreamReceiveWindow,
|
||||
MaxStreamReceiveWindow: maxStreamReceiveWindow,
|
||||
|
|
|
@ -19,7 +19,7 @@ type connIDGenerator struct {
|
|||
getStatelessResetToken func(protocol.ConnectionID) protocol.StatelessResetToken
|
||||
removeConnectionID func(protocol.ConnectionID)
|
||||
retireConnectionID func(protocol.ConnectionID)
|
||||
replaceWithClosed func([]protocol.ConnectionID, protocol.Perspective, []byte)
|
||||
replaceWithClosed func([]protocol.ConnectionID, []byte)
|
||||
queueControlFrame func(wire.Frame)
|
||||
}
|
||||
|
||||
|
@ -30,7 +30,7 @@ func newConnIDGenerator(
|
|||
getStatelessResetToken func(protocol.ConnectionID) protocol.StatelessResetToken,
|
||||
removeConnectionID func(protocol.ConnectionID),
|
||||
retireConnectionID func(protocol.ConnectionID),
|
||||
replaceWithClosed func([]protocol.ConnectionID, protocol.Perspective, []byte),
|
||||
replaceWithClosed func([]protocol.ConnectionID, []byte),
|
||||
queueControlFrame func(wire.Frame),
|
||||
generator ConnectionIDGenerator,
|
||||
) *connIDGenerator {
|
||||
|
@ -126,7 +126,7 @@ func (m *connIDGenerator) RemoveAll() {
|
|||
}
|
||||
}
|
||||
|
||||
func (m *connIDGenerator) ReplaceWithClosed(pers protocol.Perspective, connClose []byte) {
|
||||
func (m *connIDGenerator) ReplaceWithClosed(connClose []byte) {
|
||||
connIDs := make([]protocol.ConnectionID, 0, len(m.activeSrcConnIDs)+1)
|
||||
if m.initialClientDestConnID != nil {
|
||||
connIDs = append(connIDs, *m.initialClientDestConnID)
|
||||
|
@ -134,5 +134,5 @@ func (m *connIDGenerator) ReplaceWithClosed(pers protocol.Perspective, connClose
|
|||
for _, connID := range m.activeSrcConnIDs {
|
||||
connIDs = append(connIDs, connID)
|
||||
}
|
||||
m.replaceWithClosed(connIDs, pers, connClose)
|
||||
m.replaceWithClosed(connIDs, connClose)
|
||||
}
|
||||
|
|
|
@ -25,7 +25,7 @@ import (
|
|||
)
|
||||
|
||||
type unpacker interface {
|
||||
UnpackLongHeader(hdr *wire.Header, rcvTime time.Time, data []byte, v protocol.VersionNumber) (*unpackedPacket, error)
|
||||
UnpackLongHeader(hdr *wire.Header, rcvTime time.Time, data []byte, v protocol.Version) (*unpackedPacket, error)
|
||||
UnpackShortHeader(rcvTime time.Time, data []byte) (protocol.PacketNumber, protocol.PacketNumberLen, protocol.KeyPhaseBit, []byte, error)
|
||||
}
|
||||
|
||||
|
@ -93,7 +93,7 @@ type connRunner interface {
|
|||
GetStatelessResetToken(protocol.ConnectionID) protocol.StatelessResetToken
|
||||
Retire(protocol.ConnectionID)
|
||||
Remove(protocol.ConnectionID)
|
||||
ReplaceWithClosed([]protocol.ConnectionID, protocol.Perspective, []byte)
|
||||
ReplaceWithClosed([]protocol.ConnectionID, []byte)
|
||||
AddResetToken(protocol.StatelessResetToken, packetHandler)
|
||||
RemoveResetToken(protocol.StatelessResetToken)
|
||||
}
|
||||
|
@ -106,7 +106,7 @@ type closeError struct {
|
|||
|
||||
type errCloseForRecreating struct {
|
||||
nextPacketNumber protocol.PacketNumber
|
||||
nextVersion protocol.VersionNumber
|
||||
nextVersion protocol.Version
|
||||
}
|
||||
|
||||
func (e *errCloseForRecreating) Error() string {
|
||||
|
@ -128,7 +128,7 @@ type connection struct {
|
|||
srcConnIDLen int
|
||||
|
||||
perspective protocol.Perspective
|
||||
version protocol.VersionNumber
|
||||
version protocol.Version
|
||||
config *Config
|
||||
|
||||
conn sendConn
|
||||
|
@ -177,6 +177,7 @@ type connection struct {
|
|||
|
||||
earlyConnReadyChan chan struct{}
|
||||
sentFirstPacket bool
|
||||
droppedInitialKeys bool
|
||||
handshakeComplete bool
|
||||
handshakeConfirmed bool
|
||||
|
||||
|
@ -235,7 +236,7 @@ var newConnection = func(
|
|||
tracer *logging.ConnectionTracer,
|
||||
tracingID uint64,
|
||||
logger utils.Logger,
|
||||
v protocol.VersionNumber,
|
||||
v protocol.Version,
|
||||
) quicConn {
|
||||
s := &connection{
|
||||
conn: conn,
|
||||
|
@ -348,7 +349,7 @@ var newClientConnection = func(
|
|||
tracer *logging.ConnectionTracer,
|
||||
tracingID uint64,
|
||||
logger utils.Logger,
|
||||
v protocol.VersionNumber,
|
||||
v protocol.Version,
|
||||
) quicConn {
|
||||
s := &connection{
|
||||
conn: conn,
|
||||
|
@ -453,7 +454,7 @@ func (s *connection) preSetup() {
|
|||
s.handshakeStream = newCryptoStream()
|
||||
s.sendQueue = newSendQueue(s.conn)
|
||||
s.retransmissionQueue = newRetransmissionQueue()
|
||||
s.frameParser = wire.NewFrameParser(s.config.EnableDatagrams)
|
||||
s.frameParser = *wire.NewFrameParser(s.config.EnableDatagrams)
|
||||
s.rttStats = &utils.RTTStats{}
|
||||
s.connFlowController = flowcontrol.NewConnectionFlowController(
|
||||
protocol.ByteCount(s.config.InitialConnectionReceiveWindow),
|
||||
|
@ -520,6 +521,9 @@ func (s *connection) run() error {
|
|||
|
||||
runLoop:
|
||||
for {
|
||||
if s.framer.QueuedTooManyControlFrames() {
|
||||
s.closeLocal(&qerr.TransportError{ErrorCode: InternalError})
|
||||
}
|
||||
// Close immediately if requested
|
||||
select {
|
||||
case closeErr = <-s.closeChan:
|
||||
|
@ -1148,7 +1152,7 @@ func (s *connection) handleUnpackedLongHeaderPacket(
|
|||
if !s.receivedFirstPacket {
|
||||
s.receivedFirstPacket = true
|
||||
if !s.versionNegotiated && s.tracer != nil && s.tracer.NegotiatedVersion != nil {
|
||||
var clientVersions, serverVersions []protocol.VersionNumber
|
||||
var clientVersions, serverVersions []protocol.Version
|
||||
switch s.perspective {
|
||||
case protocol.PerspectiveClient:
|
||||
clientVersions = s.config.Versions
|
||||
|
@ -1185,7 +1189,8 @@ func (s *connection) handleUnpackedLongHeaderPacket(
|
|||
}
|
||||
}
|
||||
|
||||
if s.perspective == protocol.PerspectiveServer && packet.encryptionLevel == protocol.EncryptionHandshake {
|
||||
if s.perspective == protocol.PerspectiveServer && packet.encryptionLevel == protocol.EncryptionHandshake &&
|
||||
!s.droppedInitialKeys {
|
||||
// On the server side, Initial keys are dropped as soon as the first Handshake packet is received.
|
||||
// See Section 4.9.1 of RFC 9001.
|
||||
if err := s.dropEncryptionLevel(protocol.EncryptionInitial); err != nil {
|
||||
|
@ -1572,13 +1577,6 @@ func (s *connection) closeRemote(e error) {
|
|||
})
|
||||
}
|
||||
|
||||
// Close the connection. It sends a NO_ERROR application error.
|
||||
// It waits until the run loop has stopped before returning
|
||||
func (s *connection) shutdown() {
|
||||
s.closeLocal(nil)
|
||||
<-s.ctx.Done()
|
||||
}
|
||||
|
||||
func (s *connection) CloseWithError(code ApplicationErrorCode, desc string) error {
|
||||
s.closeLocal(&qerr.ApplicationError{
|
||||
ErrorCode: code,
|
||||
|
@ -1588,6 +1586,11 @@ func (s *connection) CloseWithError(code ApplicationErrorCode, desc string) erro
|
|||
return nil
|
||||
}
|
||||
|
||||
func (s *connection) closeWithTransportError(code TransportErrorCode) {
|
||||
s.closeLocal(&qerr.TransportError{ErrorCode: code})
|
||||
<-s.ctx.Done()
|
||||
}
|
||||
|
||||
func (s *connection) handleCloseError(closeErr *closeError) {
|
||||
e := closeErr.err
|
||||
if e == nil {
|
||||
|
@ -1632,7 +1635,7 @@ func (s *connection) handleCloseError(closeErr *closeError) {
|
|||
|
||||
// If this is a remote close we're done here
|
||||
if closeErr.remote {
|
||||
s.connIDGenerator.ReplaceWithClosed(s.perspective, nil)
|
||||
s.connIDGenerator.ReplaceWithClosed(nil)
|
||||
return
|
||||
}
|
||||
if closeErr.immediate {
|
||||
|
@ -1649,7 +1652,7 @@ func (s *connection) handleCloseError(closeErr *closeError) {
|
|||
if err != nil {
|
||||
s.logger.Debugf("Error sending CONNECTION_CLOSE: %s", err)
|
||||
}
|
||||
s.connIDGenerator.ReplaceWithClosed(s.perspective, connClosePacket)
|
||||
s.connIDGenerator.ReplaceWithClosed(connClosePacket)
|
||||
}
|
||||
|
||||
func (s *connection) dropEncryptionLevel(encLevel protocol.EncryptionLevel) error {
|
||||
|
@ -1661,6 +1664,7 @@ func (s *connection) dropEncryptionLevel(encLevel protocol.EncryptionLevel) erro
|
|||
//nolint:exhaustive // only Initial and 0-RTT need special treatment
|
||||
switch encLevel {
|
||||
case protocol.EncryptionInitial:
|
||||
s.droppedInitialKeys = true
|
||||
s.cryptoStreamHandler.DiscardInitialKeys()
|
||||
case protocol.Encryption0RTT:
|
||||
s.streamsMap.ResetFor0RTT()
|
||||
|
@ -2077,7 +2081,8 @@ func (s *connection) sendPackedCoalescedPacket(packet *coalescedPacket, ecn prot
|
|||
largestAcked = p.ack.LargestAcked()
|
||||
}
|
||||
s.sentPacketHandler.SentPacket(now, p.header.PacketNumber, largestAcked, p.streamFrames, p.frames, p.EncryptionLevel(), ecn, p.length, false)
|
||||
if s.perspective == protocol.PerspectiveClient && p.EncryptionLevel() == protocol.EncryptionHandshake {
|
||||
if s.perspective == protocol.PerspectiveClient && p.EncryptionLevel() == protocol.EncryptionHandshake &&
|
||||
!s.droppedInitialKeys {
|
||||
// On the client side, Initial keys are dropped as soon as the first Handshake packet is sent.
|
||||
// See Section 4.9.1 of RFC 9001.
|
||||
if err := s.dropEncryptionLevel(protocol.EncryptionInitial); err != nil {
|
||||
|
@ -2377,11 +2382,7 @@ func (s *connection) RemoteAddr() net.Addr {
|
|||
return s.conn.RemoteAddr()
|
||||
}
|
||||
|
||||
func (s *connection) getPerspective() protocol.Perspective {
|
||||
return s.perspective
|
||||
}
|
||||
|
||||
func (s *connection) GetVersion() protocol.VersionNumber {
|
||||
func (s *connection) GetVersion() protocol.Version {
|
||||
return s.version
|
||||
}
|
||||
|
||||
|
|
|
@ -15,15 +15,25 @@ type framer interface {
|
|||
HasData() bool
|
||||
|
||||
QueueControlFrame(wire.Frame)
|
||||
AppendControlFrames([]ackhandler.Frame, protocol.ByteCount, protocol.VersionNumber) ([]ackhandler.Frame, protocol.ByteCount)
|
||||
AppendControlFrames([]ackhandler.Frame, protocol.ByteCount, protocol.Version) ([]ackhandler.Frame, protocol.ByteCount)
|
||||
|
||||
AddActiveStream(protocol.StreamID)
|
||||
AppendStreamFrames([]ackhandler.StreamFrame, protocol.ByteCount, protocol.VersionNumber) ([]ackhandler.StreamFrame, protocol.ByteCount)
|
||||
AppendStreamFrames([]ackhandler.StreamFrame, protocol.ByteCount, protocol.Version) ([]ackhandler.StreamFrame, protocol.ByteCount)
|
||||
|
||||
Handle0RTTRejection() error
|
||||
|
||||
// QueuedTooManyControlFrames says if the control frame queue exceeded its maximum queue length.
|
||||
// This is a hack.
|
||||
// It is easier to implement than propagating an error return value in QueueControlFrame.
|
||||
// The correct solution would be to queue frames with their respective structs.
|
||||
// See https://github.com/quic-go/quic-go/issues/4271 for the queueing of stream-related control frames.
|
||||
QueuedTooManyControlFrames() bool
|
||||
}
|
||||
|
||||
const maxPathResponses = 256
|
||||
const (
|
||||
maxPathResponses = 256
|
||||
maxControlFrames = 16 << 10
|
||||
)
|
||||
|
||||
type framerI struct {
|
||||
mutex sync.Mutex
|
||||
|
@ -33,9 +43,10 @@ type framerI struct {
|
|||
activeStreams map[protocol.StreamID]struct{}
|
||||
streamQueue ringbuffer.RingBuffer[protocol.StreamID]
|
||||
|
||||
controlFrameMutex sync.Mutex
|
||||
controlFrames []wire.Frame
|
||||
pathResponses []*wire.PathResponseFrame
|
||||
controlFrameMutex sync.Mutex
|
||||
controlFrames []wire.Frame
|
||||
pathResponses []*wire.PathResponseFrame
|
||||
queuedTooManyControlFrames bool
|
||||
}
|
||||
|
||||
var _ framer = &framerI{}
|
||||
|
@ -73,10 +84,15 @@ func (f *framerI) QueueControlFrame(frame wire.Frame) {
|
|||
f.pathResponses = append(f.pathResponses, pr)
|
||||
return
|
||||
}
|
||||
// This is a hack.
|
||||
if len(f.controlFrames) >= maxControlFrames {
|
||||
f.queuedTooManyControlFrames = true
|
||||
return
|
||||
}
|
||||
f.controlFrames = append(f.controlFrames, frame)
|
||||
}
|
||||
|
||||
func (f *framerI) AppendControlFrames(frames []ackhandler.Frame, maxLen protocol.ByteCount, v protocol.VersionNumber) ([]ackhandler.Frame, protocol.ByteCount) {
|
||||
func (f *framerI) AppendControlFrames(frames []ackhandler.Frame, maxLen protocol.ByteCount, v protocol.Version) ([]ackhandler.Frame, protocol.ByteCount) {
|
||||
f.controlFrameMutex.Lock()
|
||||
defer f.controlFrameMutex.Unlock()
|
||||
|
||||
|
@ -105,6 +121,10 @@ func (f *framerI) AppendControlFrames(frames []ackhandler.Frame, maxLen protocol
|
|||
return frames, length
|
||||
}
|
||||
|
||||
func (f *framerI) QueuedTooManyControlFrames() bool {
|
||||
return f.queuedTooManyControlFrames
|
||||
}
|
||||
|
||||
func (f *framerI) AddActiveStream(id protocol.StreamID) {
|
||||
f.mutex.Lock()
|
||||
if _, ok := f.activeStreams[id]; !ok {
|
||||
|
@ -114,7 +134,7 @@ func (f *framerI) AddActiveStream(id protocol.StreamID) {
|
|||
f.mutex.Unlock()
|
||||
}
|
||||
|
||||
func (f *framerI) AppendStreamFrames(frames []ackhandler.StreamFrame, maxLen protocol.ByteCount, v protocol.VersionNumber) ([]ackhandler.StreamFrame, protocol.ByteCount) {
|
||||
func (f *framerI) AppendStreamFrames(frames []ackhandler.StreamFrame, maxLen protocol.ByteCount, v protocol.Version) ([]ackhandler.StreamFrame, protocol.ByteCount) {
|
||||
startLen := len(frames)
|
||||
var length protocol.ByteCount
|
||||
f.mutex.Lock()
|
||||
|
|
|
@ -16,8 +16,12 @@ import (
|
|||
// The StreamID is the ID of a QUIC stream.
|
||||
type StreamID = protocol.StreamID
|
||||
|
||||
// A Version is a QUIC version number.
|
||||
type Version = protocol.Version
|
||||
|
||||
// A VersionNumber is a QUIC version number.
|
||||
type VersionNumber = protocol.VersionNumber
|
||||
// Deprecated: VersionNumber was renamed to Version.
|
||||
type VersionNumber = Version
|
||||
|
||||
const (
|
||||
// Version1 is RFC 9000
|
||||
|
@ -159,6 +163,9 @@ type Connection interface {
|
|||
OpenStream() (Stream, error)
|
||||
// OpenStreamSync opens a new bidirectional QUIC stream.
|
||||
// It blocks until a new stream can be opened.
|
||||
// There is no signaling to the peer about new streams:
|
||||
// The peer can only accept the stream after data has been sent on the stream,
|
||||
// or the stream has been reset or closed.
|
||||
// If the error is non-nil, it satisfies the net.Error interface.
|
||||
// If the connection was closed due to a timeout, Timeout() will be true.
|
||||
OpenStreamSync(context.Context) (Stream, error)
|
||||
|
@ -255,7 +262,7 @@ type Config struct {
|
|||
GetConfigForClient func(info *ClientHelloInfo) (*Config, error)
|
||||
// The QUIC versions that can be negotiated.
|
||||
// If not set, it uses all versions available.
|
||||
Versions []VersionNumber
|
||||
Versions []Version
|
||||
// HandshakeIdleTimeout is the idle timeout before completion of the handshake.
|
||||
// If we don't receive any packet from the peer within this time, the connection attempt is aborted.
|
||||
// Additionally, if the handshake doesn't complete in twice this time, the connection attempt is also aborted.
|
||||
|
@ -267,11 +274,6 @@ type Config struct {
|
|||
// If the timeout is exceeded, the connection is closed.
|
||||
// If this value is zero, the timeout is set to 30 seconds.
|
||||
MaxIdleTimeout time.Duration
|
||||
// RequireAddressValidation determines if a QUIC Retry packet is sent.
|
||||
// This allows the server to verify the client's address, at the cost of increasing the handshake latency by 1 RTT.
|
||||
// See https://datatracker.ietf.org/doc/html/rfc9000#section-8 for details.
|
||||
// If not set, every client is forced to prove its remote address.
|
||||
RequireAddressValidation func(net.Addr) bool
|
||||
// The TokenStore stores tokens received from the server.
|
||||
// Tokens are used to skip address validation on future connection attempts.
|
||||
// The key used to store tokens is the ServerName from the tls.Config, if set
|
||||
|
@ -331,8 +333,15 @@ type Config struct {
|
|||
Tracer func(context.Context, logging.Perspective, ConnectionID) *logging.ConnectionTracer
|
||||
}
|
||||
|
||||
// ClientHelloInfo contains information about an incoming connection attempt.
|
||||
type ClientHelloInfo struct {
|
||||
// RemoteAddr is the remote address on the Initial packet.
|
||||
// Unless AddrVerified is set, the address is not yet verified, and could be a spoofed IP address.
|
||||
RemoteAddr net.Addr
|
||||
// AddrVerified says if the remote address was verified using QUIC's Retry mechanism.
|
||||
// Note that the Retry mechanism costs one network roundtrip,
|
||||
// and is not performed unless Transport.MaxUnvalidatedHandshakes is surpassed.
|
||||
AddrVerified bool
|
||||
}
|
||||
|
||||
// ConnectionState records basic details about a QUIC connection
|
||||
|
@ -347,7 +356,7 @@ type ConnectionState struct {
|
|||
// Used0RTT says if 0-RTT resumption was used.
|
||||
Used0RTT bool
|
||||
// Version is the QUIC version of the QUIC connection.
|
||||
Version VersionNumber
|
||||
Version Version
|
||||
// GSO says if generic segmentation offload is used
|
||||
GSO bool
|
||||
}
|
||||
|
|
|
@ -20,5 +20,5 @@ func NewAckHandler(
|
|||
logger utils.Logger,
|
||||
) (SentPacketHandler, ReceivedPacketHandler) {
|
||||
sph := newSentPacketHandler(initialPacketNumber, initialMaxDatagramSize, rttStats, clientAddressValidated, enableECN, pers, tracer, logger)
|
||||
return sph, newReceivedPacketHandler(sph, rttStats, logger)
|
||||
return sph, newReceivedPacketHandler(sph, logger)
|
||||
}
|
||||
|
|
39
vendor/github.com/quic-go/quic-go/internal/ackhandler/received_packet_handler.go
generated
vendored
39
vendor/github.com/quic-go/quic-go/internal/ackhandler/received_packet_handler.go
generated
vendored
|
@ -14,23 +14,19 @@ type receivedPacketHandler struct {
|
|||
|
||||
initialPackets *receivedPacketTracker
|
||||
handshakePackets *receivedPacketTracker
|
||||
appDataPackets *receivedPacketTracker
|
||||
appDataPackets appDataReceivedPacketTracker
|
||||
|
||||
lowest1RTTPacket protocol.PacketNumber
|
||||
}
|
||||
|
||||
var _ ReceivedPacketHandler = &receivedPacketHandler{}
|
||||
|
||||
func newReceivedPacketHandler(
|
||||
sentPackets sentPacketTracker,
|
||||
rttStats *utils.RTTStats,
|
||||
logger utils.Logger,
|
||||
) ReceivedPacketHandler {
|
||||
func newReceivedPacketHandler(sentPackets sentPacketTracker, logger utils.Logger) ReceivedPacketHandler {
|
||||
return &receivedPacketHandler{
|
||||
sentPackets: sentPackets,
|
||||
initialPackets: newReceivedPacketTracker(rttStats, logger),
|
||||
handshakePackets: newReceivedPacketTracker(rttStats, logger),
|
||||
appDataPackets: newReceivedPacketTracker(rttStats, logger),
|
||||
initialPackets: newReceivedPacketTracker(),
|
||||
handshakePackets: newReceivedPacketTracker(),
|
||||
appDataPackets: *newAppDataReceivedPacketTracker(logger),
|
||||
lowest1RTTPacket: protocol.InvalidPacketNumber,
|
||||
}
|
||||
}
|
||||
|
@ -88,41 +84,28 @@ func (h *receivedPacketHandler) DropPackets(encLevel protocol.EncryptionLevel) {
|
|||
}
|
||||
|
||||
func (h *receivedPacketHandler) GetAlarmTimeout() time.Time {
|
||||
var initialAlarm, handshakeAlarm time.Time
|
||||
if h.initialPackets != nil {
|
||||
initialAlarm = h.initialPackets.GetAlarmTimeout()
|
||||
}
|
||||
if h.handshakePackets != nil {
|
||||
handshakeAlarm = h.handshakePackets.GetAlarmTimeout()
|
||||
}
|
||||
oneRTTAlarm := h.appDataPackets.GetAlarmTimeout()
|
||||
return utils.MinNonZeroTime(utils.MinNonZeroTime(initialAlarm, handshakeAlarm), oneRTTAlarm)
|
||||
return h.appDataPackets.GetAlarmTimeout()
|
||||
}
|
||||
|
||||
func (h *receivedPacketHandler) GetAckFrame(encLevel protocol.EncryptionLevel, onlyIfQueued bool) *wire.AckFrame {
|
||||
var ack *wire.AckFrame
|
||||
//nolint:exhaustive // 0-RTT packets can't contain ACK frames.
|
||||
switch encLevel {
|
||||
case protocol.EncryptionInitial:
|
||||
if h.initialPackets != nil {
|
||||
ack = h.initialPackets.GetAckFrame(onlyIfQueued)
|
||||
return h.initialPackets.GetAckFrame()
|
||||
}
|
||||
return nil
|
||||
case protocol.EncryptionHandshake:
|
||||
if h.handshakePackets != nil {
|
||||
ack = h.handshakePackets.GetAckFrame(onlyIfQueued)
|
||||
return h.handshakePackets.GetAckFrame()
|
||||
}
|
||||
return nil
|
||||
case protocol.Encryption1RTT:
|
||||
// 0-RTT packets can't contain ACK frames
|
||||
return h.appDataPackets.GetAckFrame(onlyIfQueued)
|
||||
default:
|
||||
// 0-RTT packets can't contain ACK frames
|
||||
return nil
|
||||
}
|
||||
// For Initial and Handshake ACKs, the delay time is ignored by the receiver.
|
||||
// Set it to 0 in order to save bytes.
|
||||
if ack != nil {
|
||||
ack.DelayTime = 0
|
||||
}
|
||||
return ack
|
||||
}
|
||||
|
||||
func (h *receivedPacketHandler) IsPotentiallyDuplicate(pn protocol.PacketNumber, encLevel protocol.EncryptionLevel) bool {
|
||||
|
|
208
vendor/github.com/quic-go/quic-go/internal/ackhandler/received_packet_tracker.go
generated
vendored
208
vendor/github.com/quic-go/quic-go/internal/ackhandler/received_packet_tracker.go
generated
vendored
|
@ -9,40 +9,19 @@ import (
|
|||
"github.com/quic-go/quic-go/internal/wire"
|
||||
)
|
||||
|
||||
// number of ack-eliciting packets received before sending an ack.
|
||||
const packetsBeforeAck = 2
|
||||
|
||||
// The receivedPacketTracker tracks packets for the Initial and Handshake packet number space.
|
||||
// Every received packet is acknowledged immediately.
|
||||
type receivedPacketTracker struct {
|
||||
largestObserved protocol.PacketNumber
|
||||
ignoreBelow protocol.PacketNumber
|
||||
largestObservedRcvdTime time.Time
|
||||
ect0, ect1, ecnce uint64
|
||||
ect0, ect1, ecnce uint64
|
||||
|
||||
packetHistory *receivedPacketHistory
|
||||
|
||||
maxAckDelay time.Duration
|
||||
rttStats *utils.RTTStats
|
||||
packetHistory receivedPacketHistory
|
||||
|
||||
lastAck *wire.AckFrame
|
||||
hasNewAck bool // true as soon as we received an ack-eliciting new packet
|
||||
ackQueued bool // true once we received more than 2 (or later in the connection 10) ack-eliciting packets
|
||||
|
||||
ackElicitingPacketsReceivedSinceLastAck int
|
||||
ackAlarm time.Time
|
||||
lastAck *wire.AckFrame
|
||||
|
||||
logger utils.Logger
|
||||
}
|
||||
|
||||
func newReceivedPacketTracker(
|
||||
rttStats *utils.RTTStats,
|
||||
logger utils.Logger,
|
||||
) *receivedPacketTracker {
|
||||
return &receivedPacketTracker{
|
||||
packetHistory: newReceivedPacketHistory(),
|
||||
maxAckDelay: protocol.MaxAckDelay,
|
||||
rttStats: rttStats,
|
||||
logger: logger,
|
||||
}
|
||||
func newReceivedPacketTracker() *receivedPacketTracker {
|
||||
return &receivedPacketTracker{packetHistory: *newReceivedPacketHistory()}
|
||||
}
|
||||
|
||||
func (h *receivedPacketTracker) ReceivedPacket(pn protocol.PacketNumber, ecn protocol.ECN, rcvTime time.Time, ackEliciting bool) error {
|
||||
|
@ -50,16 +29,6 @@ func (h *receivedPacketTracker) ReceivedPacket(pn protocol.PacketNumber, ecn pro
|
|||
return fmt.Errorf("recevedPacketTracker BUG: ReceivedPacket called for old / duplicate packet %d", pn)
|
||||
}
|
||||
|
||||
isMissing := h.isMissing(pn)
|
||||
if pn >= h.largestObserved {
|
||||
h.largestObserved = pn
|
||||
h.largestObservedRcvdTime = rcvTime
|
||||
}
|
||||
|
||||
if ackEliciting {
|
||||
h.hasNewAck = true
|
||||
h.maybeQueueACK(pn, rcvTime, ecn, isMissing)
|
||||
}
|
||||
//nolint:exhaustive // Only need to count ECT(0), ECT(1) and ECN-CE.
|
||||
switch ecn {
|
||||
case protocol.ECT0:
|
||||
|
@ -69,12 +38,99 @@ func (h *receivedPacketTracker) ReceivedPacket(pn protocol.PacketNumber, ecn pro
|
|||
case protocol.ECNCE:
|
||||
h.ecnce++
|
||||
}
|
||||
if !ackEliciting {
|
||||
return nil
|
||||
}
|
||||
h.hasNewAck = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *receivedPacketTracker) GetAckFrame() *wire.AckFrame {
|
||||
if !h.hasNewAck {
|
||||
return nil
|
||||
}
|
||||
|
||||
// This function always returns the same ACK frame struct, filled with the most recent values.
|
||||
ack := h.lastAck
|
||||
if ack == nil {
|
||||
ack = &wire.AckFrame{}
|
||||
}
|
||||
ack.Reset()
|
||||
ack.ECT0 = h.ect0
|
||||
ack.ECT1 = h.ect1
|
||||
ack.ECNCE = h.ecnce
|
||||
ack.AckRanges = h.packetHistory.AppendAckRanges(ack.AckRanges)
|
||||
|
||||
h.lastAck = ack
|
||||
h.hasNewAck = false
|
||||
return ack
|
||||
}
|
||||
|
||||
func (h *receivedPacketTracker) IsPotentiallyDuplicate(pn protocol.PacketNumber) bool {
|
||||
return h.packetHistory.IsPotentiallyDuplicate(pn)
|
||||
}
|
||||
|
||||
// number of ack-eliciting packets received before sending an ACK
|
||||
const packetsBeforeAck = 2
|
||||
|
||||
// The appDataReceivedPacketTracker tracks packets received in the Application Data packet number space.
|
||||
// It waits until at least 2 packets were received before queueing an ACK, or until the max_ack_delay was reached.
|
||||
type appDataReceivedPacketTracker struct {
|
||||
receivedPacketTracker
|
||||
|
||||
largestObservedRcvdTime time.Time
|
||||
|
||||
largestObserved protocol.PacketNumber
|
||||
ignoreBelow protocol.PacketNumber
|
||||
|
||||
maxAckDelay time.Duration
|
||||
ackQueued bool // true if we need send a new ACK
|
||||
|
||||
ackElicitingPacketsReceivedSinceLastAck int
|
||||
ackAlarm time.Time
|
||||
|
||||
logger utils.Logger
|
||||
}
|
||||
|
||||
func newAppDataReceivedPacketTracker(logger utils.Logger) *appDataReceivedPacketTracker {
|
||||
h := &appDataReceivedPacketTracker{
|
||||
receivedPacketTracker: *newReceivedPacketTracker(),
|
||||
maxAckDelay: protocol.MaxAckDelay,
|
||||
logger: logger,
|
||||
}
|
||||
return h
|
||||
}
|
||||
|
||||
func (h *appDataReceivedPacketTracker) ReceivedPacket(pn protocol.PacketNumber, ecn protocol.ECN, rcvTime time.Time, ackEliciting bool) error {
|
||||
if err := h.receivedPacketTracker.ReceivedPacket(pn, ecn, rcvTime, ackEliciting); err != nil {
|
||||
return err
|
||||
}
|
||||
if pn >= h.largestObserved {
|
||||
h.largestObserved = pn
|
||||
h.largestObservedRcvdTime = rcvTime
|
||||
}
|
||||
if !ackEliciting {
|
||||
return nil
|
||||
}
|
||||
h.ackElicitingPacketsReceivedSinceLastAck++
|
||||
isMissing := h.isMissing(pn)
|
||||
if !h.ackQueued && h.shouldQueueACK(pn, ecn, isMissing) {
|
||||
h.ackQueued = true
|
||||
h.ackAlarm = time.Time{} // cancel the ack alarm
|
||||
}
|
||||
if !h.ackQueued {
|
||||
// No ACK queued, but we'll need to acknowledge the packet after max_ack_delay.
|
||||
h.ackAlarm = rcvTime.Add(h.maxAckDelay)
|
||||
if h.logger.Debug() {
|
||||
h.logger.Debugf("\tSetting ACK timer to max ack delay: %s", h.maxAckDelay)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// IgnoreBelow sets a lower limit for acknowledging packets.
|
||||
// Packets with packet numbers smaller than p will not be acked.
|
||||
func (h *receivedPacketTracker) IgnoreBelow(pn protocol.PacketNumber) {
|
||||
func (h *appDataReceivedPacketTracker) IgnoreBelow(pn protocol.PacketNumber) {
|
||||
if pn <= h.ignoreBelow {
|
||||
return
|
||||
}
|
||||
|
@ -86,14 +142,14 @@ func (h *receivedPacketTracker) IgnoreBelow(pn protocol.PacketNumber) {
|
|||
}
|
||||
|
||||
// isMissing says if a packet was reported missing in the last ACK.
|
||||
func (h *receivedPacketTracker) isMissing(p protocol.PacketNumber) bool {
|
||||
func (h *appDataReceivedPacketTracker) isMissing(p protocol.PacketNumber) bool {
|
||||
if h.lastAck == nil || p < h.ignoreBelow {
|
||||
return false
|
||||
}
|
||||
return p < h.lastAck.LargestAcked() && !h.lastAck.AcksPacket(p)
|
||||
}
|
||||
|
||||
func (h *receivedPacketTracker) hasNewMissingPackets() bool {
|
||||
func (h *appDataReceivedPacketTracker) hasNewMissingPackets() bool {
|
||||
if h.lastAck == nil {
|
||||
return false
|
||||
}
|
||||
|
@ -101,31 +157,21 @@ func (h *receivedPacketTracker) hasNewMissingPackets() bool {
|
|||
return highestRange.Smallest > h.lastAck.LargestAcked()+1 && highestRange.Len() == 1
|
||||
}
|
||||
|
||||
// maybeQueueACK queues an ACK, if necessary.
|
||||
func (h *receivedPacketTracker) maybeQueueACK(pn protocol.PacketNumber, rcvTime time.Time, ecn protocol.ECN, wasMissing bool) {
|
||||
func (h *appDataReceivedPacketTracker) shouldQueueACK(pn protocol.PacketNumber, ecn protocol.ECN, wasMissing bool) bool {
|
||||
// always acknowledge the first packet
|
||||
if h.lastAck == nil {
|
||||
if !h.ackQueued {
|
||||
h.logger.Debugf("\tQueueing ACK because the first packet should be acknowledged.")
|
||||
}
|
||||
h.ackQueued = true
|
||||
return
|
||||
h.logger.Debugf("\tQueueing ACK because the first packet should be acknowledged.")
|
||||
return true
|
||||
}
|
||||
|
||||
if h.ackQueued {
|
||||
return
|
||||
}
|
||||
|
||||
h.ackElicitingPacketsReceivedSinceLastAck++
|
||||
|
||||
// Send an ACK if this packet was reported missing in an ACK sent before.
|
||||
// Ack decimation with reordering relies on the timer to send an ACK, but if
|
||||
// missing packets we reported in the previous ack, send an ACK immediately.
|
||||
// missing packets we reported in the previous ACK, send an ACK immediately.
|
||||
if wasMissing {
|
||||
if h.logger.Debug() {
|
||||
h.logger.Debugf("\tQueueing ACK because packet %d was missing before.", pn)
|
||||
}
|
||||
h.ackQueued = true
|
||||
return true
|
||||
}
|
||||
|
||||
// send an ACK every 2 ack-eliciting packets
|
||||
|
@ -133,68 +179,42 @@ func (h *receivedPacketTracker) maybeQueueACK(pn protocol.PacketNumber, rcvTime
|
|||
if h.logger.Debug() {
|
||||
h.logger.Debugf("\tQueueing ACK because packet %d packets were received after the last ACK (using initial threshold: %d).", h.ackElicitingPacketsReceivedSinceLastAck, packetsBeforeAck)
|
||||
}
|
||||
h.ackQueued = true
|
||||
} else if h.ackAlarm.IsZero() {
|
||||
if h.logger.Debug() {
|
||||
h.logger.Debugf("\tSetting ACK timer to max ack delay: %s", h.maxAckDelay)
|
||||
}
|
||||
h.ackAlarm = rcvTime.Add(h.maxAckDelay)
|
||||
return true
|
||||
}
|
||||
|
||||
// queue an ACK if there are new missing packets to report
|
||||
if h.hasNewMissingPackets() {
|
||||
h.logger.Debugf("\tQueuing ACK because there's a new missing packet to report.")
|
||||
h.ackQueued = true
|
||||
return true
|
||||
}
|
||||
|
||||
// queue an ACK if the packet was ECN-CE marked
|
||||
if ecn == protocol.ECNCE {
|
||||
h.logger.Debugf("\tQueuing ACK because the packet was ECN-CE marked.")
|
||||
h.ackQueued = true
|
||||
}
|
||||
|
||||
if h.ackQueued {
|
||||
// cancel the ack alarm
|
||||
h.ackAlarm = time.Time{}
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (h *receivedPacketTracker) GetAckFrame(onlyIfQueued bool) *wire.AckFrame {
|
||||
if !h.hasNewAck {
|
||||
return nil
|
||||
}
|
||||
func (h *appDataReceivedPacketTracker) GetAckFrame(onlyIfQueued bool) *wire.AckFrame {
|
||||
now := time.Now()
|
||||
if onlyIfQueued {
|
||||
if !h.ackQueued && (h.ackAlarm.IsZero() || h.ackAlarm.After(now)) {
|
||||
if onlyIfQueued && !h.ackQueued {
|
||||
if h.ackAlarm.IsZero() || h.ackAlarm.After(now) {
|
||||
return nil
|
||||
}
|
||||
if h.logger.Debug() && !h.ackQueued && !h.ackAlarm.IsZero() {
|
||||
if h.logger.Debug() && !h.ackAlarm.IsZero() {
|
||||
h.logger.Debugf("Sending ACK because the ACK timer expired.")
|
||||
}
|
||||
}
|
||||
|
||||
// This function always returns the same ACK frame struct, filled with the most recent values.
|
||||
ack := h.lastAck
|
||||
ack := h.receivedPacketTracker.GetAckFrame()
|
||||
if ack == nil {
|
||||
ack = &wire.AckFrame{}
|
||||
return nil
|
||||
}
|
||||
ack.Reset()
|
||||
ack.DelayTime = max(0, now.Sub(h.largestObservedRcvdTime))
|
||||
ack.ECT0 = h.ect0
|
||||
ack.ECT1 = h.ect1
|
||||
ack.ECNCE = h.ecnce
|
||||
ack.AckRanges = h.packetHistory.AppendAckRanges(ack.AckRanges)
|
||||
|
||||
h.lastAck = ack
|
||||
h.ackAlarm = time.Time{}
|
||||
h.ackQueued = false
|
||||
h.hasNewAck = false
|
||||
h.ackAlarm = time.Time{}
|
||||
h.ackElicitingPacketsReceivedSinceLastAck = 0
|
||||
return ack
|
||||
}
|
||||
|
||||
func (h *receivedPacketTracker) GetAlarmTimeout() time.Time { return h.ackAlarm }
|
||||
|
||||
func (h *receivedPacketTracker) IsPotentiallyDuplicate(pn protocol.PacketNumber) bool {
|
||||
return h.packetHistory.IsPotentiallyDuplicate(pn)
|
||||
}
|
||||
func (h *appDataReceivedPacketTracker) GetAlarmTimeout() time.Time { return h.ackAlarm }
|
||||
|
|
|
@ -48,10 +48,12 @@ func (c *baseFlowController) AddBytesSent(n protocol.ByteCount) {
|
|||
}
|
||||
|
||||
// UpdateSendWindow is called after receiving a MAX_{STREAM_}DATA frame.
|
||||
func (c *baseFlowController) UpdateSendWindow(offset protocol.ByteCount) {
|
||||
func (c *baseFlowController) UpdateSendWindow(offset protocol.ByteCount) (updated bool) {
|
||||
if offset > c.sendWindow {
|
||||
c.sendWindow = offset
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (c *baseFlowController) sendWindowSize() protocol.ByteCount {
|
||||
|
|
|
@ -5,7 +5,7 @@ import "github.com/quic-go/quic-go/internal/protocol"
|
|||
type flowController interface {
|
||||
// for sending
|
||||
SendWindowSize() protocol.ByteCount
|
||||
UpdateSendWindow(protocol.ByteCount)
|
||||
UpdateSendWindow(protocol.ByteCount) (updated bool)
|
||||
AddBytesSent(protocol.ByteCount)
|
||||
// for receiving
|
||||
AddBytesRead(protocol.ByteCount)
|
||||
|
@ -16,12 +16,11 @@ type flowController interface {
|
|||
// A StreamFlowController is a flow controller for a QUIC stream.
|
||||
type StreamFlowController interface {
|
||||
flowController
|
||||
// for receiving
|
||||
// UpdateHighestReceived should be called when a new highest offset is received
|
||||
// UpdateHighestReceived is called when a new highest offset is received
|
||||
// final has to be to true if this is the final offset of the stream,
|
||||
// as contained in a STREAM frame with FIN bit, and the RESET_STREAM frame
|
||||
UpdateHighestReceived(offset protocol.ByteCount, final bool) error
|
||||
// Abandon should be called when reading from the stream is aborted early,
|
||||
// Abandon is called when reading from the stream is aborted early,
|
||||
// and there won't be any further calls to AddBytesRead.
|
||||
Abandon()
|
||||
}
|
||||
|
|
|
@ -1,13 +1,12 @@
|
|||
package handshake
|
||||
|
||||
import (
|
||||
"crypto/cipher"
|
||||
"encoding/binary"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
)
|
||||
|
||||
func createAEAD(suite *cipherSuite, trafficSecret []byte, v protocol.VersionNumber) cipher.AEAD {
|
||||
func createAEAD(suite *cipherSuite, trafficSecret []byte, v protocol.Version) *xorNonceAEAD {
|
||||
keyLabel := hkdfLabelKeyV1
|
||||
ivLabel := hkdfLabelIVV1
|
||||
if v == protocol.Version2 {
|
||||
|
@ -20,28 +19,26 @@ func createAEAD(suite *cipherSuite, trafficSecret []byte, v protocol.VersionNumb
|
|||
}
|
||||
|
||||
type longHeaderSealer struct {
|
||||
aead cipher.AEAD
|
||||
aead *xorNonceAEAD
|
||||
headerProtector headerProtector
|
||||
|
||||
// use a single slice to avoid allocations
|
||||
nonceBuf []byte
|
||||
nonceBuf [8]byte
|
||||
}
|
||||
|
||||
var _ LongHeaderSealer = &longHeaderSealer{}
|
||||
|
||||
func newLongHeaderSealer(aead cipher.AEAD, headerProtector headerProtector) LongHeaderSealer {
|
||||
func newLongHeaderSealer(aead *xorNonceAEAD, headerProtector headerProtector) LongHeaderSealer {
|
||||
if aead.NonceSize() != 8 {
|
||||
panic("unexpected nonce size")
|
||||
}
|
||||
return &longHeaderSealer{
|
||||
aead: aead,
|
||||
headerProtector: headerProtector,
|
||||
nonceBuf: make([]byte, aead.NonceSize()),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *longHeaderSealer) Seal(dst, src []byte, pn protocol.PacketNumber, ad []byte) []byte {
|
||||
binary.BigEndian.PutUint64(s.nonceBuf[len(s.nonceBuf)-8:], uint64(pn))
|
||||
// The AEAD we're using here will be the qtls.aeadAESGCM13.
|
||||
// It uses the nonce provided here and XOR it with the IV.
|
||||
return s.aead.Seal(dst, s.nonceBuf, src, ad)
|
||||
binary.BigEndian.PutUint64(s.nonceBuf[:], uint64(pn))
|
||||
return s.aead.Seal(dst, s.nonceBuf[:], src, ad)
|
||||
}
|
||||
|
||||
func (s *longHeaderSealer) EncryptHeader(sample []byte, firstByte *byte, pnBytes []byte) {
|
||||
|
@ -53,21 +50,23 @@ func (s *longHeaderSealer) Overhead() int {
|
|||
}
|
||||
|
||||
type longHeaderOpener struct {
|
||||
aead cipher.AEAD
|
||||
aead *xorNonceAEAD
|
||||
headerProtector headerProtector
|
||||
highestRcvdPN protocol.PacketNumber // highest packet number received (which could be successfully unprotected)
|
||||
|
||||
// use a single slice to avoid allocations
|
||||
nonceBuf []byte
|
||||
// use a single array to avoid allocations
|
||||
nonceBuf [8]byte
|
||||
}
|
||||
|
||||
var _ LongHeaderOpener = &longHeaderOpener{}
|
||||
|
||||
func newLongHeaderOpener(aead cipher.AEAD, headerProtector headerProtector) LongHeaderOpener {
|
||||
func newLongHeaderOpener(aead *xorNonceAEAD, headerProtector headerProtector) LongHeaderOpener {
|
||||
if aead.NonceSize() != 8 {
|
||||
panic("unexpected nonce size")
|
||||
}
|
||||
return &longHeaderOpener{
|
||||
aead: aead,
|
||||
headerProtector: headerProtector,
|
||||
nonceBuf: make([]byte, aead.NonceSize()),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -76,10 +75,8 @@ func (o *longHeaderOpener) DecodePacketNumber(wirePN protocol.PacketNumber, wire
|
|||
}
|
||||
|
||||
func (o *longHeaderOpener) Open(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error) {
|
||||
binary.BigEndian.PutUint64(o.nonceBuf[len(o.nonceBuf)-8:], uint64(pn))
|
||||
// The AEAD we're using here will be the qtls.aeadAESGCM13.
|
||||
// It uses the nonce provided here and XOR it with the IV.
|
||||
dec, err := o.aead.Open(dst, o.nonceBuf, src, ad)
|
||||
binary.BigEndian.PutUint64(o.nonceBuf[:], uint64(pn))
|
||||
dec, err := o.aead.Open(dst, o.nonceBuf[:], src, ad)
|
||||
if err == nil {
|
||||
o.highestRcvdPN = max(o.highestRcvdPN, pn)
|
||||
} else {
|
||||
|
|
|
@ -18,7 +18,7 @@ type cipherSuite struct {
|
|||
ID uint16
|
||||
Hash crypto.Hash
|
||||
KeyLen int
|
||||
AEAD func(key, nonceMask []byte) cipher.AEAD
|
||||
AEAD func(key, nonceMask []byte) *xorNonceAEAD
|
||||
}
|
||||
|
||||
func (s cipherSuite) IVLen() int { return aeadNonceLength }
|
||||
|
@ -36,7 +36,7 @@ func getCipherSuite(id uint16) *cipherSuite {
|
|||
}
|
||||
}
|
||||
|
||||
func aeadAESGCMTLS13(key, nonceMask []byte) cipher.AEAD {
|
||||
func aeadAESGCMTLS13(key, nonceMask []byte) *xorNonceAEAD {
|
||||
if len(nonceMask) != aeadNonceLength {
|
||||
panic("tls: internal error: wrong nonce length")
|
||||
}
|
||||
|
@ -54,7 +54,7 @@ func aeadAESGCMTLS13(key, nonceMask []byte) cipher.AEAD {
|
|||
return ret
|
||||
}
|
||||
|
||||
func aeadChaCha20Poly1305(key, nonceMask []byte) cipher.AEAD {
|
||||
func aeadChaCha20Poly1305(key, nonceMask []byte) *xorNonceAEAD {
|
||||
if len(nonceMask) != aeadNonceLength {
|
||||
panic("tls: internal error: wrong nonce length")
|
||||
}
|
||||
|
|
|
@ -8,7 +8,6 @@ import (
|
|||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
|
@ -33,7 +32,7 @@ type cryptoSetup struct {
|
|||
|
||||
events []Event
|
||||
|
||||
version protocol.VersionNumber
|
||||
version protocol.Version
|
||||
|
||||
ourParams *wire.TransportParameters
|
||||
peerParams *wire.TransportParameters
|
||||
|
@ -48,8 +47,6 @@ type cryptoSetup struct {
|
|||
|
||||
perspective protocol.Perspective
|
||||
|
||||
mutex sync.Mutex // protects all members below
|
||||
|
||||
handshakeCompleteTime time.Time
|
||||
|
||||
zeroRTTOpener LongHeaderOpener // only set for the server
|
||||
|
@ -79,7 +76,7 @@ func NewCryptoSetupClient(
|
|||
rttStats *utils.RTTStats,
|
||||
tracer *logging.ConnectionTracer,
|
||||
logger utils.Logger,
|
||||
version protocol.VersionNumber,
|
||||
version protocol.Version,
|
||||
) CryptoSetup {
|
||||
cs := newCryptoSetup(
|
||||
connID,
|
||||
|
@ -114,7 +111,7 @@ func NewCryptoSetupServer(
|
|||
rttStats *utils.RTTStats,
|
||||
tracer *logging.ConnectionTracer,
|
||||
logger utils.Logger,
|
||||
version protocol.VersionNumber,
|
||||
version protocol.Version,
|
||||
) CryptoSetup {
|
||||
cs := newCryptoSetup(
|
||||
connID,
|
||||
|
@ -172,7 +169,7 @@ func newCryptoSetup(
|
|||
tracer *logging.ConnectionTracer,
|
||||
logger utils.Logger,
|
||||
perspective protocol.Perspective,
|
||||
version protocol.VersionNumber,
|
||||
version protocol.Version,
|
||||
) *cryptoSetup {
|
||||
initialSealer, initialOpener := NewInitialAEAD(connID, perspective, version)
|
||||
if tracer != nil && tracer.UpdatedKeyFromTLS != nil {
|
||||
|
@ -269,10 +266,10 @@ func (h *cryptoSetup) handleEvent(ev tls.QUICEvent) (done bool, err error) {
|
|||
case tls.QUICNoEvent:
|
||||
return true, nil
|
||||
case tls.QUICSetReadSecret:
|
||||
h.SetReadKey(ev.Level, ev.Suite, ev.Data)
|
||||
h.setReadKey(ev.Level, ev.Suite, ev.Data)
|
||||
return false, nil
|
||||
case tls.QUICSetWriteSecret:
|
||||
h.SetWriteKey(ev.Level, ev.Suite, ev.Data)
|
||||
h.setWriteKey(ev.Level, ev.Suite, ev.Data)
|
||||
return false, nil
|
||||
case tls.QUICTransportParameters:
|
||||
return false, h.handleTransportParameters(ev.Data)
|
||||
|
@ -434,19 +431,16 @@ func (h *cryptoSetup) handleSessionTicket(sessionTicketData []byte, using0RTT bo
|
|||
func (h *cryptoSetup) rejected0RTT() {
|
||||
h.logger.Debugf("0-RTT was rejected. Dropping 0-RTT keys.")
|
||||
|
||||
h.mutex.Lock()
|
||||
had0RTTKeys := h.zeroRTTSealer != nil
|
||||
h.zeroRTTSealer = nil
|
||||
h.mutex.Unlock()
|
||||
|
||||
if had0RTTKeys {
|
||||
h.events = append(h.events, Event{Kind: EventDiscard0RTTKeys})
|
||||
}
|
||||
}
|
||||
|
||||
func (h *cryptoSetup) SetReadKey(el tls.QUICEncryptionLevel, suiteID uint16, trafficSecret []byte) {
|
||||
func (h *cryptoSetup) setReadKey(el tls.QUICEncryptionLevel, suiteID uint16, trafficSecret []byte) {
|
||||
suite := getCipherSuite(suiteID)
|
||||
h.mutex.Lock()
|
||||
//nolint:exhaustive // The TLS stack doesn't export Initial keys.
|
||||
switch el {
|
||||
case tls.QUICEncryptionLevelEarly:
|
||||
|
@ -478,16 +472,14 @@ func (h *cryptoSetup) SetReadKey(el tls.QUICEncryptionLevel, suiteID uint16, tra
|
|||
default:
|
||||
panic("unexpected read encryption level")
|
||||
}
|
||||
h.mutex.Unlock()
|
||||
h.events = append(h.events, Event{Kind: EventReceivedReadKeys})
|
||||
if h.tracer != nil && h.tracer.UpdatedKeyFromTLS != nil {
|
||||
h.tracer.UpdatedKeyFromTLS(qtls.FromTLSEncryptionLevel(el), h.perspective.Opposite())
|
||||
}
|
||||
}
|
||||
|
||||
func (h *cryptoSetup) SetWriteKey(el tls.QUICEncryptionLevel, suiteID uint16, trafficSecret []byte) {
|
||||
func (h *cryptoSetup) setWriteKey(el tls.QUICEncryptionLevel, suiteID uint16, trafficSecret []byte) {
|
||||
suite := getCipherSuite(suiteID)
|
||||
h.mutex.Lock()
|
||||
//nolint:exhaustive // The TLS stack doesn't export Initial keys.
|
||||
switch el {
|
||||
case tls.QUICEncryptionLevelEarly:
|
||||
|
@ -498,7 +490,6 @@ func (h *cryptoSetup) SetWriteKey(el tls.QUICEncryptionLevel, suiteID uint16, tr
|
|||
createAEAD(suite, trafficSecret, h.version),
|
||||
newHeaderProtector(suite, trafficSecret, true, h.version),
|
||||
)
|
||||
h.mutex.Unlock()
|
||||
if h.logger.Debug() {
|
||||
h.logger.Debugf("Installed 0-RTT Write keys (using %s)", tls.CipherSuiteName(suite.ID))
|
||||
}
|
||||
|
@ -533,7 +524,6 @@ func (h *cryptoSetup) SetWriteKey(el tls.QUICEncryptionLevel, suiteID uint16, tr
|
|||
default:
|
||||
panic("unexpected write encryption level")
|
||||
}
|
||||
h.mutex.Unlock()
|
||||
if h.tracer != nil && h.tracer.UpdatedKeyFromTLS != nil {
|
||||
h.tracer.UpdatedKeyFromTLS(qtls.FromTLSEncryptionLevel(el), h.perspective)
|
||||
}
|
||||
|
@ -555,11 +545,9 @@ func (h *cryptoSetup) writeRecord(encLevel tls.QUICEncryptionLevel, p []byte) {
|
|||
}
|
||||
|
||||
func (h *cryptoSetup) DiscardInitialKeys() {
|
||||
h.mutex.Lock()
|
||||
dropped := h.initialOpener != nil
|
||||
h.initialOpener = nil
|
||||
h.initialSealer = nil
|
||||
h.mutex.Unlock()
|
||||
if dropped {
|
||||
h.logger.Debugf("Dropping Initial keys.")
|
||||
}
|
||||
|
@ -574,22 +562,17 @@ func (h *cryptoSetup) SetHandshakeConfirmed() {
|
|||
h.aead.SetHandshakeConfirmed()
|
||||
// drop Handshake keys
|
||||
var dropped bool
|
||||
h.mutex.Lock()
|
||||
if h.handshakeOpener != nil {
|
||||
h.handshakeOpener = nil
|
||||
h.handshakeSealer = nil
|
||||
dropped = true
|
||||
}
|
||||
h.mutex.Unlock()
|
||||
if dropped {
|
||||
h.logger.Debugf("Dropping Handshake keys.")
|
||||
}
|
||||
}
|
||||
|
||||
func (h *cryptoSetup) GetInitialSealer() (LongHeaderSealer, error) {
|
||||
h.mutex.Lock()
|
||||
defer h.mutex.Unlock()
|
||||
|
||||
if h.initialSealer == nil {
|
||||
return nil, ErrKeysDropped
|
||||
}
|
||||
|
@ -597,9 +580,6 @@ func (h *cryptoSetup) GetInitialSealer() (LongHeaderSealer, error) {
|
|||
}
|
||||
|
||||
func (h *cryptoSetup) Get0RTTSealer() (LongHeaderSealer, error) {
|
||||
h.mutex.Lock()
|
||||
defer h.mutex.Unlock()
|
||||
|
||||
if h.zeroRTTSealer == nil {
|
||||
return nil, ErrKeysDropped
|
||||
}
|
||||
|
@ -607,9 +587,6 @@ func (h *cryptoSetup) Get0RTTSealer() (LongHeaderSealer, error) {
|
|||
}
|
||||
|
||||
func (h *cryptoSetup) GetHandshakeSealer() (LongHeaderSealer, error) {
|
||||
h.mutex.Lock()
|
||||
defer h.mutex.Unlock()
|
||||
|
||||
if h.handshakeSealer == nil {
|
||||
if h.initialSealer == nil {
|
||||
return nil, ErrKeysDropped
|
||||
|
@ -620,9 +597,6 @@ func (h *cryptoSetup) GetHandshakeSealer() (LongHeaderSealer, error) {
|
|||
}
|
||||
|
||||
func (h *cryptoSetup) Get1RTTSealer() (ShortHeaderSealer, error) {
|
||||
h.mutex.Lock()
|
||||
defer h.mutex.Unlock()
|
||||
|
||||
if !h.has1RTTSealer {
|
||||
return nil, ErrKeysNotYetAvailable
|
||||
}
|
||||
|
@ -630,9 +604,6 @@ func (h *cryptoSetup) Get1RTTSealer() (ShortHeaderSealer, error) {
|
|||
}
|
||||
|
||||
func (h *cryptoSetup) GetInitialOpener() (LongHeaderOpener, error) {
|
||||
h.mutex.Lock()
|
||||
defer h.mutex.Unlock()
|
||||
|
||||
if h.initialOpener == nil {
|
||||
return nil, ErrKeysDropped
|
||||
}
|
||||
|
@ -640,9 +611,6 @@ func (h *cryptoSetup) GetInitialOpener() (LongHeaderOpener, error) {
|
|||
}
|
||||
|
||||
func (h *cryptoSetup) Get0RTTOpener() (LongHeaderOpener, error) {
|
||||
h.mutex.Lock()
|
||||
defer h.mutex.Unlock()
|
||||
|
||||
if h.zeroRTTOpener == nil {
|
||||
if h.initialOpener != nil {
|
||||
return nil, ErrKeysNotYetAvailable
|
||||
|
@ -654,9 +622,6 @@ func (h *cryptoSetup) Get0RTTOpener() (LongHeaderOpener, error) {
|
|||
}
|
||||
|
||||
func (h *cryptoSetup) GetHandshakeOpener() (LongHeaderOpener, error) {
|
||||
h.mutex.Lock()
|
||||
defer h.mutex.Unlock()
|
||||
|
||||
if h.handshakeOpener == nil {
|
||||
if h.initialOpener != nil {
|
||||
return nil, ErrKeysNotYetAvailable
|
||||
|
@ -668,9 +633,6 @@ func (h *cryptoSetup) GetHandshakeOpener() (LongHeaderOpener, error) {
|
|||
}
|
||||
|
||||
func (h *cryptoSetup) Get1RTTOpener() (ShortHeaderOpener, error) {
|
||||
h.mutex.Lock()
|
||||
defer h.mutex.Unlock()
|
||||
|
||||
if h.zeroRTTOpener != nil && time.Since(h.handshakeCompleteTime) > 3*h.rttStats.PTO(true) {
|
||||
h.zeroRTTOpener = nil
|
||||
h.logger.Debugf("Dropping 0-RTT keys.")
|
||||
|
|
|
@ -17,14 +17,14 @@ type headerProtector interface {
|
|||
DecryptHeader(sample []byte, firstByte *byte, hdrBytes []byte)
|
||||
}
|
||||
|
||||
func hkdfHeaderProtectionLabel(v protocol.VersionNumber) string {
|
||||
func hkdfHeaderProtectionLabel(v protocol.Version) string {
|
||||
if v == protocol.Version2 {
|
||||
return "quicv2 hp"
|
||||
}
|
||||
return "quic hp"
|
||||
}
|
||||
|
||||
func newHeaderProtector(suite *cipherSuite, trafficSecret []byte, isLongHeader bool, v protocol.VersionNumber) headerProtector {
|
||||
func newHeaderProtector(suite *cipherSuite, trafficSecret []byte, isLongHeader bool, v protocol.Version) headerProtector {
|
||||
hkdfLabel := hkdfHeaderProtectionLabel(v)
|
||||
switch suite.ID {
|
||||
case tls.TLS_AES_128_GCM_SHA256, tls.TLS_AES_256_GCM_SHA384:
|
||||
|
@ -37,7 +37,7 @@ func newHeaderProtector(suite *cipherSuite, trafficSecret []byte, isLongHeader b
|
|||
}
|
||||
|
||||
type aesHeaderProtector struct {
|
||||
mask []byte
|
||||
mask [16]byte // AES always has a 16 byte block size
|
||||
block cipher.Block
|
||||
isLongHeader bool
|
||||
}
|
||||
|
@ -52,7 +52,6 @@ func newAESHeaderProtector(suite *cipherSuite, trafficSecret []byte, isLongHeade
|
|||
}
|
||||
return &aesHeaderProtector{
|
||||
block: block,
|
||||
mask: make([]byte, block.BlockSize()),
|
||||
isLongHeader: isLongHeader,
|
||||
}
|
||||
}
|
||||
|
@ -69,7 +68,7 @@ func (p *aesHeaderProtector) apply(sample []byte, firstByte *byte, hdrBytes []by
|
|||
if len(sample) != len(p.mask) {
|
||||
panic("invalid sample size")
|
||||
}
|
||||
p.block.Encrypt(p.mask, sample)
|
||||
p.block.Encrypt(p.mask[:], sample)
|
||||
if p.isLongHeader {
|
||||
*firstByte ^= p.mask[0] & 0xf
|
||||
} else {
|
||||
|
|
|
@ -7,7 +7,7 @@ import (
|
|||
"golang.org/x/crypto/hkdf"
|
||||
)
|
||||
|
||||
// hkdfExpandLabel HKDF expands a label.
|
||||
// hkdfExpandLabel HKDF expands a label as defined in RFC 8446, section 7.1.
|
||||
// Since this implementation avoids using a cryptobyte.Builder, it is about 15% faster than the
|
||||
// hkdfExpandLabel in the standard library.
|
||||
func hkdfExpandLabel(hash crypto.Hash, secret, context []byte, label string, length int) []byte {
|
||||
|
|
|
@ -21,7 +21,7 @@ const (
|
|||
hkdfLabelIVV2 = "quicv2 iv"
|
||||
)
|
||||
|
||||
func getSalt(v protocol.VersionNumber) []byte {
|
||||
func getSalt(v protocol.Version) []byte {
|
||||
if v == protocol.Version2 {
|
||||
return quicSaltV2
|
||||
}
|
||||
|
@ -31,7 +31,7 @@ func getSalt(v protocol.VersionNumber) []byte {
|
|||
var initialSuite = getCipherSuite(tls.TLS_AES_128_GCM_SHA256)
|
||||
|
||||
// NewInitialAEAD creates a new AEAD for Initial encryption / decryption.
|
||||
func NewInitialAEAD(connID protocol.ConnectionID, pers protocol.Perspective, v protocol.VersionNumber) (LongHeaderSealer, LongHeaderOpener) {
|
||||
func NewInitialAEAD(connID protocol.ConnectionID, pers protocol.Perspective, v protocol.Version) (LongHeaderSealer, LongHeaderOpener) {
|
||||
clientSecret, serverSecret := computeSecrets(connID, v)
|
||||
var mySecret, otherSecret []byte
|
||||
if pers == protocol.PerspectiveClient {
|
||||
|
@ -51,14 +51,14 @@ func NewInitialAEAD(connID protocol.ConnectionID, pers protocol.Perspective, v p
|
|||
newLongHeaderOpener(decrypter, newAESHeaderProtector(initialSuite, otherSecret, true, hkdfHeaderProtectionLabel(v)))
|
||||
}
|
||||
|
||||
func computeSecrets(connID protocol.ConnectionID, v protocol.VersionNumber) (clientSecret, serverSecret []byte) {
|
||||
func computeSecrets(connID protocol.ConnectionID, v protocol.Version) (clientSecret, serverSecret []byte) {
|
||||
initialSecret := hkdf.Extract(crypto.SHA256.New, connID.Bytes(), getSalt(v))
|
||||
clientSecret = hkdfExpandLabel(crypto.SHA256, initialSecret, []byte{}, "client in", crypto.SHA256.Size())
|
||||
serverSecret = hkdfExpandLabel(crypto.SHA256, initialSecret, []byte{}, "server in", crypto.SHA256.Size())
|
||||
return
|
||||
}
|
||||
|
||||
func computeInitialKeyAndIV(secret []byte, v protocol.VersionNumber) (key, iv []byte) {
|
||||
func computeInitialKeyAndIV(secret []byte, v protocol.Version) (key, iv []byte) {
|
||||
keyLabel := hkdfLabelKeyV1
|
||||
ivLabel := hkdfLabelIVV1
|
||||
if v == protocol.Version2 {
|
||||
|
|
|
@ -40,7 +40,7 @@ var (
|
|||
)
|
||||
|
||||
// GetRetryIntegrityTag calculates the integrity tag on a Retry packet
|
||||
func GetRetryIntegrityTag(retry []byte, origDestConnID protocol.ConnectionID, version protocol.VersionNumber) *[16]byte {
|
||||
func GetRetryIntegrityTag(retry []byte, origDestConnID protocol.ConnectionID, version protocol.Version) *[16]byte {
|
||||
retryMutex.Lock()
|
||||
defer retryMutex.Unlock()
|
||||
|
||||
|
|
|
@ -59,7 +59,7 @@ type updatableAEAD struct {
|
|||
|
||||
tracer *logging.ConnectionTracer
|
||||
logger utils.Logger
|
||||
version protocol.VersionNumber
|
||||
version protocol.Version
|
||||
|
||||
// use a single slice to avoid allocations
|
||||
nonceBuf []byte
|
||||
|
@ -70,7 +70,7 @@ var (
|
|||
_ ShortHeaderSealer = &updatableAEAD{}
|
||||
)
|
||||
|
||||
func newUpdatableAEAD(rttStats *utils.RTTStats, tracer *logging.ConnectionTracer, logger utils.Logger, version protocol.VersionNumber) *updatableAEAD {
|
||||
func newUpdatableAEAD(rttStats *utils.RTTStats, tracer *logging.ConnectionTracer, logger utils.Logger, version protocol.Version) *updatableAEAD {
|
||||
return &updatableAEAD{
|
||||
firstPacketNumber: protocol.InvalidPacketNumber,
|
||||
largestAcked: protocol.InvalidPacketNumber,
|
||||
|
@ -133,7 +133,7 @@ func (a *updatableAEAD) SetReadKey(suite *cipherSuite, trafficSecret []byte) {
|
|||
|
||||
// SetWriteKey sets the write key.
|
||||
// For the client, this function is called after SetReadKey.
|
||||
// For the server, this function is called before SetWriteKey.
|
||||
// For the server, this function is called before SetReadKey.
|
||||
func (a *updatableAEAD) SetWriteKey(suite *cipherSuite, trafficSecret []byte) {
|
||||
a.sendAEAD = createAEAD(suite, trafficSecret, a.version)
|
||||
a.headerEncrypter = newHeaderProtector(suite, trafficSecret, false, a.version)
|
||||
|
|
|
@ -17,9 +17,9 @@ func (p Perspective) Opposite() Perspective {
|
|||
func (p Perspective) String() string {
|
||||
switch p {
|
||||
case PerspectiveServer:
|
||||
return "Server"
|
||||
return "server"
|
||||
case PerspectiveClient:
|
||||
return "Client"
|
||||
return "client"
|
||||
default:
|
||||
return "invalid perspective"
|
||||
}
|
||||
|
|
|
@ -1,14 +1,17 @@
|
|||
package protocol
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"math"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/exp/rand"
|
||||
)
|
||||
|
||||
// VersionNumber is a version number as int
|
||||
type VersionNumber uint32
|
||||
// Version is a version number as int
|
||||
type Version uint32
|
||||
|
||||
// gQUIC version range as defined in the wiki: https://github.com/quicwg/base-drafts/wiki/QUIC-Versions
|
||||
const (
|
||||
|
@ -18,22 +21,22 @@ const (
|
|||
|
||||
// The version numbers, making grepping easier
|
||||
const (
|
||||
VersionUnknown VersionNumber = math.MaxUint32
|
||||
versionDraft29 VersionNumber = 0xff00001d // draft-29 used to be a widely deployed version
|
||||
Version1 VersionNumber = 0x1
|
||||
Version2 VersionNumber = 0x6b3343cf
|
||||
VersionUnknown Version = math.MaxUint32
|
||||
versionDraft29 Version = 0xff00001d // draft-29 used to be a widely deployed version
|
||||
Version1 Version = 0x1
|
||||
Version2 Version = 0x6b3343cf
|
||||
)
|
||||
|
||||
// SupportedVersions lists the versions that the server supports
|
||||
// must be in sorted descending order
|
||||
var SupportedVersions = []VersionNumber{Version1, Version2}
|
||||
var SupportedVersions = []Version{Version1, Version2}
|
||||
|
||||
// IsValidVersion says if the version is known to quic-go
|
||||
func IsValidVersion(v VersionNumber) bool {
|
||||
func IsValidVersion(v Version) bool {
|
||||
return v == Version1 || IsSupportedVersion(SupportedVersions, v)
|
||||
}
|
||||
|
||||
func (vn VersionNumber) String() string {
|
||||
func (vn Version) String() string {
|
||||
//nolint:exhaustive
|
||||
switch vn {
|
||||
case VersionUnknown:
|
||||
|
@ -52,16 +55,16 @@ func (vn VersionNumber) String() string {
|
|||
}
|
||||
}
|
||||
|
||||
func (vn VersionNumber) isGQUIC() bool {
|
||||
func (vn Version) isGQUIC() bool {
|
||||
return vn > gquicVersion0 && vn <= maxGquicVersion
|
||||
}
|
||||
|
||||
func (vn VersionNumber) toGQUICVersion() int {
|
||||
func (vn Version) toGQUICVersion() int {
|
||||
return int(10*(vn-gquicVersion0)/0x100) + int(vn%0x10)
|
||||
}
|
||||
|
||||
// IsSupportedVersion returns true if the server supports this version
|
||||
func IsSupportedVersion(supported []VersionNumber, v VersionNumber) bool {
|
||||
func IsSupportedVersion(supported []Version, v Version) bool {
|
||||
for _, t := range supported {
|
||||
if t == v {
|
||||
return true
|
||||
|
@ -74,7 +77,7 @@ func IsSupportedVersion(supported []VersionNumber, v VersionNumber) bool {
|
|||
// ours is a slice of versions that we support, sorted by our preference (descending)
|
||||
// theirs is a slice of versions offered by the peer. The order does not matter.
|
||||
// The bool returned indicates if a matching version was found.
|
||||
func ChooseSupportedVersion(ours, theirs []VersionNumber) (VersionNumber, bool) {
|
||||
func ChooseSupportedVersion(ours, theirs []Version) (Version, bool) {
|
||||
for _, ourVer := range ours {
|
||||
for _, theirVer := range theirs {
|
||||
if ourVer == theirVer {
|
||||
|
@ -85,19 +88,25 @@ func ChooseSupportedVersion(ours, theirs []VersionNumber) (VersionNumber, bool)
|
|||
return 0, false
|
||||
}
|
||||
|
||||
// generateReservedVersion generates a reserved version number (v & 0x0f0f0f0f == 0x0a0a0a0a)
|
||||
func generateReservedVersion() VersionNumber {
|
||||
b := make([]byte, 4)
|
||||
_, _ = rand.Read(b) // ignore the error here. Failure to read random data doesn't break anything
|
||||
return VersionNumber((binary.BigEndian.Uint32(b) | 0x0a0a0a0a) & 0xfafafafa)
|
||||
var (
|
||||
versionNegotiationMx sync.Mutex
|
||||
versionNegotiationRand = rand.New(rand.NewSource(uint64(time.Now().UnixNano())))
|
||||
)
|
||||
|
||||
// generateReservedVersion generates a reserved version (v & 0x0f0f0f0f == 0x0a0a0a0a)
|
||||
func generateReservedVersion() Version {
|
||||
var b [4]byte
|
||||
_, _ = versionNegotiationRand.Read(b[:]) // ignore the error here. Failure to read random data doesn't break anything
|
||||
return Version((binary.BigEndian.Uint32(b[:]) | 0x0a0a0a0a) & 0xfafafafa)
|
||||
}
|
||||
|
||||
// GetGreasedVersions adds one reserved version number to a slice of version numbers, at a random position
|
||||
func GetGreasedVersions(supported []VersionNumber) []VersionNumber {
|
||||
b := make([]byte, 1)
|
||||
_, _ = rand.Read(b) // ignore the error here. Failure to read random data doesn't break anything
|
||||
randPos := int(b[0]) % (len(supported) + 1)
|
||||
greased := make([]VersionNumber, len(supported)+1)
|
||||
// GetGreasedVersions adds one reserved version number to a slice of version numbers, at a random position.
|
||||
// It doesn't modify the supported slice.
|
||||
func GetGreasedVersions(supported []Version) []Version {
|
||||
versionNegotiationMx.Lock()
|
||||
defer versionNegotiationMx.Unlock()
|
||||
randPos := rand.Intn(len(supported) + 1)
|
||||
greased := make([]Version, len(supported)+1)
|
||||
copy(greased, supported[:randPos])
|
||||
greased[randPos] = generateReservedVersion()
|
||||
copy(greased[randPos+1:], supported[randPos:])
|
||||
|
|
|
@ -101,8 +101,8 @@ func (e *HandshakeTimeoutError) Is(target error) bool { return target == net.Err
|
|||
|
||||
// A VersionNegotiationError occurs when the client and the server can't agree on a QUIC version.
|
||||
type VersionNegotiationError struct {
|
||||
Ours []protocol.VersionNumber
|
||||
Theirs []protocol.VersionNumber
|
||||
Ours []protocol.Version
|
||||
Theirs []protocol.Version
|
||||
}
|
||||
|
||||
func (e *VersionNegotiationError) Error() string {
|
||||
|
|
|
@ -1,23 +1,11 @@
|
|||
package qtls
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/cipher"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
type cipherSuiteTLS13 struct {
|
||||
ID uint16
|
||||
KeyLen int
|
||||
AEAD func(key, fixedNonce []byte) cipher.AEAD
|
||||
Hash crypto.Hash
|
||||
}
|
||||
|
||||
//go:linkname cipherSuiteTLS13ByID crypto/tls.cipherSuiteTLS13ByID
|
||||
func cipherSuiteTLS13ByID(id uint16) *cipherSuiteTLS13
|
||||
|
||||
//go:linkname cipherSuitesTLS13 crypto/tls.cipherSuitesTLS13
|
||||
var cipherSuitesTLS13 []unsafe.Pointer
|
||||
|
||||
|
|
|
@ -1,12 +1,12 @@
|
|||
//go:build go1.21
|
||||
|
||||
package qtls
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type clientSessionCache struct {
|
||||
mx sync.Mutex
|
||||
getData func(earlyData bool) []byte
|
||||
setData func(data []byte, earlyData bool) (allowEarlyData bool)
|
||||
wrapped tls.ClientSessionCache
|
||||
|
@ -14,7 +14,10 @@ type clientSessionCache struct {
|
|||
|
||||
var _ tls.ClientSessionCache = &clientSessionCache{}
|
||||
|
||||
func (c clientSessionCache) Put(key string, cs *tls.ClientSessionState) {
|
||||
func (c *clientSessionCache) Put(key string, cs *tls.ClientSessionState) {
|
||||
c.mx.Lock()
|
||||
defer c.mx.Unlock()
|
||||
|
||||
if cs == nil {
|
||||
c.wrapped.Put(key, nil)
|
||||
return
|
||||
|
@ -34,7 +37,10 @@ func (c clientSessionCache) Put(key string, cs *tls.ClientSessionState) {
|
|||
c.wrapped.Put(key, newCS)
|
||||
}
|
||||
|
||||
func (c clientSessionCache) Get(key string) (*tls.ClientSessionState, bool) {
|
||||
func (c *clientSessionCache) Get(key string) (*tls.ClientSessionState, bool) {
|
||||
c.mx.Lock()
|
||||
defer c.mx.Unlock()
|
||||
|
||||
cs, ok := c.wrapped.Get(key)
|
||||
if !ok || cs == nil {
|
||||
return cs, ok
|
||||
|
|
|
@ -27,18 +27,6 @@ func MinTime(a, b time.Time) time.Time {
|
|||
return a
|
||||
}
|
||||
|
||||
// MinNonZeroTime returns the earlist time that is not time.Time{}
|
||||
// If both a and b are time.Time{}, it returns time.Time{}
|
||||
func MinNonZeroTime(a, b time.Time) time.Time {
|
||||
if a.IsZero() {
|
||||
return b
|
||||
}
|
||||
if b.IsZero() {
|
||||
return a
|
||||
}
|
||||
return MinTime(a, b)
|
||||
}
|
||||
|
||||
// MaxTime returns the later time
|
||||
func MaxTime(a, b time.Time) time.Time {
|
||||
if a.After(b) {
|
||||
|
|
|
@ -22,7 +22,7 @@ type AckFrame struct {
|
|||
}
|
||||
|
||||
// parseAckFrame reads an ACK frame
|
||||
func parseAckFrame(frame *AckFrame, r *bytes.Reader, typ uint64, ackDelayExponent uint8, _ protocol.VersionNumber) error {
|
||||
func parseAckFrame(frame *AckFrame, r *bytes.Reader, typ uint64, ackDelayExponent uint8, _ protocol.Version) error {
|
||||
ecn := typ == ackECNFrameType
|
||||
|
||||
la, err := quicvarint.Read(r)
|
||||
|
@ -110,7 +110,7 @@ func parseAckFrame(frame *AckFrame, r *bytes.Reader, typ uint64, ackDelayExponen
|
|||
}
|
||||
|
||||
// Append appends an ACK frame.
|
||||
func (f *AckFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, error) {
|
||||
func (f *AckFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {
|
||||
hasECN := f.ECT0 > 0 || f.ECT1 > 0 || f.ECNCE > 0
|
||||
if hasECN {
|
||||
b = append(b, ackECNFrameType)
|
||||
|
@ -143,7 +143,7 @@ func (f *AckFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, error) {
|
|||
}
|
||||
|
||||
// Length of a written frame
|
||||
func (f *AckFrame) Length(_ protocol.VersionNumber) protocol.ByteCount {
|
||||
func (f *AckFrame) Length(_ protocol.Version) protocol.ByteCount {
|
||||
largestAcked := f.AckRanges[0].Largest
|
||||
numRanges := f.numEncodableAckRanges()
|
||||
|
||||
|
|
|
@ -16,7 +16,7 @@ type ConnectionCloseFrame struct {
|
|||
ReasonPhrase string
|
||||
}
|
||||
|
||||
func parseConnectionCloseFrame(r *bytes.Reader, typ uint64, _ protocol.VersionNumber) (*ConnectionCloseFrame, error) {
|
||||
func parseConnectionCloseFrame(r *bytes.Reader, typ uint64, _ protocol.Version) (*ConnectionCloseFrame, error) {
|
||||
f := &ConnectionCloseFrame{IsApplicationError: typ == applicationCloseFrameType}
|
||||
ec, err := quicvarint.Read(r)
|
||||
if err != nil {
|
||||
|
@ -53,7 +53,7 @@ func parseConnectionCloseFrame(r *bytes.Reader, typ uint64, _ protocol.VersionNu
|
|||
}
|
||||
|
||||
// Length of a written frame
|
||||
func (f *ConnectionCloseFrame) Length(protocol.VersionNumber) protocol.ByteCount {
|
||||
func (f *ConnectionCloseFrame) Length(protocol.Version) protocol.ByteCount {
|
||||
length := 1 + quicvarint.Len(f.ErrorCode) + quicvarint.Len(uint64(len(f.ReasonPhrase))) + protocol.ByteCount(len(f.ReasonPhrase))
|
||||
if !f.IsApplicationError {
|
||||
length += quicvarint.Len(f.FrameType) // for the frame type
|
||||
|
@ -61,7 +61,7 @@ func (f *ConnectionCloseFrame) Length(protocol.VersionNumber) protocol.ByteCount
|
|||
return length
|
||||
}
|
||||
|
||||
func (f *ConnectionCloseFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, error) {
|
||||
func (f *ConnectionCloseFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {
|
||||
if f.IsApplicationError {
|
||||
b = append(b, applicationCloseFrameType)
|
||||
} else {
|
||||
|
|
|
@ -14,7 +14,7 @@ type CryptoFrame struct {
|
|||
Data []byte
|
||||
}
|
||||
|
||||
func parseCryptoFrame(r *bytes.Reader, _ protocol.VersionNumber) (*CryptoFrame, error) {
|
||||
func parseCryptoFrame(r *bytes.Reader, _ protocol.Version) (*CryptoFrame, error) {
|
||||
frame := &CryptoFrame{}
|
||||
offset, err := quicvarint.Read(r)
|
||||
if err != nil {
|
||||
|
@ -38,7 +38,7 @@ func parseCryptoFrame(r *bytes.Reader, _ protocol.VersionNumber) (*CryptoFrame,
|
|||
return frame, nil
|
||||
}
|
||||
|
||||
func (f *CryptoFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, error) {
|
||||
func (f *CryptoFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {
|
||||
b = append(b, cryptoFrameType)
|
||||
b = quicvarint.Append(b, uint64(f.Offset))
|
||||
b = quicvarint.Append(b, uint64(len(f.Data)))
|
||||
|
@ -47,7 +47,7 @@ func (f *CryptoFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, error)
|
|||
}
|
||||
|
||||
// Length of a written frame
|
||||
func (f *CryptoFrame) Length(_ protocol.VersionNumber) protocol.ByteCount {
|
||||
func (f *CryptoFrame) Length(_ protocol.Version) protocol.ByteCount {
|
||||
return 1 + quicvarint.Len(uint64(f.Offset)) + quicvarint.Len(uint64(len(f.Data))) + protocol.ByteCount(len(f.Data))
|
||||
}
|
||||
|
||||
|
@ -71,7 +71,7 @@ func (f *CryptoFrame) MaxDataLen(maxSize protocol.ByteCount) protocol.ByteCount
|
|||
// The frame might not be split if:
|
||||
// * the size is large enough to fit the whole frame
|
||||
// * the size is too small to fit even a 1-byte frame. In that case, the frame returned is nil.
|
||||
func (f *CryptoFrame) MaybeSplitOffFrame(maxSize protocol.ByteCount, version protocol.VersionNumber) (*CryptoFrame, bool /* was splitting required */) {
|
||||
func (f *CryptoFrame) MaybeSplitOffFrame(maxSize protocol.ByteCount, version protocol.Version) (*CryptoFrame, bool /* was splitting required */) {
|
||||
if f.Length(version) <= maxSize {
|
||||
return nil, false
|
||||
}
|
||||
|
|
|
@ -12,7 +12,7 @@ type DataBlockedFrame struct {
|
|||
MaximumData protocol.ByteCount
|
||||
}
|
||||
|
||||
func parseDataBlockedFrame(r *bytes.Reader, _ protocol.VersionNumber) (*DataBlockedFrame, error) {
|
||||
func parseDataBlockedFrame(r *bytes.Reader, _ protocol.Version) (*DataBlockedFrame, error) {
|
||||
offset, err := quicvarint.Read(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -20,12 +20,12 @@ func parseDataBlockedFrame(r *bytes.Reader, _ protocol.VersionNumber) (*DataBloc
|
|||
return &DataBlockedFrame{MaximumData: protocol.ByteCount(offset)}, nil
|
||||
}
|
||||
|
||||
func (f *DataBlockedFrame) Append(b []byte, version protocol.VersionNumber) ([]byte, error) {
|
||||
func (f *DataBlockedFrame) Append(b []byte, version protocol.Version) ([]byte, error) {
|
||||
b = append(b, dataBlockedFrameType)
|
||||
return quicvarint.Append(b, uint64(f.MaximumData)), nil
|
||||
}
|
||||
|
||||
// Length of a written frame
|
||||
func (f *DataBlockedFrame) Length(version protocol.VersionNumber) protocol.ByteCount {
|
||||
func (f *DataBlockedFrame) Length(version protocol.Version) protocol.ByteCount {
|
||||
return 1 + quicvarint.Len(uint64(f.MaximumData))
|
||||
}
|
||||
|
|
|
@ -20,7 +20,7 @@ type DatagramFrame struct {
|
|||
Data []byte
|
||||
}
|
||||
|
||||
func parseDatagramFrame(r *bytes.Reader, typ uint64, _ protocol.VersionNumber) (*DatagramFrame, error) {
|
||||
func parseDatagramFrame(r *bytes.Reader, typ uint64, _ protocol.Version) (*DatagramFrame, error) {
|
||||
f := &DatagramFrame{}
|
||||
f.DataLenPresent = typ&0x1 > 0
|
||||
|
||||
|
@ -45,7 +45,7 @@ func parseDatagramFrame(r *bytes.Reader, typ uint64, _ protocol.VersionNumber) (
|
|||
return f, nil
|
||||
}
|
||||
|
||||
func (f *DatagramFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, error) {
|
||||
func (f *DatagramFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {
|
||||
typ := uint8(0x30)
|
||||
if f.DataLenPresent {
|
||||
typ ^= 0b1
|
||||
|
@ -59,7 +59,7 @@ func (f *DatagramFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, erro
|
|||
}
|
||||
|
||||
// MaxDataLen returns the maximum data length
|
||||
func (f *DatagramFrame) MaxDataLen(maxSize protocol.ByteCount, version protocol.VersionNumber) protocol.ByteCount {
|
||||
func (f *DatagramFrame) MaxDataLen(maxSize protocol.ByteCount, version protocol.Version) protocol.ByteCount {
|
||||
headerLen := protocol.ByteCount(1)
|
||||
if f.DataLenPresent {
|
||||
// pretend that the data size will be 1 bytes
|
||||
|
@ -77,7 +77,7 @@ func (f *DatagramFrame) MaxDataLen(maxSize protocol.ByteCount, version protocol.
|
|||
}
|
||||
|
||||
// Length of a written frame
|
||||
func (f *DatagramFrame) Length(_ protocol.VersionNumber) protocol.ByteCount {
|
||||
func (f *DatagramFrame) Length(_ protocol.Version) protocol.ByteCount {
|
||||
length := 1 + protocol.ByteCount(len(f.Data))
|
||||
if f.DataLenPresent {
|
||||
length += quicvarint.Len(uint64(len(f.Data)))
|
||||
|
|
|
@ -32,7 +32,7 @@ type ExtendedHeader struct {
|
|||
parsedLen protocol.ByteCount
|
||||
}
|
||||
|
||||
func (h *ExtendedHeader) parse(b *bytes.Reader, v protocol.VersionNumber) (bool /* reserved bits valid */, error) {
|
||||
func (h *ExtendedHeader) parse(b *bytes.Reader, v protocol.Version) (bool /* reserved bits valid */, error) {
|
||||
startLen := b.Len()
|
||||
// read the (now unencrypted) first byte
|
||||
var err error
|
||||
|
@ -51,7 +51,7 @@ func (h *ExtendedHeader) parse(b *bytes.Reader, v protocol.VersionNumber) (bool
|
|||
return reservedBitsValid, err
|
||||
}
|
||||
|
||||
func (h *ExtendedHeader) parseLongHeader(b *bytes.Reader, _ protocol.VersionNumber) (bool /* reserved bits valid */, error) {
|
||||
func (h *ExtendedHeader) parseLongHeader(b *bytes.Reader, _ protocol.Version) (bool /* reserved bits valid */, error) {
|
||||
if err := h.readPacketNumber(b); err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
@ -95,7 +95,7 @@ func (h *ExtendedHeader) readPacketNumber(b *bytes.Reader) error {
|
|||
}
|
||||
|
||||
// Append appends the Header.
|
||||
func (h *ExtendedHeader) Append(b []byte, v protocol.VersionNumber) ([]byte, error) {
|
||||
func (h *ExtendedHeader) Append(b []byte, v protocol.Version) ([]byte, error) {
|
||||
if h.DestConnectionID.Len() > protocol.MaxConnIDLen {
|
||||
return nil, fmt.Errorf("invalid connection ID length: %d bytes", h.DestConnectionID.Len())
|
||||
}
|
||||
|
@ -162,7 +162,7 @@ func (h *ExtendedHeader) ParsedLen() protocol.ByteCount {
|
|||
}
|
||||
|
||||
// GetLength determines the length of the Header.
|
||||
func (h *ExtendedHeader) GetLength(_ protocol.VersionNumber) protocol.ByteCount {
|
||||
func (h *ExtendedHeader) GetLength(_ protocol.Version) protocol.ByteCount {
|
||||
length := 1 /* type byte */ + 4 /* version */ + 1 /* dest conn ID len */ + protocol.ByteCount(h.DestConnectionID.Len()) + 1 /* src conn ID len */ + protocol.ByteCount(h.SrcConnectionID.Len()) + protocol.ByteCount(h.PacketNumberLen) + 2 /* length */
|
||||
if h.Type == protocol.PacketTypeInitial {
|
||||
length += quicvarint.Len(uint64(len(h.Token))) + protocol.ByteCount(len(h.Token))
|
||||
|
|
|
@ -36,7 +36,8 @@ const (
|
|||
handshakeDoneFrameType = 0x1e
|
||||
)
|
||||
|
||||
type frameParser struct {
|
||||
// The FrameParser parses QUIC frames, one by one.
|
||||
type FrameParser struct {
|
||||
r bytes.Reader // cached bytes.Reader, so we don't have to repeatedly allocate them
|
||||
|
||||
ackDelayExponent uint8
|
||||
|
@ -47,11 +48,9 @@ type frameParser struct {
|
|||
ackFrame *AckFrame
|
||||
}
|
||||
|
||||
var _ FrameParser = &frameParser{}
|
||||
|
||||
// NewFrameParser creates a new frame parser.
|
||||
func NewFrameParser(supportsDatagrams bool) *frameParser {
|
||||
return &frameParser{
|
||||
func NewFrameParser(supportsDatagrams bool) *FrameParser {
|
||||
return &FrameParser{
|
||||
r: *bytes.NewReader(nil),
|
||||
supportsDatagrams: supportsDatagrams,
|
||||
ackFrame: &AckFrame{},
|
||||
|
@ -60,7 +59,7 @@ func NewFrameParser(supportsDatagrams bool) *frameParser {
|
|||
|
||||
// ParseNext parses the next frame.
|
||||
// It skips PADDING frames.
|
||||
func (p *frameParser) ParseNext(data []byte, encLevel protocol.EncryptionLevel, v protocol.VersionNumber) (int, Frame, error) {
|
||||
func (p *FrameParser) ParseNext(data []byte, encLevel protocol.EncryptionLevel, v protocol.Version) (int, Frame, error) {
|
||||
startLen := len(data)
|
||||
p.r.Reset(data)
|
||||
frame, err := p.parseNext(&p.r, encLevel, v)
|
||||
|
@ -69,7 +68,7 @@ func (p *frameParser) ParseNext(data []byte, encLevel protocol.EncryptionLevel,
|
|||
return n, frame, err
|
||||
}
|
||||
|
||||
func (p *frameParser) parseNext(r *bytes.Reader, encLevel protocol.EncryptionLevel, v protocol.VersionNumber) (Frame, error) {
|
||||
func (p *FrameParser) parseNext(r *bytes.Reader, encLevel protocol.EncryptionLevel, v protocol.Version) (Frame, error) {
|
||||
for r.Len() != 0 {
|
||||
typ, err := quicvarint.Read(r)
|
||||
if err != nil {
|
||||
|
@ -95,7 +94,7 @@ func (p *frameParser) parseNext(r *bytes.Reader, encLevel protocol.EncryptionLev
|
|||
return nil, nil
|
||||
}
|
||||
|
||||
func (p *frameParser) parseFrame(r *bytes.Reader, typ uint64, encLevel protocol.EncryptionLevel, v protocol.VersionNumber) (Frame, error) {
|
||||
func (p *FrameParser) parseFrame(r *bytes.Reader, typ uint64, encLevel protocol.EncryptionLevel, v protocol.Version) (Frame, error) {
|
||||
var frame Frame
|
||||
var err error
|
||||
if typ&0xf8 == 0x8 {
|
||||
|
@ -163,7 +162,7 @@ func (p *frameParser) parseFrame(r *bytes.Reader, typ uint64, encLevel protocol.
|
|||
return frame, nil
|
||||
}
|
||||
|
||||
func (p *frameParser) isAllowedAtEncLevel(f Frame, encLevel protocol.EncryptionLevel) bool {
|
||||
func (p *FrameParser) isAllowedAtEncLevel(f Frame, encLevel protocol.EncryptionLevel) bool {
|
||||
switch encLevel {
|
||||
case protocol.EncryptionInitial, protocol.EncryptionHandshake:
|
||||
switch f.(type) {
|
||||
|
@ -186,6 +185,8 @@ func (p *frameParser) isAllowedAtEncLevel(f Frame, encLevel protocol.EncryptionL
|
|||
}
|
||||
}
|
||||
|
||||
func (p *frameParser) SetAckDelayExponent(exp uint8) {
|
||||
// SetAckDelayExponent sets the acknowledgment delay exponent (sent in the transport parameters).
|
||||
// This value is used to scale the ACK Delay field in the ACK frame.
|
||||
func (p *FrameParser) SetAckDelayExponent(exp uint8) {
|
||||
p.ackDelayExponent = exp
|
||||
}
|
||||
|
|
|
@ -7,11 +7,11 @@ import (
|
|||
// A HandshakeDoneFrame is a HANDSHAKE_DONE frame
|
||||
type HandshakeDoneFrame struct{}
|
||||
|
||||
func (f *HandshakeDoneFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, error) {
|
||||
func (f *HandshakeDoneFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {
|
||||
return append(b, handshakeDoneFrameType), nil
|
||||
}
|
||||
|
||||
// Length of a written frame
|
||||
func (f *HandshakeDoneFrame) Length(_ protocol.VersionNumber) protocol.ByteCount {
|
||||
func (f *HandshakeDoneFrame) Length(_ protocol.Version) protocol.ByteCount {
|
||||
return 1
|
||||
}
|
||||
|
|
|
@ -85,11 +85,11 @@ func IsLongHeaderPacket(firstByte byte) bool {
|
|||
|
||||
// ParseVersion parses the QUIC version.
|
||||
// It should only be called for Long Header packets (Short Header packets don't contain a version number).
|
||||
func ParseVersion(data []byte) (protocol.VersionNumber, error) {
|
||||
func ParseVersion(data []byte) (protocol.Version, error) {
|
||||
if len(data) < 5 {
|
||||
return 0, io.EOF
|
||||
}
|
||||
return protocol.VersionNumber(binary.BigEndian.Uint32(data[1:5])), nil
|
||||
return protocol.Version(binary.BigEndian.Uint32(data[1:5])), nil
|
||||
}
|
||||
|
||||
// IsVersionNegotiationPacket says if this is a version negotiation packet
|
||||
|
@ -109,7 +109,7 @@ func Is0RTTPacket(b []byte) bool {
|
|||
if !IsLongHeaderPacket(b[0]) {
|
||||
return false
|
||||
}
|
||||
version := protocol.VersionNumber(binary.BigEndian.Uint32(b[1:5]))
|
||||
version := protocol.Version(binary.BigEndian.Uint32(b[1:5]))
|
||||
//nolint:exhaustive // We only need to test QUIC versions that we support.
|
||||
switch version {
|
||||
case protocol.Version1:
|
||||
|
@ -128,7 +128,7 @@ type Header struct {
|
|||
typeByte byte
|
||||
Type protocol.PacketType
|
||||
|
||||
Version protocol.VersionNumber
|
||||
Version protocol.Version
|
||||
SrcConnectionID protocol.ConnectionID
|
||||
DestConnectionID protocol.ConnectionID
|
||||
|
||||
|
@ -184,7 +184,7 @@ func (h *Header) parseLongHeader(b *bytes.Reader) error {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
h.Version = protocol.VersionNumber(v)
|
||||
h.Version = protocol.Version(v)
|
||||
if h.Version != 0 && h.typeByte&0x40 == 0 {
|
||||
return errors.New("not a QUIC packet")
|
||||
}
|
||||
|
@ -278,7 +278,7 @@ func (h *Header) ParsedLen() protocol.ByteCount {
|
|||
|
||||
// ParseExtended parses the version dependent part of the header.
|
||||
// The Reader has to be set such that it points to the first byte of the header.
|
||||
func (h *Header) ParseExtended(b *bytes.Reader, ver protocol.VersionNumber) (*ExtendedHeader, error) {
|
||||
func (h *Header) ParseExtended(b *bytes.Reader, ver protocol.Version) (*ExtendedHeader, error) {
|
||||
extHdr := h.toExtendedHeader()
|
||||
reservedBitsValid, err := extHdr.parse(b, ver)
|
||||
if err != nil {
|
||||
|
|
|
@ -6,12 +6,6 @@ import (
|
|||
|
||||
// A Frame in QUIC
|
||||
type Frame interface {
|
||||
Append(b []byte, version protocol.VersionNumber) ([]byte, error)
|
||||
Length(version protocol.VersionNumber) protocol.ByteCount
|
||||
}
|
||||
|
||||
// A FrameParser parses QUIC frames, one by one.
|
||||
type FrameParser interface {
|
||||
ParseNext([]byte, protocol.EncryptionLevel, protocol.VersionNumber) (int, Frame, error)
|
||||
SetAckDelayExponent(uint8)
|
||||
Append(b []byte, version protocol.Version) ([]byte, error)
|
||||
Length(version protocol.Version) protocol.ByteCount
|
||||
}
|
||||
|
|
|
@ -63,7 +63,9 @@ func LogFrame(logger utils.Logger, frame Frame, sent bool) {
|
|||
logger.Debugf("\t%s &wire.StreamsBlockedFrame{Type: bidi, MaxStreams: %d}", dir, f.StreamLimit)
|
||||
}
|
||||
case *NewConnectionIDFrame:
|
||||
logger.Debugf("\t%s &wire.NewConnectionIDFrame{SequenceNumber: %d, ConnectionID: %s, StatelessResetToken: %#x}", dir, f.SequenceNumber, f.ConnectionID, f.StatelessResetToken)
|
||||
logger.Debugf("\t%s &wire.NewConnectionIDFrame{SequenceNumber: %d, RetirePriorTo: %d, ConnectionID: %s, StatelessResetToken: %#x}", dir, f.SequenceNumber, f.RetirePriorTo, f.ConnectionID, f.StatelessResetToken)
|
||||
case *RetireConnectionIDFrame:
|
||||
logger.Debugf("\t%s &wire.RetireConnectionIDFrame{SequenceNumber: %d}", dir, f.SequenceNumber)
|
||||
case *NewTokenFrame:
|
||||
logger.Debugf("\t%s &wire.NewTokenFrame{Token: %#x}", dir, f.Token)
|
||||
default:
|
||||
|
|
|
@ -13,7 +13,7 @@ type MaxDataFrame struct {
|
|||
}
|
||||
|
||||
// parseMaxDataFrame parses a MAX_DATA frame
|
||||
func parseMaxDataFrame(r *bytes.Reader, _ protocol.VersionNumber) (*MaxDataFrame, error) {
|
||||
func parseMaxDataFrame(r *bytes.Reader, _ protocol.Version) (*MaxDataFrame, error) {
|
||||
frame := &MaxDataFrame{}
|
||||
byteOffset, err := quicvarint.Read(r)
|
||||
if err != nil {
|
||||
|
@ -23,13 +23,13 @@ func parseMaxDataFrame(r *bytes.Reader, _ protocol.VersionNumber) (*MaxDataFrame
|
|||
return frame, nil
|
||||
}
|
||||
|
||||
func (f *MaxDataFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, error) {
|
||||
func (f *MaxDataFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {
|
||||
b = append(b, maxDataFrameType)
|
||||
b = quicvarint.Append(b, uint64(f.MaximumData))
|
||||
return b, nil
|
||||
}
|
||||
|
||||
// Length of a written frame
|
||||
func (f *MaxDataFrame) Length(_ protocol.VersionNumber) protocol.ByteCount {
|
||||
func (f *MaxDataFrame) Length(_ protocol.Version) protocol.ByteCount {
|
||||
return 1 + quicvarint.Len(uint64(f.MaximumData))
|
||||
}
|
||||
|
|
|
@ -13,7 +13,7 @@ type MaxStreamDataFrame struct {
|
|||
MaximumStreamData protocol.ByteCount
|
||||
}
|
||||
|
||||
func parseMaxStreamDataFrame(r *bytes.Reader, _ protocol.VersionNumber) (*MaxStreamDataFrame, error) {
|
||||
func parseMaxStreamDataFrame(r *bytes.Reader, _ protocol.Version) (*MaxStreamDataFrame, error) {
|
||||
sid, err := quicvarint.Read(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -29,7 +29,7 @@ func parseMaxStreamDataFrame(r *bytes.Reader, _ protocol.VersionNumber) (*MaxStr
|
|||
}, nil
|
||||
}
|
||||
|
||||
func (f *MaxStreamDataFrame) Append(b []byte, version protocol.VersionNumber) ([]byte, error) {
|
||||
func (f *MaxStreamDataFrame) Append(b []byte, version protocol.Version) ([]byte, error) {
|
||||
b = append(b, maxStreamDataFrameType)
|
||||
b = quicvarint.Append(b, uint64(f.StreamID))
|
||||
b = quicvarint.Append(b, uint64(f.MaximumStreamData))
|
||||
|
@ -37,6 +37,6 @@ func (f *MaxStreamDataFrame) Append(b []byte, version protocol.VersionNumber) ([
|
|||
}
|
||||
|
||||
// Length of a written frame
|
||||
func (f *MaxStreamDataFrame) Length(version protocol.VersionNumber) protocol.ByteCount {
|
||||
func (f *MaxStreamDataFrame) Length(version protocol.Version) protocol.ByteCount {
|
||||
return 1 + quicvarint.Len(uint64(f.StreamID)) + quicvarint.Len(uint64(f.MaximumStreamData))
|
||||
}
|
||||
|
|
|
@ -14,7 +14,7 @@ type MaxStreamsFrame struct {
|
|||
MaxStreamNum protocol.StreamNum
|
||||
}
|
||||
|
||||
func parseMaxStreamsFrame(r *bytes.Reader, typ uint64, _ protocol.VersionNumber) (*MaxStreamsFrame, error) {
|
||||
func parseMaxStreamsFrame(r *bytes.Reader, typ uint64, _ protocol.Version) (*MaxStreamsFrame, error) {
|
||||
f := &MaxStreamsFrame{}
|
||||
switch typ {
|
||||
case bidiMaxStreamsFrameType:
|
||||
|
@ -33,7 +33,7 @@ func parseMaxStreamsFrame(r *bytes.Reader, typ uint64, _ protocol.VersionNumber)
|
|||
return f, nil
|
||||
}
|
||||
|
||||
func (f *MaxStreamsFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, error) {
|
||||
func (f *MaxStreamsFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {
|
||||
switch f.Type {
|
||||
case protocol.StreamTypeBidi:
|
||||
b = append(b, bidiMaxStreamsFrameType)
|
||||
|
@ -45,6 +45,6 @@ func (f *MaxStreamsFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, er
|
|||
}
|
||||
|
||||
// Length of a written frame
|
||||
func (f *MaxStreamsFrame) Length(protocol.VersionNumber) protocol.ByteCount {
|
||||
func (f *MaxStreamsFrame) Length(protocol.Version) protocol.ByteCount {
|
||||
return 1 + quicvarint.Len(uint64(f.MaxStreamNum))
|
||||
}
|
||||
|
|
|
@ -18,7 +18,7 @@ type NewConnectionIDFrame struct {
|
|||
StatelessResetToken protocol.StatelessResetToken
|
||||
}
|
||||
|
||||
func parseNewConnectionIDFrame(r *bytes.Reader, _ protocol.VersionNumber) (*NewConnectionIDFrame, error) {
|
||||
func parseNewConnectionIDFrame(r *bytes.Reader, _ protocol.Version) (*NewConnectionIDFrame, error) {
|
||||
seq, err := quicvarint.Read(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -57,7 +57,7 @@ func parseNewConnectionIDFrame(r *bytes.Reader, _ protocol.VersionNumber) (*NewC
|
|||
return frame, nil
|
||||
}
|
||||
|
||||
func (f *NewConnectionIDFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, error) {
|
||||
func (f *NewConnectionIDFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {
|
||||
b = append(b, newConnectionIDFrameType)
|
||||
b = quicvarint.Append(b, f.SequenceNumber)
|
||||
b = quicvarint.Append(b, f.RetirePriorTo)
|
||||
|
@ -72,6 +72,6 @@ func (f *NewConnectionIDFrame) Append(b []byte, _ protocol.VersionNumber) ([]byt
|
|||
}
|
||||
|
||||
// Length of a written frame
|
||||
func (f *NewConnectionIDFrame) Length(protocol.VersionNumber) protocol.ByteCount {
|
||||
func (f *NewConnectionIDFrame) Length(protocol.Version) protocol.ByteCount {
|
||||
return 1 + quicvarint.Len(f.SequenceNumber) + quicvarint.Len(f.RetirePriorTo) + 1 /* connection ID length */ + protocol.ByteCount(f.ConnectionID.Len()) + 16
|
||||
}
|
||||
|
|
|
@ -14,7 +14,7 @@ type NewTokenFrame struct {
|
|||
Token []byte
|
||||
}
|
||||
|
||||
func parseNewTokenFrame(r *bytes.Reader, _ protocol.VersionNumber) (*NewTokenFrame, error) {
|
||||
func parseNewTokenFrame(r *bytes.Reader, _ protocol.Version) (*NewTokenFrame, error) {
|
||||
tokenLen, err := quicvarint.Read(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -32,7 +32,7 @@ func parseNewTokenFrame(r *bytes.Reader, _ protocol.VersionNumber) (*NewTokenFra
|
|||
return &NewTokenFrame{Token: token}, nil
|
||||
}
|
||||
|
||||
func (f *NewTokenFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, error) {
|
||||
func (f *NewTokenFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {
|
||||
b = append(b, newTokenFrameType)
|
||||
b = quicvarint.Append(b, uint64(len(f.Token)))
|
||||
b = append(b, f.Token...)
|
||||
|
@ -40,6 +40,6 @@ func (f *NewTokenFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, erro
|
|||
}
|
||||
|
||||
// Length of a written frame
|
||||
func (f *NewTokenFrame) Length(protocol.VersionNumber) protocol.ByteCount {
|
||||
func (f *NewTokenFrame) Length(protocol.Version) protocol.ByteCount {
|
||||
return 1 + quicvarint.Len(uint64(len(f.Token))) + protocol.ByteCount(len(f.Token))
|
||||
}
|
||||
|
|
|
@ -12,7 +12,7 @@ type PathChallengeFrame struct {
|
|||
Data [8]byte
|
||||
}
|
||||
|
||||
func parsePathChallengeFrame(r *bytes.Reader, _ protocol.VersionNumber) (*PathChallengeFrame, error) {
|
||||
func parsePathChallengeFrame(r *bytes.Reader, _ protocol.Version) (*PathChallengeFrame, error) {
|
||||
frame := &PathChallengeFrame{}
|
||||
if _, err := io.ReadFull(r, frame.Data[:]); err != nil {
|
||||
if err == io.ErrUnexpectedEOF {
|
||||
|
@ -23,13 +23,13 @@ func parsePathChallengeFrame(r *bytes.Reader, _ protocol.VersionNumber) (*PathCh
|
|||
return frame, nil
|
||||
}
|
||||
|
||||
func (f *PathChallengeFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, error) {
|
||||
func (f *PathChallengeFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {
|
||||
b = append(b, pathChallengeFrameType)
|
||||
b = append(b, f.Data[:]...)
|
||||
return b, nil
|
||||
}
|
||||
|
||||
// Length of a written frame
|
||||
func (f *PathChallengeFrame) Length(_ protocol.VersionNumber) protocol.ByteCount {
|
||||
func (f *PathChallengeFrame) Length(_ protocol.Version) protocol.ByteCount {
|
||||
return 1 + 8
|
||||
}
|
||||
|
|
|
@ -12,7 +12,7 @@ type PathResponseFrame struct {
|
|||
Data [8]byte
|
||||
}
|
||||
|
||||
func parsePathResponseFrame(r *bytes.Reader, _ protocol.VersionNumber) (*PathResponseFrame, error) {
|
||||
func parsePathResponseFrame(r *bytes.Reader, _ protocol.Version) (*PathResponseFrame, error) {
|
||||
frame := &PathResponseFrame{}
|
||||
if _, err := io.ReadFull(r, frame.Data[:]); err != nil {
|
||||
if err == io.ErrUnexpectedEOF {
|
||||
|
@ -23,13 +23,13 @@ func parsePathResponseFrame(r *bytes.Reader, _ protocol.VersionNumber) (*PathRes
|
|||
return frame, nil
|
||||
}
|
||||
|
||||
func (f *PathResponseFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, error) {
|
||||
func (f *PathResponseFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {
|
||||
b = append(b, pathResponseFrameType)
|
||||
b = append(b, f.Data[:]...)
|
||||
return b, nil
|
||||
}
|
||||
|
||||
// Length of a written frame
|
||||
func (f *PathResponseFrame) Length(_ protocol.VersionNumber) protocol.ByteCount {
|
||||
func (f *PathResponseFrame) Length(_ protocol.Version) protocol.ByteCount {
|
||||
return 1 + 8
|
||||
}
|
||||
|
|
|
@ -7,11 +7,11 @@ import (
|
|||
// A PingFrame is a PING frame
|
||||
type PingFrame struct{}
|
||||
|
||||
func (f *PingFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, error) {
|
||||
func (f *PingFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {
|
||||
return append(b, pingFrameType), nil
|
||||
}
|
||||
|
||||
// Length of a written frame
|
||||
func (f *PingFrame) Length(_ protocol.VersionNumber) protocol.ByteCount {
|
||||
func (f *PingFrame) Length(_ protocol.Version) protocol.ByteCount {
|
||||
return 1
|
||||
}
|
||||
|
|
|
@ -15,7 +15,7 @@ type ResetStreamFrame struct {
|
|||
FinalSize protocol.ByteCount
|
||||
}
|
||||
|
||||
func parseResetStreamFrame(r *bytes.Reader, _ protocol.VersionNumber) (*ResetStreamFrame, error) {
|
||||
func parseResetStreamFrame(r *bytes.Reader, _ protocol.Version) (*ResetStreamFrame, error) {
|
||||
var streamID protocol.StreamID
|
||||
var byteOffset protocol.ByteCount
|
||||
sid, err := quicvarint.Read(r)
|
||||
|
@ -40,7 +40,7 @@ func parseResetStreamFrame(r *bytes.Reader, _ protocol.VersionNumber) (*ResetStr
|
|||
}, nil
|
||||
}
|
||||
|
||||
func (f *ResetStreamFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, error) {
|
||||
func (f *ResetStreamFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {
|
||||
b = append(b, resetStreamFrameType)
|
||||
b = quicvarint.Append(b, uint64(f.StreamID))
|
||||
b = quicvarint.Append(b, uint64(f.ErrorCode))
|
||||
|
@ -49,6 +49,6 @@ func (f *ResetStreamFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, e
|
|||
}
|
||||
|
||||
// Length of a written frame
|
||||
func (f *ResetStreamFrame) Length(version protocol.VersionNumber) protocol.ByteCount {
|
||||
func (f *ResetStreamFrame) Length(version protocol.Version) protocol.ByteCount {
|
||||
return 1 + quicvarint.Len(uint64(f.StreamID)) + quicvarint.Len(uint64(f.ErrorCode)) + quicvarint.Len(uint64(f.FinalSize))
|
||||
}
|
||||
|
|
|
@ -12,7 +12,7 @@ type RetireConnectionIDFrame struct {
|
|||
SequenceNumber uint64
|
||||
}
|
||||
|
||||
func parseRetireConnectionIDFrame(r *bytes.Reader, _ protocol.VersionNumber) (*RetireConnectionIDFrame, error) {
|
||||
func parseRetireConnectionIDFrame(r *bytes.Reader, _ protocol.Version) (*RetireConnectionIDFrame, error) {
|
||||
seq, err := quicvarint.Read(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -20,13 +20,13 @@ func parseRetireConnectionIDFrame(r *bytes.Reader, _ protocol.VersionNumber) (*R
|
|||
return &RetireConnectionIDFrame{SequenceNumber: seq}, nil
|
||||
}
|
||||
|
||||
func (f *RetireConnectionIDFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, error) {
|
||||
func (f *RetireConnectionIDFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {
|
||||
b = append(b, retireConnectionIDFrameType)
|
||||
b = quicvarint.Append(b, f.SequenceNumber)
|
||||
return b, nil
|
||||
}
|
||||
|
||||
// Length of a written frame
|
||||
func (f *RetireConnectionIDFrame) Length(protocol.VersionNumber) protocol.ByteCount {
|
||||
func (f *RetireConnectionIDFrame) Length(protocol.Version) protocol.ByteCount {
|
||||
return 1 + quicvarint.Len(f.SequenceNumber)
|
||||
}
|
||||
|
|
|
@ -15,7 +15,7 @@ type StopSendingFrame struct {
|
|||
}
|
||||
|
||||
// parseStopSendingFrame parses a STOP_SENDING frame
|
||||
func parseStopSendingFrame(r *bytes.Reader, _ protocol.VersionNumber) (*StopSendingFrame, error) {
|
||||
func parseStopSendingFrame(r *bytes.Reader, _ protocol.Version) (*StopSendingFrame, error) {
|
||||
streamID, err := quicvarint.Read(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -32,11 +32,11 @@ func parseStopSendingFrame(r *bytes.Reader, _ protocol.VersionNumber) (*StopSend
|
|||
}
|
||||
|
||||
// Length of a written frame
|
||||
func (f *StopSendingFrame) Length(_ protocol.VersionNumber) protocol.ByteCount {
|
||||
func (f *StopSendingFrame) Length(_ protocol.Version) protocol.ByteCount {
|
||||
return 1 + quicvarint.Len(uint64(f.StreamID)) + quicvarint.Len(uint64(f.ErrorCode))
|
||||
}
|
||||
|
||||
func (f *StopSendingFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, error) {
|
||||
func (f *StopSendingFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {
|
||||
b = append(b, stopSendingFrameType)
|
||||
b = quicvarint.Append(b, uint64(f.StreamID))
|
||||
b = quicvarint.Append(b, uint64(f.ErrorCode))
|
||||
|
|
|
@ -13,7 +13,7 @@ type StreamDataBlockedFrame struct {
|
|||
MaximumStreamData protocol.ByteCount
|
||||
}
|
||||
|
||||
func parseStreamDataBlockedFrame(r *bytes.Reader, _ protocol.VersionNumber) (*StreamDataBlockedFrame, error) {
|
||||
func parseStreamDataBlockedFrame(r *bytes.Reader, _ protocol.Version) (*StreamDataBlockedFrame, error) {
|
||||
sid, err := quicvarint.Read(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -29,7 +29,7 @@ func parseStreamDataBlockedFrame(r *bytes.Reader, _ protocol.VersionNumber) (*St
|
|||
}, nil
|
||||
}
|
||||
|
||||
func (f *StreamDataBlockedFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, error) {
|
||||
func (f *StreamDataBlockedFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {
|
||||
b = append(b, 0x15)
|
||||
b = quicvarint.Append(b, uint64(f.StreamID))
|
||||
b = quicvarint.Append(b, uint64(f.MaximumStreamData))
|
||||
|
@ -37,6 +37,6 @@ func (f *StreamDataBlockedFrame) Append(b []byte, _ protocol.VersionNumber) ([]b
|
|||
}
|
||||
|
||||
// Length of a written frame
|
||||
func (f *StreamDataBlockedFrame) Length(version protocol.VersionNumber) protocol.ByteCount {
|
||||
func (f *StreamDataBlockedFrame) Length(version protocol.Version) protocol.ByteCount {
|
||||
return 1 + quicvarint.Len(uint64(f.StreamID)) + quicvarint.Len(uint64(f.MaximumStreamData))
|
||||
}
|
||||
|
|
|
@ -20,7 +20,7 @@ type StreamFrame struct {
|
|||
fromPool bool
|
||||
}
|
||||
|
||||
func parseStreamFrame(r *bytes.Reader, typ uint64, _ protocol.VersionNumber) (*StreamFrame, error) {
|
||||
func parseStreamFrame(r *bytes.Reader, typ uint64, _ protocol.Version) (*StreamFrame, error) {
|
||||
hasOffset := typ&0b100 > 0
|
||||
fin := typ&0b1 > 0
|
||||
hasDataLen := typ&0b10 > 0
|
||||
|
@ -79,7 +79,7 @@ func parseStreamFrame(r *bytes.Reader, typ uint64, _ protocol.VersionNumber) (*S
|
|||
}
|
||||
|
||||
// Write writes a STREAM frame
|
||||
func (f *StreamFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, error) {
|
||||
func (f *StreamFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {
|
||||
if len(f.Data) == 0 && !f.Fin {
|
||||
return nil, errors.New("StreamFrame: attempting to write empty frame without FIN")
|
||||
}
|
||||
|
@ -108,7 +108,7 @@ func (f *StreamFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, error)
|
|||
}
|
||||
|
||||
// Length returns the total length of the STREAM frame
|
||||
func (f *StreamFrame) Length(version protocol.VersionNumber) protocol.ByteCount {
|
||||
func (f *StreamFrame) Length(version protocol.Version) protocol.ByteCount {
|
||||
length := 1 + quicvarint.Len(uint64(f.StreamID))
|
||||
if f.Offset != 0 {
|
||||
length += quicvarint.Len(uint64(f.Offset))
|
||||
|
@ -126,7 +126,7 @@ func (f *StreamFrame) DataLen() protocol.ByteCount {
|
|||
|
||||
// MaxDataLen returns the maximum data length
|
||||
// If 0 is returned, writing will fail (a STREAM frame must contain at least 1 byte of data).
|
||||
func (f *StreamFrame) MaxDataLen(maxSize protocol.ByteCount, version protocol.VersionNumber) protocol.ByteCount {
|
||||
func (f *StreamFrame) MaxDataLen(maxSize protocol.ByteCount, version protocol.Version) protocol.ByteCount {
|
||||
headerLen := 1 + quicvarint.Len(uint64(f.StreamID))
|
||||
if f.Offset != 0 {
|
||||
headerLen += quicvarint.Len(uint64(f.Offset))
|
||||
|
@ -151,7 +151,7 @@ func (f *StreamFrame) MaxDataLen(maxSize protocol.ByteCount, version protocol.Ve
|
|||
// The frame might not be split if:
|
||||
// * the size is large enough to fit the whole frame
|
||||
// * the size is too small to fit even a 1-byte frame. In that case, the frame returned is nil.
|
||||
func (f *StreamFrame) MaybeSplitOffFrame(maxSize protocol.ByteCount, version protocol.VersionNumber) (*StreamFrame, bool /* was splitting required */) {
|
||||
func (f *StreamFrame) MaybeSplitOffFrame(maxSize protocol.ByteCount, version protocol.Version) (*StreamFrame, bool /* was splitting required */) {
|
||||
if maxSize >= f.Length(version) {
|
||||
return nil, false
|
||||
}
|
||||
|
|
|
@ -14,7 +14,7 @@ type StreamsBlockedFrame struct {
|
|||
StreamLimit protocol.StreamNum
|
||||
}
|
||||
|
||||
func parseStreamsBlockedFrame(r *bytes.Reader, typ uint64, _ protocol.VersionNumber) (*StreamsBlockedFrame, error) {
|
||||
func parseStreamsBlockedFrame(r *bytes.Reader, typ uint64, _ protocol.Version) (*StreamsBlockedFrame, error) {
|
||||
f := &StreamsBlockedFrame{}
|
||||
switch typ {
|
||||
case bidiStreamBlockedFrameType:
|
||||
|
@ -33,7 +33,7 @@ func parseStreamsBlockedFrame(r *bytes.Reader, typ uint64, _ protocol.VersionNum
|
|||
return f, nil
|
||||
}
|
||||
|
||||
func (f *StreamsBlockedFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, error) {
|
||||
func (f *StreamsBlockedFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {
|
||||
switch f.Type {
|
||||
case protocol.StreamTypeBidi:
|
||||
b = append(b, bidiStreamBlockedFrameType)
|
||||
|
@ -45,6 +45,6 @@ func (f *StreamsBlockedFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte
|
|||
}
|
||||
|
||||
// Length of a written frame
|
||||
func (f *StreamsBlockedFrame) Length(_ protocol.VersionNumber) protocol.ByteCount {
|
||||
func (f *StreamsBlockedFrame) Length(_ protocol.Version) protocol.ByteCount {
|
||||
return 1 + quicvarint.Len(uint64(f.StreamLimit))
|
||||
}
|
||||
|
|
|
@ -7,7 +7,7 @@ import (
|
|||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sort"
|
||||
"time"
|
||||
|
||||
|
@ -51,10 +51,7 @@ const (
|
|||
|
||||
// PreferredAddress is the value encoding in the preferred_address transport parameter
|
||||
type PreferredAddress struct {
|
||||
IPv4 net.IP
|
||||
IPv4Port uint16
|
||||
IPv6 net.IP
|
||||
IPv6Port uint16
|
||||
IPv4, IPv6 netip.AddrPort
|
||||
ConnectionID protocol.ConnectionID
|
||||
StatelessResetToken protocol.StatelessResetToken
|
||||
}
|
||||
|
@ -218,26 +215,24 @@ func (p *TransportParameters) unmarshal(r *bytes.Reader, sentBy protocol.Perspec
|
|||
func (p *TransportParameters) readPreferredAddress(r *bytes.Reader, expectedLen int) error {
|
||||
remainingLen := r.Len()
|
||||
pa := &PreferredAddress{}
|
||||
ipv4 := make([]byte, 4)
|
||||
if _, err := io.ReadFull(r, ipv4); err != nil {
|
||||
var ipv4 [4]byte
|
||||
if _, err := io.ReadFull(r, ipv4[:]); err != nil {
|
||||
return err
|
||||
}
|
||||
pa.IPv4 = net.IP(ipv4)
|
||||
port, err := utils.BigEndian.ReadUint16(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
pa.IPv4Port = port
|
||||
ipv6 := make([]byte, 16)
|
||||
if _, err := io.ReadFull(r, ipv6); err != nil {
|
||||
pa.IPv4 = netip.AddrPortFrom(netip.AddrFrom4(ipv4), port)
|
||||
var ipv6 [16]byte
|
||||
if _, err := io.ReadFull(r, ipv6[:]); err != nil {
|
||||
return err
|
||||
}
|
||||
pa.IPv6 = net.IP(ipv6)
|
||||
port, err = utils.BigEndian.ReadUint16(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
pa.IPv6Port = port
|
||||
pa.IPv6 = netip.AddrPortFrom(netip.AddrFrom16(ipv6), port)
|
||||
connIDLen, err := r.ReadByte()
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -384,13 +379,12 @@ func (p *TransportParameters) Marshal(pers protocol.Perspective) []byte {
|
|||
if p.PreferredAddress != nil {
|
||||
b = quicvarint.Append(b, uint64(preferredAddressParameterID))
|
||||
b = quicvarint.Append(b, 4+2+16+2+1+uint64(p.PreferredAddress.ConnectionID.Len())+16)
|
||||
ipv4 := p.PreferredAddress.IPv4
|
||||
b = append(b, ipv4[len(ipv4)-4:]...)
|
||||
b = append(b, []byte{0, 0}...)
|
||||
binary.BigEndian.PutUint16(b[len(b)-2:], p.PreferredAddress.IPv4Port)
|
||||
b = append(b, p.PreferredAddress.IPv6...)
|
||||
b = append(b, []byte{0, 0}...)
|
||||
binary.BigEndian.PutUint16(b[len(b)-2:], p.PreferredAddress.IPv6Port)
|
||||
ip4 := p.PreferredAddress.IPv4.Addr().As4()
|
||||
b = append(b, ip4[:]...)
|
||||
b = binary.BigEndian.AppendUint16(b, p.PreferredAddress.IPv4.Port())
|
||||
ip6 := p.PreferredAddress.IPv6.Addr().As16()
|
||||
b = append(b, ip6[:]...)
|
||||
b = binary.BigEndian.AppendUint16(b, p.PreferredAddress.IPv6.Port())
|
||||
b = append(b, uint8(p.PreferredAddress.ConnectionID.Len()))
|
||||
b = append(b, p.PreferredAddress.ConnectionID.Bytes()...)
|
||||
b = append(b, p.PreferredAddress.StatelessResetToken[:]...)
|
||||
|
|
|
@ -1,17 +1,15 @@
|
|||
package wire
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
"github.com/quic-go/quic-go/internal/utils"
|
||||
)
|
||||
|
||||
// ParseVersionNegotiationPacket parses a Version Negotiation packet.
|
||||
func ParseVersionNegotiationPacket(b []byte) (dest, src protocol.ArbitraryLenConnectionID, _ []protocol.VersionNumber, _ error) {
|
||||
func ParseVersionNegotiationPacket(b []byte) (dest, src protocol.ArbitraryLenConnectionID, _ []protocol.Version, _ error) {
|
||||
n, dest, src, err := ParseArbitraryLenConnectionIDs(b)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
|
@ -25,32 +23,31 @@ func ParseVersionNegotiationPacket(b []byte) (dest, src protocol.ArbitraryLenCon
|
|||
//nolint:stylecheck
|
||||
return nil, nil, nil, errors.New("Version Negotiation packet has a version list with an invalid length")
|
||||
}
|
||||
versions := make([]protocol.VersionNumber, len(b)/4)
|
||||
versions := make([]protocol.Version, len(b)/4)
|
||||
for i := 0; len(b) > 0; i++ {
|
||||
versions[i] = protocol.VersionNumber(binary.BigEndian.Uint32(b[:4]))
|
||||
versions[i] = protocol.Version(binary.BigEndian.Uint32(b[:4]))
|
||||
b = b[4:]
|
||||
}
|
||||
return dest, src, versions, nil
|
||||
}
|
||||
|
||||
// ComposeVersionNegotiation composes a Version Negotiation
|
||||
func ComposeVersionNegotiation(destConnID, srcConnID protocol.ArbitraryLenConnectionID, versions []protocol.VersionNumber) []byte {
|
||||
func ComposeVersionNegotiation(destConnID, srcConnID protocol.ArbitraryLenConnectionID, versions []protocol.Version) []byte {
|
||||
greasedVersions := protocol.GetGreasedVersions(versions)
|
||||
expectedLen := 1 /* type byte */ + 4 /* version field */ + 1 /* dest connection ID length field */ + destConnID.Len() + 1 /* src connection ID length field */ + srcConnID.Len() + len(greasedVersions)*4
|
||||
buf := bytes.NewBuffer(make([]byte, 0, expectedLen))
|
||||
r := make([]byte, 1)
|
||||
_, _ = rand.Read(r) // ignore the error here. It is not critical to have perfect random here.
|
||||
buf := make([]byte, 1+4 /* type byte and version field */, expectedLen)
|
||||
_, _ = rand.Read(buf[:1]) // ignore the error here. It is not critical to have perfect random here.
|
||||
// Setting the "QUIC bit" (0x40) is not required by the RFC,
|
||||
// but it allows clients to demultiplex QUIC with a long list of other protocols.
|
||||
// See RFC 9443 and https://mailarchive.ietf.org/arch/msg/quic/oR4kxGKY6mjtPC1CZegY1ED4beg/ for details.
|
||||
buf.WriteByte(r[0] | 0xc0)
|
||||
utils.BigEndian.WriteUint32(buf, 0) // version 0
|
||||
buf.WriteByte(uint8(destConnID.Len()))
|
||||
buf.Write(destConnID.Bytes())
|
||||
buf.WriteByte(uint8(srcConnID.Len()))
|
||||
buf.Write(srcConnID.Bytes())
|
||||
buf[0] |= 0xc0
|
||||
// The next 4 bytes are left at 0 (version number).
|
||||
buf = append(buf, uint8(destConnID.Len()))
|
||||
buf = append(buf, destConnID.Bytes()...)
|
||||
buf = append(buf, uint8(srcConnID.Len()))
|
||||
buf = append(buf, srcConnID.Bytes()...)
|
||||
for _, v := range greasedVersions {
|
||||
utils.BigEndian.WriteUint32(buf, uint32(v))
|
||||
buf = binary.BigEndian.AppendUint32(buf, uint32(v))
|
||||
}
|
||||
return buf.Bytes()
|
||||
return buf
|
||||
}
|
||||
|
|
|
@ -27,9 +27,9 @@ type ConnectionTracer struct {
|
|||
UpdatedCongestionState func(CongestionState)
|
||||
UpdatedPTOCount func(value uint32)
|
||||
UpdatedKeyFromTLS func(EncryptionLevel, Perspective)
|
||||
UpdatedKey func(generation KeyPhase, remote bool)
|
||||
UpdatedKey func(keyPhase KeyPhase, remote bool)
|
||||
DroppedEncryptionLevel func(EncryptionLevel)
|
||||
DroppedKey func(generation KeyPhase)
|
||||
DroppedKey func(keyPhase KeyPhase)
|
||||
SetLossTimer func(TimerType, EncryptionLevel, time.Time)
|
||||
LossTimerExpired func(TimerType, EncryptionLevel)
|
||||
LossTimerCanceled func()
|
||||
|
|
|
@ -37,7 +37,7 @@ type (
|
|||
// The StreamType is the type of the stream (unidirectional or bidirectional).
|
||||
StreamType = protocol.StreamType
|
||||
// The VersionNumber is the QUIC version.
|
||||
VersionNumber = protocol.VersionNumber
|
||||
VersionNumber = protocol.Version
|
||||
|
||||
// The Header is the QUIC packet header, before removing header protection.
|
||||
Header = wire.Header
|
||||
|
|
|
@ -7,6 +7,8 @@ type Tracer struct {
|
|||
SentPacket func(net.Addr, *Header, ByteCount, []Frame)
|
||||
SentVersionNegotiationPacket func(_ net.Addr, dest, src ArbitraryLenConnectionID, _ []VersionNumber)
|
||||
DroppedPacket func(net.Addr, PacketType, ByteCount, PacketDropReason)
|
||||
Debug func(name, msg string)
|
||||
Close func()
|
||||
}
|
||||
|
||||
// NewMultiplexedTracer creates a new tracer that multiplexes events to multiple tracers.
|
||||
|
@ -39,5 +41,19 @@ func NewMultiplexedTracer(tracers ...*Tracer) *Tracer {
|
|||
}
|
||||
}
|
||||
},
|
||||
Debug: func(name, msg string) {
|
||||
for _, t := range tracers {
|
||||
if t.Debug != nil {
|
||||
t.Debug(name, msg)
|
||||
}
|
||||
}
|
||||
},
|
||||
Close: func() {
|
||||
for _, t := range tracers {
|
||||
if t.Close != nil {
|
||||
t.Close()
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
|
|
@ -3,12 +3,12 @@
|
|||
# Install Go manually, since oss-fuzz ships with an outdated Go version.
|
||||
# See https://github.com/google/oss-fuzz/pull/10643.
|
||||
export CXX="${CXX} -lresolv" # required by Go 1.20
|
||||
wget https://go.dev/dl/go1.21.5.linux-amd64.tar.gz \
|
||||
wget https://go.dev/dl/go1.22.0.linux-amd64.tar.gz \
|
||||
&& mkdir temp-go \
|
||||
&& rm -rf /root/.go/* \
|
||||
&& tar -C temp-go/ -xzf go1.21.5.linux-amd64.tar.gz \
|
||||
&& tar -C temp-go/ -xzf go1.22.0.linux-amd64.tar.gz \
|
||||
&& mv temp-go/go/* /root/.go/ \
|
||||
&& rm -rf temp-go go1.21.5.linux-amd64.tar.gz
|
||||
&& rm -rf temp-go go1.22.0.linux-amd64.tar.gz
|
||||
|
||||
(
|
||||
# fuzz qpack
|
||||
|
|
|
@ -129,7 +129,7 @@ func (h *packetHandlerMap) Add(id protocol.ConnectionID, handler packetHandler)
|
|||
return true
|
||||
}
|
||||
|
||||
func (h *packetHandlerMap) AddWithConnID(clientDestConnID, newConnID protocol.ConnectionID, fn func() (packetHandler, bool)) bool {
|
||||
func (h *packetHandlerMap) AddWithConnID(clientDestConnID, newConnID protocol.ConnectionID, handler packetHandler) bool {
|
||||
h.mutex.Lock()
|
||||
defer h.mutex.Unlock()
|
||||
|
||||
|
@ -137,12 +137,8 @@ func (h *packetHandlerMap) AddWithConnID(clientDestConnID, newConnID protocol.Co
|
|||
h.logger.Debugf("Not adding connection ID %s for a new connection, as it already exists.", clientDestConnID)
|
||||
return false
|
||||
}
|
||||
conn, ok := fn()
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
h.handlers[clientDestConnID] = conn
|
||||
h.handlers[newConnID] = conn
|
||||
h.handlers[clientDestConnID] = handler
|
||||
h.handlers[newConnID] = handler
|
||||
h.logger.Debugf("Adding connection IDs %s and %s for a new connection.", clientDestConnID, newConnID)
|
||||
return true
|
||||
}
|
||||
|
@ -168,18 +164,17 @@ func (h *packetHandlerMap) Retire(id protocol.ConnectionID) {
|
|||
// Depending on which side closed the connection, we need to:
|
||||
// * remote close: absorb delayed packets
|
||||
// * local close: retransmit the CONNECTION_CLOSE packet, in case it was lost
|
||||
func (h *packetHandlerMap) ReplaceWithClosed(ids []protocol.ConnectionID, pers protocol.Perspective, connClosePacket []byte) {
|
||||
func (h *packetHandlerMap) ReplaceWithClosed(ids []protocol.ConnectionID, connClosePacket []byte) {
|
||||
var handler packetHandler
|
||||
if connClosePacket != nil {
|
||||
handler = newClosedLocalConn(
|
||||
func(addr net.Addr, info packetInfo) {
|
||||
h.enqueueClosePacket(closePacket{payload: connClosePacket, addr: addr, info: info})
|
||||
},
|
||||
pers,
|
||||
h.logger,
|
||||
)
|
||||
} else {
|
||||
handler = newClosedRemoteConn(pers)
|
||||
handler = newClosedRemoteConn()
|
||||
}
|
||||
|
||||
h.mutex.Lock()
|
||||
|
@ -191,7 +186,6 @@ func (h *packetHandlerMap) ReplaceWithClosed(ids []protocol.ConnectionID, pers p
|
|||
|
||||
time.AfterFunc(h.deleteRetiredConnsAfter, func() {
|
||||
h.mutex.Lock()
|
||||
handler.shutdown()
|
||||
for _, id := range ids {
|
||||
delete(h.handlers, id)
|
||||
}
|
||||
|
|
|
@ -18,13 +18,13 @@ import (
|
|||
var errNothingToPack = errors.New("nothing to pack")
|
||||
|
||||
type packer interface {
|
||||
PackCoalescedPacket(onlyAck bool, maxPacketSize protocol.ByteCount, v protocol.VersionNumber) (*coalescedPacket, error)
|
||||
PackAckOnlyPacket(maxPacketSize protocol.ByteCount, v protocol.VersionNumber) (shortHeaderPacket, *packetBuffer, error)
|
||||
AppendPacket(buf *packetBuffer, maxPacketSize protocol.ByteCount, v protocol.VersionNumber) (shortHeaderPacket, error)
|
||||
MaybePackProbePacket(protocol.EncryptionLevel, protocol.ByteCount, protocol.VersionNumber) (*coalescedPacket, error)
|
||||
PackConnectionClose(*qerr.TransportError, protocol.ByteCount, protocol.VersionNumber) (*coalescedPacket, error)
|
||||
PackApplicationClose(*qerr.ApplicationError, protocol.ByteCount, protocol.VersionNumber) (*coalescedPacket, error)
|
||||
PackMTUProbePacket(ping ackhandler.Frame, size protocol.ByteCount, v protocol.VersionNumber) (shortHeaderPacket, *packetBuffer, error)
|
||||
PackCoalescedPacket(onlyAck bool, maxPacketSize protocol.ByteCount, v protocol.Version) (*coalescedPacket, error)
|
||||
PackAckOnlyPacket(maxPacketSize protocol.ByteCount, v protocol.Version) (shortHeaderPacket, *packetBuffer, error)
|
||||
AppendPacket(buf *packetBuffer, maxPacketSize protocol.ByteCount, v protocol.Version) (shortHeaderPacket, error)
|
||||
MaybePackProbePacket(protocol.EncryptionLevel, protocol.ByteCount, protocol.Version) (*coalescedPacket, error)
|
||||
PackConnectionClose(*qerr.TransportError, protocol.ByteCount, protocol.Version) (*coalescedPacket, error)
|
||||
PackApplicationClose(*qerr.ApplicationError, protocol.ByteCount, protocol.Version) (*coalescedPacket, error)
|
||||
PackMTUProbePacket(ping ackhandler.Frame, size protocol.ByteCount, v protocol.Version) (shortHeaderPacket, *packetBuffer, error)
|
||||
|
||||
SetToken([]byte)
|
||||
}
|
||||
|
@ -106,8 +106,8 @@ type sealingManager interface {
|
|||
|
||||
type frameSource interface {
|
||||
HasData() bool
|
||||
AppendStreamFrames([]ackhandler.StreamFrame, protocol.ByteCount, protocol.VersionNumber) ([]ackhandler.StreamFrame, protocol.ByteCount)
|
||||
AppendControlFrames([]ackhandler.Frame, protocol.ByteCount, protocol.VersionNumber) ([]ackhandler.Frame, protocol.ByteCount)
|
||||
AppendStreamFrames([]ackhandler.StreamFrame, protocol.ByteCount, protocol.Version) ([]ackhandler.StreamFrame, protocol.ByteCount)
|
||||
AppendControlFrames([]ackhandler.Frame, protocol.ByteCount, protocol.Version) ([]ackhandler.Frame, protocol.ByteCount)
|
||||
}
|
||||
|
||||
type ackFrameSource interface {
|
||||
|
@ -170,7 +170,7 @@ func newPacketPacker(
|
|||
}
|
||||
|
||||
// PackConnectionClose packs a packet that closes the connection with a transport error.
|
||||
func (p *packetPacker) PackConnectionClose(e *qerr.TransportError, maxPacketSize protocol.ByteCount, v protocol.VersionNumber) (*coalescedPacket, error) {
|
||||
func (p *packetPacker) PackConnectionClose(e *qerr.TransportError, maxPacketSize protocol.ByteCount, v protocol.Version) (*coalescedPacket, error) {
|
||||
var reason string
|
||||
// don't send details of crypto errors
|
||||
if !e.ErrorCode.IsCryptoError() {
|
||||
|
@ -180,7 +180,7 @@ func (p *packetPacker) PackConnectionClose(e *qerr.TransportError, maxPacketSize
|
|||
}
|
||||
|
||||
// PackApplicationClose packs a packet that closes the connection with an application error.
|
||||
func (p *packetPacker) PackApplicationClose(e *qerr.ApplicationError, maxPacketSize protocol.ByteCount, v protocol.VersionNumber) (*coalescedPacket, error) {
|
||||
func (p *packetPacker) PackApplicationClose(e *qerr.ApplicationError, maxPacketSize protocol.ByteCount, v protocol.Version) (*coalescedPacket, error) {
|
||||
return p.packConnectionClose(true, uint64(e.ErrorCode), 0, e.ErrorMessage, maxPacketSize, v)
|
||||
}
|
||||
|
||||
|
@ -190,7 +190,7 @@ func (p *packetPacker) packConnectionClose(
|
|||
frameType uint64,
|
||||
reason string,
|
||||
maxPacketSize protocol.ByteCount,
|
||||
v protocol.VersionNumber,
|
||||
v protocol.Version,
|
||||
) (*coalescedPacket, error) {
|
||||
var sealers [4]sealer
|
||||
var hdrs [3]*wire.ExtendedHeader
|
||||
|
@ -293,7 +293,7 @@ func (p *packetPacker) packConnectionClose(
|
|||
// longHeaderPacketLength calculates the length of a serialized long header packet.
|
||||
// It takes into account that packets that have a tiny payload need to be padded,
|
||||
// such that len(payload) + packet number len >= 4 + AEAD overhead
|
||||
func (p *packetPacker) longHeaderPacketLength(hdr *wire.ExtendedHeader, pl payload, v protocol.VersionNumber) protocol.ByteCount {
|
||||
func (p *packetPacker) longHeaderPacketLength(hdr *wire.ExtendedHeader, pl payload, v protocol.Version) protocol.ByteCount {
|
||||
var paddingLen protocol.ByteCount
|
||||
pnLen := protocol.ByteCount(hdr.PacketNumberLen)
|
||||
if pl.length < 4-pnLen {
|
||||
|
@ -328,7 +328,7 @@ func (p *packetPacker) initialPaddingLen(frames []ackhandler.Frame, currentSize,
|
|||
// PackCoalescedPacket packs a new packet.
|
||||
// It packs an Initial / Handshake if there is data to send in these packet number spaces.
|
||||
// It should only be called before the handshake is confirmed.
|
||||
func (p *packetPacker) PackCoalescedPacket(onlyAck bool, maxPacketSize protocol.ByteCount, v protocol.VersionNumber) (*coalescedPacket, error) {
|
||||
func (p *packetPacker) PackCoalescedPacket(onlyAck bool, maxPacketSize protocol.ByteCount, v protocol.Version) (*coalescedPacket, error) {
|
||||
var (
|
||||
initialHdr, handshakeHdr, zeroRTTHdr *wire.ExtendedHeader
|
||||
initialPayload, handshakePayload, zeroRTTPayload, oneRTTPayload payload
|
||||
|
@ -442,7 +442,7 @@ func (p *packetPacker) PackCoalescedPacket(onlyAck bool, maxPacketSize protocol.
|
|||
|
||||
// PackAckOnlyPacket packs a packet containing only an ACK in the application data packet number space.
|
||||
// It should be called after the handshake is confirmed.
|
||||
func (p *packetPacker) PackAckOnlyPacket(maxPacketSize protocol.ByteCount, v protocol.VersionNumber) (shortHeaderPacket, *packetBuffer, error) {
|
||||
func (p *packetPacker) PackAckOnlyPacket(maxPacketSize protocol.ByteCount, v protocol.Version) (shortHeaderPacket, *packetBuffer, error) {
|
||||
buf := getPacketBuffer()
|
||||
packet, err := p.appendPacket(buf, true, maxPacketSize, v)
|
||||
return packet, buf, err
|
||||
|
@ -450,11 +450,11 @@ func (p *packetPacker) PackAckOnlyPacket(maxPacketSize protocol.ByteCount, v pro
|
|||
|
||||
// AppendPacket packs a packet in the application data packet number space.
|
||||
// It should be called after the handshake is confirmed.
|
||||
func (p *packetPacker) AppendPacket(buf *packetBuffer, maxPacketSize protocol.ByteCount, v protocol.VersionNumber) (shortHeaderPacket, error) {
|
||||
func (p *packetPacker) AppendPacket(buf *packetBuffer, maxPacketSize protocol.ByteCount, v protocol.Version) (shortHeaderPacket, error) {
|
||||
return p.appendPacket(buf, false, maxPacketSize, v)
|
||||
}
|
||||
|
||||
func (p *packetPacker) appendPacket(buf *packetBuffer, onlyAck bool, maxPacketSize protocol.ByteCount, v protocol.VersionNumber) (shortHeaderPacket, error) {
|
||||
func (p *packetPacker) appendPacket(buf *packetBuffer, onlyAck bool, maxPacketSize protocol.ByteCount, v protocol.Version) (shortHeaderPacket, error) {
|
||||
sealer, err := p.cryptoSetup.Get1RTTSealer()
|
||||
if err != nil {
|
||||
return shortHeaderPacket{}, err
|
||||
|
@ -471,7 +471,7 @@ func (p *packetPacker) appendPacket(buf *packetBuffer, onlyAck bool, maxPacketSi
|
|||
return p.appendShortHeaderPacket(buf, connID, pn, pnLen, kp, pl, 0, maxPacketSize, sealer, false, v)
|
||||
}
|
||||
|
||||
func (p *packetPacker) maybeGetCryptoPacket(maxPacketSize protocol.ByteCount, encLevel protocol.EncryptionLevel, onlyAck, ackAllowed bool, v protocol.VersionNumber) (*wire.ExtendedHeader, payload) {
|
||||
func (p *packetPacker) maybeGetCryptoPacket(maxPacketSize protocol.ByteCount, encLevel protocol.EncryptionLevel, onlyAck, ackAllowed bool, v protocol.Version) (*wire.ExtendedHeader, payload) {
|
||||
if onlyAck {
|
||||
if ack := p.acks.GetAckFrame(encLevel, true); ack != nil {
|
||||
return p.getLongHeader(encLevel, v), payload{
|
||||
|
@ -543,7 +543,7 @@ func (p *packetPacker) maybeGetCryptoPacket(maxPacketSize protocol.ByteCount, en
|
|||
return hdr, pl
|
||||
}
|
||||
|
||||
func (p *packetPacker) maybeGetAppDataPacketFor0RTT(sealer sealer, maxPacketSize protocol.ByteCount, v protocol.VersionNumber) (*wire.ExtendedHeader, payload) {
|
||||
func (p *packetPacker) maybeGetAppDataPacketFor0RTT(sealer sealer, maxPacketSize protocol.ByteCount, v protocol.Version) (*wire.ExtendedHeader, payload) {
|
||||
if p.perspective != protocol.PerspectiveClient {
|
||||
return nil, payload{}
|
||||
}
|
||||
|
@ -553,12 +553,12 @@ func (p *packetPacker) maybeGetAppDataPacketFor0RTT(sealer sealer, maxPacketSize
|
|||
return hdr, p.maybeGetAppDataPacket(maxPayloadSize, false, false, v)
|
||||
}
|
||||
|
||||
func (p *packetPacker) maybeGetShortHeaderPacket(sealer handshake.ShortHeaderSealer, hdrLen protocol.ByteCount, maxPacketSize protocol.ByteCount, onlyAck, ackAllowed bool, v protocol.VersionNumber) payload {
|
||||
func (p *packetPacker) maybeGetShortHeaderPacket(sealer handshake.ShortHeaderSealer, hdrLen protocol.ByteCount, maxPacketSize protocol.ByteCount, onlyAck, ackAllowed bool, v protocol.Version) payload {
|
||||
maxPayloadSize := maxPacketSize - hdrLen - protocol.ByteCount(sealer.Overhead())
|
||||
return p.maybeGetAppDataPacket(maxPayloadSize, onlyAck, ackAllowed, v)
|
||||
}
|
||||
|
||||
func (p *packetPacker) maybeGetAppDataPacket(maxPayloadSize protocol.ByteCount, onlyAck, ackAllowed bool, v protocol.VersionNumber) payload {
|
||||
func (p *packetPacker) maybeGetAppDataPacket(maxPayloadSize protocol.ByteCount, onlyAck, ackAllowed bool, v protocol.Version) payload {
|
||||
pl := p.composeNextPacket(maxPayloadSize, onlyAck, ackAllowed, v)
|
||||
|
||||
// check if we have anything to send
|
||||
|
@ -581,7 +581,7 @@ func (p *packetPacker) maybeGetAppDataPacket(maxPayloadSize protocol.ByteCount,
|
|||
return pl
|
||||
}
|
||||
|
||||
func (p *packetPacker) composeNextPacket(maxFrameSize protocol.ByteCount, onlyAck, ackAllowed bool, v protocol.VersionNumber) payload {
|
||||
func (p *packetPacker) composeNextPacket(maxFrameSize protocol.ByteCount, onlyAck, ackAllowed bool, v protocol.Version) payload {
|
||||
if onlyAck {
|
||||
if ack := p.acks.GetAckFrame(protocol.Encryption1RTT, true); ack != nil {
|
||||
return payload{ack: ack, length: ack.Length(v)}
|
||||
|
@ -589,12 +589,11 @@ func (p *packetPacker) composeNextPacket(maxFrameSize protocol.ByteCount, onlyAc
|
|||
return payload{}
|
||||
}
|
||||
|
||||
pl := payload{streamFrames: make([]ackhandler.StreamFrame, 0, 1)}
|
||||
|
||||
hasData := p.framer.HasData()
|
||||
hasRetransmission := p.retransmissionQueue.HasAppData()
|
||||
|
||||
var hasAck bool
|
||||
var pl payload
|
||||
if ackAllowed {
|
||||
if ack := p.acks.GetAckFrame(protocol.Encryption1RTT, !hasRetransmission && !hasData); ack != nil {
|
||||
pl.ack = ack
|
||||
|
@ -661,7 +660,7 @@ func (p *packetPacker) composeNextPacket(maxFrameSize protocol.ByteCount, onlyAc
|
|||
return pl
|
||||
}
|
||||
|
||||
func (p *packetPacker) MaybePackProbePacket(encLevel protocol.EncryptionLevel, maxPacketSize protocol.ByteCount, v protocol.VersionNumber) (*coalescedPacket, error) {
|
||||
func (p *packetPacker) MaybePackProbePacket(encLevel protocol.EncryptionLevel, maxPacketSize protocol.ByteCount, v protocol.Version) (*coalescedPacket, error) {
|
||||
if encLevel == protocol.Encryption1RTT {
|
||||
s, err := p.cryptoSetup.Get1RTTSealer()
|
||||
if err != nil {
|
||||
|
@ -727,7 +726,7 @@ func (p *packetPacker) MaybePackProbePacket(encLevel protocol.EncryptionLevel, m
|
|||
return packet, nil
|
||||
}
|
||||
|
||||
func (p *packetPacker) PackMTUProbePacket(ping ackhandler.Frame, size protocol.ByteCount, v protocol.VersionNumber) (shortHeaderPacket, *packetBuffer, error) {
|
||||
func (p *packetPacker) PackMTUProbePacket(ping ackhandler.Frame, size protocol.ByteCount, v protocol.Version) (shortHeaderPacket, *packetBuffer, error) {
|
||||
pl := payload{
|
||||
frames: []ackhandler.Frame{ping},
|
||||
length: ping.Frame.Length(v),
|
||||
|
@ -745,7 +744,7 @@ func (p *packetPacker) PackMTUProbePacket(ping ackhandler.Frame, size protocol.B
|
|||
return packet, buffer, err
|
||||
}
|
||||
|
||||
func (p *packetPacker) getLongHeader(encLevel protocol.EncryptionLevel, v protocol.VersionNumber) *wire.ExtendedHeader {
|
||||
func (p *packetPacker) getLongHeader(encLevel protocol.EncryptionLevel, v protocol.Version) *wire.ExtendedHeader {
|
||||
pn, pnLen := p.pnManager.PeekPacketNumber(encLevel)
|
||||
hdr := &wire.ExtendedHeader{
|
||||
PacketNumber: pn,
|
||||
|
@ -768,7 +767,7 @@ func (p *packetPacker) getLongHeader(encLevel protocol.EncryptionLevel, v protoc
|
|||
return hdr
|
||||
}
|
||||
|
||||
func (p *packetPacker) appendLongHeaderPacket(buffer *packetBuffer, header *wire.ExtendedHeader, pl payload, padding protocol.ByteCount, encLevel protocol.EncryptionLevel, sealer sealer, v protocol.VersionNumber) (*longHeaderPacket, error) {
|
||||
func (p *packetPacker) appendLongHeaderPacket(buffer *packetBuffer, header *wire.ExtendedHeader, pl payload, padding protocol.ByteCount, encLevel protocol.EncryptionLevel, sealer sealer, v protocol.Version) (*longHeaderPacket, error) {
|
||||
var paddingLen protocol.ByteCount
|
||||
pnLen := protocol.ByteCount(header.PacketNumberLen)
|
||||
if pl.length < 4-pnLen {
|
||||
|
@ -814,7 +813,7 @@ func (p *packetPacker) appendShortHeaderPacket(
|
|||
padding, maxPacketSize protocol.ByteCount,
|
||||
sealer sealer,
|
||||
isMTUProbePacket bool,
|
||||
v protocol.VersionNumber,
|
||||
v protocol.Version,
|
||||
) (shortHeaderPacket, error) {
|
||||
var paddingLen protocol.ByteCount
|
||||
if pl.length < 4-protocol.ByteCount(pnLen) {
|
||||
|
@ -860,7 +859,7 @@ func (p *packetPacker) appendShortHeaderPacket(
|
|||
|
||||
// appendPacketPayload serializes the payload of a packet into the raw byte slice.
|
||||
// It modifies the order of payload.frames.
|
||||
func (p *packetPacker) appendPacketPayload(raw []byte, pl payload, paddingLen protocol.ByteCount, v protocol.VersionNumber) ([]byte, error) {
|
||||
func (p *packetPacker) appendPacketPayload(raw []byte, pl payload, paddingLen protocol.ByteCount, v protocol.Version) ([]byte, error) {
|
||||
payloadOffset := len(raw)
|
||||
if pl.ack != nil {
|
||||
var err error
|
||||
|
|
|
@ -53,7 +53,7 @@ func newPacketUnpacker(cs handshake.CryptoSetup, shortHdrConnIDLen int) *packetU
|
|||
// If the reserved bits are invalid, the error is wire.ErrInvalidReservedBits.
|
||||
// If any other error occurred when parsing the header, the error is of type headerParseError.
|
||||
// If decrypting the payload fails for any reason, the error is the error returned by the AEAD.
|
||||
func (u *packetUnpacker) UnpackLongHeader(hdr *wire.Header, rcvTime time.Time, data []byte, v protocol.VersionNumber) (*unpackedPacket, error) {
|
||||
func (u *packetUnpacker) UnpackLongHeader(hdr *wire.Header, rcvTime time.Time, data []byte, v protocol.Version) (*unpackedPacket, error) {
|
||||
var encLevel protocol.EncryptionLevel
|
||||
var extHdr *wire.ExtendedHeader
|
||||
var decrypted []byte
|
||||
|
@ -125,7 +125,7 @@ func (u *packetUnpacker) UnpackShortHeader(rcvTime time.Time, data []byte) (prot
|
|||
return pn, pnLen, kp, decrypted, nil
|
||||
}
|
||||
|
||||
func (u *packetUnpacker) unpackLongHeaderPacket(opener handshake.LongHeaderOpener, hdr *wire.Header, data []byte, v protocol.VersionNumber) (*wire.ExtendedHeader, []byte, error) {
|
||||
func (u *packetUnpacker) unpackLongHeaderPacket(opener handshake.LongHeaderOpener, hdr *wire.Header, data []byte, v protocol.Version) (*wire.ExtendedHeader, []byte, error) {
|
||||
extHdr, parseErr := u.unpackLongHeader(opener, hdr, data, v)
|
||||
// If the reserved bits are set incorrectly, we still need to continue unpacking.
|
||||
// This avoids a timing side-channel, which otherwise might allow an attacker
|
||||
|
@ -187,7 +187,7 @@ func (u *packetUnpacker) unpackShortHeader(hd headerDecryptor, data []byte) (int
|
|||
}
|
||||
|
||||
// The error is either nil, a wire.ErrInvalidReservedBits or of type headerParseError.
|
||||
func (u *packetUnpacker) unpackLongHeader(hd headerDecryptor, hdr *wire.Header, data []byte, v protocol.VersionNumber) (*wire.ExtendedHeader, error) {
|
||||
func (u *packetUnpacker) unpackLongHeader(hd headerDecryptor, hdr *wire.Header, data []byte, v protocol.Version) (*wire.ExtendedHeader, error) {
|
||||
extHdr, err := unpackLongHeader(hd, hdr, data, v)
|
||||
if err != nil && err != wire.ErrInvalidReservedBits {
|
||||
return nil, &headerParseError{err: err}
|
||||
|
@ -195,7 +195,7 @@ func (u *packetUnpacker) unpackLongHeader(hd headerDecryptor, hdr *wire.Header,
|
|||
return extHdr, err
|
||||
}
|
||||
|
||||
func unpackLongHeader(hd headerDecryptor, hdr *wire.Header, data []byte, v protocol.VersionNumber) (*wire.ExtendedHeader, error) {
|
||||
func unpackLongHeader(hd headerDecryptor, hdr *wire.Header, data []byte, v protocol.Version) (*wire.ExtendedHeader, error) {
|
||||
r := bytes.NewReader(data)
|
||||
|
||||
hdrLen := hdr.ParsedLen()
|
||||
|
|
|
@ -292,10 +292,6 @@ func (s *receiveStream) handleResetStreamFrameImpl(frame *wire.ResetStreamFrame)
|
|||
return newlyRcvdFinalOffset, nil
|
||||
}
|
||||
|
||||
func (s *receiveStream) CloseRemote(offset protocol.ByteCount) {
|
||||
s.handleStreamFrame(&wire.StreamFrame{Fin: true, Offset: offset})
|
||||
}
|
||||
|
||||
func (s *receiveStream) SetReadDeadline(t time.Time) error {
|
||||
s.mutex.Lock()
|
||||
s.deadline = t
|
||||
|
|
|
@ -74,7 +74,7 @@ func (q *retransmissionQueue) addAppData(f wire.Frame) {
|
|||
q.appData = append(q.appData, f)
|
||||
}
|
||||
|
||||
func (q *retransmissionQueue) GetInitialFrame(maxLen protocol.ByteCount, v protocol.VersionNumber) wire.Frame {
|
||||
func (q *retransmissionQueue) GetInitialFrame(maxLen protocol.ByteCount, v protocol.Version) wire.Frame {
|
||||
if len(q.initialCryptoData) > 0 {
|
||||
f := q.initialCryptoData[0]
|
||||
newFrame, needsSplit := f.MaybeSplitOffFrame(maxLen, v)
|
||||
|
@ -97,7 +97,7 @@ func (q *retransmissionQueue) GetInitialFrame(maxLen protocol.ByteCount, v proto
|
|||
return f
|
||||
}
|
||||
|
||||
func (q *retransmissionQueue) GetHandshakeFrame(maxLen protocol.ByteCount, v protocol.VersionNumber) wire.Frame {
|
||||
func (q *retransmissionQueue) GetHandshakeFrame(maxLen protocol.ByteCount, v protocol.Version) wire.Frame {
|
||||
if len(q.handshakeCryptoData) > 0 {
|
||||
f := q.handshakeCryptoData[0]
|
||||
newFrame, needsSplit := f.MaybeSplitOffFrame(maxLen, v)
|
||||
|
@ -120,7 +120,7 @@ func (q *retransmissionQueue) GetHandshakeFrame(maxLen protocol.ByteCount, v pro
|
|||
return f
|
||||
}
|
||||
|
||||
func (q *retransmissionQueue) GetAppDataFrame(maxLen protocol.ByteCount, v protocol.VersionNumber) wire.Frame {
|
||||
func (q *retransmissionQueue) GetAppDataFrame(maxLen protocol.ByteCount, v protocol.Version) wire.Frame {
|
||||
if len(q.appData) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -18,7 +18,7 @@ type sendStreamI interface {
|
|||
SendStream
|
||||
handleStopSendingFrame(*wire.StopSendingFrame)
|
||||
hasData() bool
|
||||
popStreamFrame(maxBytes protocol.ByteCount, v protocol.VersionNumber) (frame ackhandler.StreamFrame, ok, hasMore bool)
|
||||
popStreamFrame(maxBytes protocol.ByteCount, v protocol.Version) (frame ackhandler.StreamFrame, ok, hasMore bool)
|
||||
closeForShutdown(error)
|
||||
updateSendWindow(protocol.ByteCount)
|
||||
}
|
||||
|
@ -198,7 +198,7 @@ func (s *sendStream) canBufferStreamFrame() bool {
|
|||
|
||||
// popStreamFrame returns the next STREAM frame that is supposed to be sent on this stream
|
||||
// maxBytes is the maximum length this frame (including frame header) will have.
|
||||
func (s *sendStream) popStreamFrame(maxBytes protocol.ByteCount, v protocol.VersionNumber) (af ackhandler.StreamFrame, ok, hasMore bool) {
|
||||
func (s *sendStream) popStreamFrame(maxBytes protocol.ByteCount, v protocol.Version) (af ackhandler.StreamFrame, ok, hasMore bool) {
|
||||
s.mutex.Lock()
|
||||
f, hasMoreData := s.popNewOrRetransmittedStreamFrame(maxBytes, v)
|
||||
if f != nil {
|
||||
|
@ -215,7 +215,7 @@ func (s *sendStream) popStreamFrame(maxBytes protocol.ByteCount, v protocol.Vers
|
|||
}, true, hasMoreData
|
||||
}
|
||||
|
||||
func (s *sendStream) popNewOrRetransmittedStreamFrame(maxBytes protocol.ByteCount, v protocol.VersionNumber) (*wire.StreamFrame, bool /* has more data to send */) {
|
||||
func (s *sendStream) popNewOrRetransmittedStreamFrame(maxBytes protocol.ByteCount, v protocol.Version) (*wire.StreamFrame, bool /* has more data to send */) {
|
||||
if s.cancelWriteErr != nil || s.closeForShutdownErr != nil {
|
||||
return nil, false
|
||||
}
|
||||
|
@ -269,7 +269,7 @@ func (s *sendStream) popNewOrRetransmittedStreamFrame(maxBytes protocol.ByteCoun
|
|||
return f, hasMoreData
|
||||
}
|
||||
|
||||
func (s *sendStream) popNewStreamFrame(maxBytes, sendWindow protocol.ByteCount, v protocol.VersionNumber) (*wire.StreamFrame, bool) {
|
||||
func (s *sendStream) popNewStreamFrame(maxBytes, sendWindow protocol.ByteCount, v protocol.Version) (*wire.StreamFrame, bool) {
|
||||
if s.nextFrame != nil {
|
||||
nextFrame := s.nextFrame
|
||||
s.nextFrame = nil
|
||||
|
@ -304,7 +304,7 @@ func (s *sendStream) popNewStreamFrame(maxBytes, sendWindow protocol.ByteCount,
|
|||
return f, hasMoreData
|
||||
}
|
||||
|
||||
func (s *sendStream) popNewStreamFrameWithoutBuffer(f *wire.StreamFrame, maxBytes, sendWindow protocol.ByteCount, v protocol.VersionNumber) bool {
|
||||
func (s *sendStream) popNewStreamFrameWithoutBuffer(f *wire.StreamFrame, maxBytes, sendWindow protocol.ByteCount, v protocol.Version) bool {
|
||||
maxDataLen := f.MaxDataLen(maxBytes, v)
|
||||
if maxDataLen == 0 { // a STREAM frame must have at least one byte of data
|
||||
return s.dataForWriting != nil || s.nextFrame != nil || s.finishedWriting
|
||||
|
@ -314,7 +314,7 @@ func (s *sendStream) popNewStreamFrameWithoutBuffer(f *wire.StreamFrame, maxByte
|
|||
return s.dataForWriting != nil || s.nextFrame != nil || s.finishedWriting
|
||||
}
|
||||
|
||||
func (s *sendStream) maybeGetRetransmission(maxBytes protocol.ByteCount, v protocol.VersionNumber) (*wire.StreamFrame, bool /* has more retransmissions */) {
|
||||
func (s *sendStream) maybeGetRetransmission(maxBytes protocol.ByteCount, v protocol.Version) (*wire.StreamFrame, bool /* has more retransmissions */) {
|
||||
f := s.retransmissionQueue[0]
|
||||
newFrame, needsSplit := f.MaybeSplitOffFrame(maxBytes, v)
|
||||
if needsSplit {
|
||||
|
@ -404,11 +404,13 @@ func (s *sendStream) cancelWriteImpl(errorCode qerr.StreamErrorCode, remote bool
|
|||
}
|
||||
|
||||
func (s *sendStream) updateSendWindow(limit protocol.ByteCount) {
|
||||
updated := s.flowController.UpdateSendWindow(limit)
|
||||
if !updated { // duplicate or reordered MAX_STREAM_DATA frame
|
||||
return
|
||||
}
|
||||
s.mutex.Lock()
|
||||
hasStreamData := s.dataForWriting != nil || s.nextFrame != nil
|
||||
s.mutex.Unlock()
|
||||
|
||||
s.flowController.UpdateSendWindow(limit)
|
||||
if hasStreamData {
|
||||
s.sender.onHasStreamData(s.streamID)
|
||||
}
|
||||
|
|
|
@ -7,7 +7,6 @@ import (
|
|||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/handshake"
|
||||
|
@ -24,15 +23,14 @@ var ErrServerClosed = errors.New("quic: server closed")
|
|||
// packetHandler handles packets
|
||||
type packetHandler interface {
|
||||
handlePacket(receivedPacket)
|
||||
shutdown()
|
||||
destroy(error)
|
||||
getPerspective() protocol.Perspective
|
||||
closeWithTransportError(qerr.TransportErrorCode)
|
||||
}
|
||||
|
||||
type packetHandlerManager interface {
|
||||
Get(protocol.ConnectionID) (packetHandler, bool)
|
||||
GetByResetToken(protocol.StatelessResetToken) (packetHandler, bool)
|
||||
AddWithConnID(protocol.ConnectionID, protocol.ConnectionID, func() (packetHandler, bool)) bool
|
||||
AddWithConnID(destConnID, newConnID protocol.ConnectionID, h packetHandler) bool
|
||||
Close(error)
|
||||
connRunner
|
||||
}
|
||||
|
@ -41,11 +39,9 @@ type quicConn interface {
|
|||
EarlyConnection
|
||||
earlyConnReady() <-chan struct{}
|
||||
handlePacket(receivedPacket)
|
||||
GetVersion() protocol.VersionNumber
|
||||
getPerspective() protocol.Perspective
|
||||
run() error
|
||||
destroy(error)
|
||||
shutdown()
|
||||
closeWithTransportError(TransportErrorCode)
|
||||
}
|
||||
|
||||
type zeroRTTQueue struct {
|
||||
|
@ -98,10 +94,10 @@ type baseServer struct {
|
|||
*logging.ConnectionTracer,
|
||||
uint64,
|
||||
utils.Logger,
|
||||
protocol.VersionNumber,
|
||||
protocol.Version,
|
||||
) quicConn
|
||||
|
||||
closeOnce sync.Once
|
||||
closeMx sync.Mutex
|
||||
errorChan chan struct{} // is closed when the server is closed
|
||||
closeErr error
|
||||
running chan struct{} // closed as soon as run() returns
|
||||
|
@ -111,8 +107,9 @@ type baseServer struct {
|
|||
connectionRefusedQueue chan rejectedPacket
|
||||
retryQueue chan rejectedPacket
|
||||
|
||||
connQueue chan quicConn
|
||||
connQueueLen int32 // to be used as an atomic
|
||||
verifySourceAddress func(net.Addr) bool
|
||||
|
||||
connQueue chan quicConn
|
||||
|
||||
tracer *logging.Tracer
|
||||
|
||||
|
@ -240,6 +237,7 @@ func newServer(
|
|||
onClose func(),
|
||||
tokenGeneratorKey TokenGeneratorKey,
|
||||
maxTokenAge time.Duration,
|
||||
verifySourceAddress func(net.Addr) bool,
|
||||
disableVersionNegotiation bool,
|
||||
acceptEarly bool,
|
||||
) *baseServer {
|
||||
|
@ -249,9 +247,10 @@ func newServer(
|
|||
config: config,
|
||||
tokenGenerator: handshake.NewTokenGenerator(tokenGeneratorKey),
|
||||
maxTokenAge: maxTokenAge,
|
||||
verifySourceAddress: verifySourceAddress,
|
||||
connIDGenerator: connIDGenerator,
|
||||
connHandler: connHandler,
|
||||
connQueue: make(chan quicConn),
|
||||
connQueue: make(chan quicConn, protocol.MaxAcceptQueueSize),
|
||||
errorChan: make(chan struct{}),
|
||||
running: make(chan struct{}),
|
||||
receivedPackets: make(chan receivedPacket, protocol.MaxServerUnprocessedPackets),
|
||||
|
@ -322,7 +321,6 @@ func (s *baseServer) accept(ctx context.Context) (quicConn, error) {
|
|||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case conn := <-s.connQueue:
|
||||
atomic.AddInt32(&s.connQueueLen, -1)
|
||||
return conn, nil
|
||||
case <-s.errorChan:
|
||||
return nil, s.closeErr
|
||||
|
@ -335,15 +333,19 @@ func (s *baseServer) Close() error {
|
|||
}
|
||||
|
||||
func (s *baseServer) close(e error, notifyOnClose bool) {
|
||||
s.closeOnce.Do(func() {
|
||||
s.closeErr = e
|
||||
close(s.errorChan)
|
||||
s.closeMx.Lock()
|
||||
if s.closeErr != nil {
|
||||
s.closeMx.Unlock()
|
||||
return
|
||||
}
|
||||
s.closeErr = e
|
||||
close(s.errorChan)
|
||||
<-s.running
|
||||
s.closeMx.Unlock()
|
||||
|
||||
<-s.running
|
||||
if notifyOnClose {
|
||||
s.onClose()
|
||||
}
|
||||
})
|
||||
if notifyOnClose {
|
||||
s.onClose()
|
||||
}
|
||||
}
|
||||
|
||||
// Addr returns the server's network address
|
||||
|
@ -542,10 +544,10 @@ func (s *baseServer) validateToken(token *handshake.Token, addr net.Addr) bool {
|
|||
|
||||
func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error {
|
||||
if len(hdr.Token) == 0 && hdr.DestConnectionID.Len() < protocol.MinConnectionIDLenInitial {
|
||||
p.buffer.Release()
|
||||
if s.tracer != nil && s.tracer.DroppedPacket != nil {
|
||||
s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropUnexpectedPacket)
|
||||
}
|
||||
p.buffer.Release()
|
||||
return errors.New("too short connection ID")
|
||||
}
|
||||
|
||||
|
@ -558,8 +560,9 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error
|
|||
}
|
||||
|
||||
var (
|
||||
token *handshake.Token
|
||||
retrySrcConnID *protocol.ConnectionID
|
||||
token *handshake.Token
|
||||
retrySrcConnID *protocol.ConnectionID
|
||||
clientAddrVerified bool
|
||||
)
|
||||
origDestConnID := hdr.DestConnectionID
|
||||
if len(hdr.Token) > 0 {
|
||||
|
@ -572,28 +575,30 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error
|
|||
token = tok
|
||||
}
|
||||
}
|
||||
|
||||
clientAddrIsValid := s.validateToken(token, p.remoteAddr)
|
||||
if token != nil && !clientAddrIsValid {
|
||||
// For invalid and expired non-retry tokens, we don't send an INVALID_TOKEN error.
|
||||
// We just ignore them, and act as if there was no token on this packet at all.
|
||||
// This also means we might send a Retry later.
|
||||
if !token.IsRetryToken {
|
||||
token = nil
|
||||
} else {
|
||||
// For Retry tokens, we send an INVALID_ERROR if
|
||||
// * the token is too old, or
|
||||
// * the token is invalid, in case of a retry token.
|
||||
select {
|
||||
case s.invalidTokenQueue <- rejectedPacket{receivedPacket: p, hdr: hdr}:
|
||||
default:
|
||||
// drop packet if we can't send out the INVALID_TOKEN packets fast enough
|
||||
p.buffer.Release()
|
||||
if token != nil {
|
||||
clientAddrVerified = s.validateToken(token, p.remoteAddr)
|
||||
if !clientAddrVerified {
|
||||
// For invalid and expired non-retry tokens, we don't send an INVALID_TOKEN error.
|
||||
// We just ignore them, and act as if there was no token on this packet at all.
|
||||
// This also means we might send a Retry later.
|
||||
if !token.IsRetryToken {
|
||||
token = nil
|
||||
} else {
|
||||
// For Retry tokens, we send an INVALID_ERROR if
|
||||
// * the token is too old, or
|
||||
// * the token is invalid, in case of a retry token.
|
||||
select {
|
||||
case s.invalidTokenQueue <- rejectedPacket{receivedPacket: p, hdr: hdr}:
|
||||
default:
|
||||
// drop packet if we can't send out the INVALID_TOKEN packets fast enough
|
||||
p.buffer.Release()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
if token == nil && s.config.RequireAddressValidation(p.remoteAddr) {
|
||||
|
||||
if token == nil && s.verifySourceAddress != nil && s.verifySourceAddress(p.remoteAddr) {
|
||||
// Retry invalidates all 0-RTT packets sent.
|
||||
delete(s.zeroRTTQueues, hdr.DestConnectionID)
|
||||
select {
|
||||
|
@ -605,121 +610,116 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error
|
|||
return nil
|
||||
}
|
||||
|
||||
if queueLen := atomic.LoadInt32(&s.connQueueLen); queueLen >= protocol.MaxAcceptQueueSize {
|
||||
s.logger.Debugf("Rejecting new connection. Server currently busy. Accept queue length: %d (max %d)", queueLen, protocol.MaxAcceptQueueSize)
|
||||
select {
|
||||
case s.connectionRefusedQueue <- rejectedPacket{receivedPacket: p, hdr: hdr}:
|
||||
default:
|
||||
// drop packet if we can't send out the CONNECTION_REFUSED fast enough
|
||||
p.buffer.Release()
|
||||
config := s.config
|
||||
if s.config.GetConfigForClient != nil {
|
||||
conf, err := s.config.GetConfigForClient(&ClientHelloInfo{
|
||||
RemoteAddr: p.remoteAddr,
|
||||
AddrVerified: clientAddrVerified,
|
||||
})
|
||||
if err != nil {
|
||||
s.logger.Debugf("Rejecting new connection due to GetConfigForClient callback")
|
||||
delete(s.zeroRTTQueues, hdr.DestConnectionID)
|
||||
select {
|
||||
case s.connectionRefusedQueue <- rejectedPacket{receivedPacket: p, hdr: hdr}:
|
||||
default:
|
||||
// drop packet if we can't send out the CONNECTION_REFUSED fast enough
|
||||
p.buffer.Release()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
return nil
|
||||
config = populateConfig(conf)
|
||||
}
|
||||
|
||||
var conn quicConn
|
||||
tracingID := nextConnTracingID()
|
||||
var tracer *logging.ConnectionTracer
|
||||
if config.Tracer != nil {
|
||||
// Use the same connection ID that is passed to the client's GetLogWriter callback.
|
||||
connID := hdr.DestConnectionID
|
||||
if origDestConnID.Len() > 0 {
|
||||
connID = origDestConnID
|
||||
}
|
||||
tracer = config.Tracer(context.WithValue(context.Background(), ConnectionTracingKey, tracingID), protocol.PerspectiveServer, connID)
|
||||
}
|
||||
connID, err := s.connIDGenerator.GenerateConnectionID()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.logger.Debugf("Changing connection ID to %s.", connID)
|
||||
var conn quicConn
|
||||
tracingID := nextConnTracingID()
|
||||
if added := s.connHandler.AddWithConnID(hdr.DestConnectionID, connID, func() (packetHandler, bool) {
|
||||
config := s.config
|
||||
if s.config.GetConfigForClient != nil {
|
||||
conf, err := s.config.GetConfigForClient(&ClientHelloInfo{RemoteAddr: p.remoteAddr})
|
||||
if err != nil {
|
||||
s.logger.Debugf("Rejecting new connection due to GetConfigForClient callback")
|
||||
return nil, false
|
||||
}
|
||||
config = populateConfig(conf)
|
||||
}
|
||||
var tracer *logging.ConnectionTracer
|
||||
if config.Tracer != nil {
|
||||
// Use the same connection ID that is passed to the client's GetLogWriter callback.
|
||||
connID := hdr.DestConnectionID
|
||||
if origDestConnID.Len() > 0 {
|
||||
connID = origDestConnID
|
||||
}
|
||||
tracer = config.Tracer(context.WithValue(context.Background(), ConnectionTracingKey, tracingID), protocol.PerspectiveServer, connID)
|
||||
}
|
||||
conn = s.newConn(
|
||||
newSendConn(s.conn, p.remoteAddr, p.info, s.logger),
|
||||
s.connHandler,
|
||||
origDestConnID,
|
||||
retrySrcConnID,
|
||||
hdr.DestConnectionID,
|
||||
hdr.SrcConnectionID,
|
||||
connID,
|
||||
s.connIDGenerator,
|
||||
s.connHandler.GetStatelessResetToken(connID),
|
||||
config,
|
||||
s.tlsConf,
|
||||
s.tokenGenerator,
|
||||
clientAddrIsValid,
|
||||
tracer,
|
||||
tracingID,
|
||||
s.logger,
|
||||
hdr.Version,
|
||||
)
|
||||
conn.handlePacket(p)
|
||||
|
||||
if q, ok := s.zeroRTTQueues[hdr.DestConnectionID]; ok {
|
||||
for _, p := range q.packets {
|
||||
conn.handlePacket(p)
|
||||
}
|
||||
delete(s.zeroRTTQueues, hdr.DestConnectionID)
|
||||
}
|
||||
|
||||
return conn, true
|
||||
}); !added {
|
||||
select {
|
||||
case s.connectionRefusedQueue <- rejectedPacket{receivedPacket: p, hdr: hdr}:
|
||||
default:
|
||||
// drop packet if we can't send out the CONNECTION_REFUSED fast enough
|
||||
p.buffer.Release()
|
||||
}
|
||||
conn = s.newConn(
|
||||
newSendConn(s.conn, p.remoteAddr, p.info, s.logger),
|
||||
s.connHandler,
|
||||
origDestConnID,
|
||||
retrySrcConnID,
|
||||
hdr.DestConnectionID,
|
||||
hdr.SrcConnectionID,
|
||||
connID,
|
||||
s.connIDGenerator,
|
||||
s.connHandler.GetStatelessResetToken(connID),
|
||||
config,
|
||||
s.tlsConf,
|
||||
s.tokenGenerator,
|
||||
clientAddrVerified,
|
||||
tracer,
|
||||
tracingID,
|
||||
s.logger,
|
||||
hdr.Version,
|
||||
)
|
||||
conn.handlePacket(p)
|
||||
// Adding the connection will fail if the client's chosen Destination Connection ID is already in use.
|
||||
// This is very unlikely: Even if an attacker chooses a connection ID that's already in use,
|
||||
// under normal circumstances the packet would just be routed to that connection.
|
||||
// The only time this collision will occur if we receive the two Initial packets at the same time.
|
||||
if added := s.connHandler.AddWithConnID(hdr.DestConnectionID, connID, conn); !added {
|
||||
delete(s.zeroRTTQueues, hdr.DestConnectionID)
|
||||
conn.closeWithTransportError(qerr.ConnectionRefused)
|
||||
return nil
|
||||
}
|
||||
// Pass queued 0-RTT to the newly established connection.
|
||||
if q, ok := s.zeroRTTQueues[hdr.DestConnectionID]; ok {
|
||||
for _, p := range q.packets {
|
||||
conn.handlePacket(p)
|
||||
}
|
||||
delete(s.zeroRTTQueues, hdr.DestConnectionID)
|
||||
}
|
||||
|
||||
go conn.run()
|
||||
go s.handleNewConn(conn)
|
||||
if conn == nil {
|
||||
p.buffer.Release()
|
||||
return nil
|
||||
}
|
||||
go func() {
|
||||
if completed := s.handleNewConn(conn); !completed {
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case s.connQueue <- conn:
|
||||
default:
|
||||
conn.closeWithTransportError(ConnectionRefused)
|
||||
}
|
||||
}()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *baseServer) handleNewConn(conn quicConn) {
|
||||
connCtx := conn.Context()
|
||||
func (s *baseServer) handleNewConn(conn quicConn) bool {
|
||||
if s.acceptEarlyConns {
|
||||
// wait until the early connection is ready, the handshake fails, or the server is closed
|
||||
select {
|
||||
case <-s.errorChan:
|
||||
conn.destroy(&qerr.TransportError{ErrorCode: ConnectionRefused})
|
||||
return
|
||||
conn.closeWithTransportError(ConnectionRefused)
|
||||
return false
|
||||
case <-conn.Context().Done():
|
||||
return false
|
||||
case <-conn.earlyConnReady():
|
||||
case <-connCtx.Done():
|
||||
return
|
||||
}
|
||||
} else {
|
||||
// wait until the handshake is complete (or fails)
|
||||
select {
|
||||
case <-s.errorChan:
|
||||
conn.destroy(&qerr.TransportError{ErrorCode: ConnectionRefused})
|
||||
return
|
||||
case <-conn.HandshakeComplete():
|
||||
case <-connCtx.Done():
|
||||
return
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
atomic.AddInt32(&s.connQueueLen, 1)
|
||||
// wait until the handshake completes, fails, or the server is closed
|
||||
select {
|
||||
case s.connQueue <- conn:
|
||||
// blocks until the connection is accepted
|
||||
case <-connCtx.Done():
|
||||
atomic.AddInt32(&s.connQueueLen, -1)
|
||||
// don't pass connections that were already closed to Accept()
|
||||
case <-s.errorChan:
|
||||
conn.closeWithTransportError(ConnectionRefused)
|
||||
return false
|
||||
case <-conn.Context().Done():
|
||||
return false
|
||||
case <-conn.HandshakeComplete():
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -60,7 +60,7 @@ type streamI interface {
|
|||
// for sending
|
||||
hasData() bool
|
||||
handleStopSendingFrame(*wire.StopSendingFrame)
|
||||
popStreamFrame(maxBytes protocol.ByteCount, v protocol.VersionNumber) (ackhandler.StreamFrame, bool, bool)
|
||||
popStreamFrame(maxBytes protocol.ByteCount, v protocol.Version) (ackhandler.StreamFrame, bool, bool)
|
||||
updateSendWindow(protocol.ByteCount)
|
||||
}
|
||||
|
||||
|
|
|
@ -221,7 +221,7 @@ func (c *oobConn) ReadPacket() (receivedPacket, error) {
|
|||
// unsigned int ipi6_ifindex; /* send/recv interface index */
|
||||
// };
|
||||
if len(body) == 20 {
|
||||
p.info.addr = netip.AddrFrom16(*(*[16]byte)(body[:16]))
|
||||
p.info.addr = netip.AddrFrom16(*(*[16]byte)(body[:16])).Unmap()
|
||||
p.info.ifIndex = binary.LittleEndian.Uint32(body[16:])
|
||||
} else {
|
||||
invalidCmsgOnceV6.Do(func() {
|
||||
|
|
|
@ -41,7 +41,8 @@ type Transport struct {
|
|||
Conn net.PacketConn
|
||||
|
||||
// The length of the connection ID in bytes.
|
||||
// It can be 0, or any value between 4 and 18.
|
||||
// It can be any value between 1 and 20.
|
||||
// Due to the increased risk of collisions, it is not recommended to use connection IDs shorter than 4 bytes.
|
||||
// If unset, a 4 byte connection ID will be used.
|
||||
ConnectionIDLength int
|
||||
|
||||
|
@ -77,7 +78,19 @@ type Transport struct {
|
|||
// It has no effect for clients.
|
||||
DisableVersionNegotiationPackets bool
|
||||
|
||||
// VerifySourceAddress decides if a connection attempt originating from unvalidated source
|
||||
// addresses first needs to go through source address validation using QUIC's Retry mechanism,
|
||||
// as described in RFC 9000 section 8.1.2.
|
||||
// Note that the address passed to this callback is unvalidated, and might be spoofed in case
|
||||
// of an attack.
|
||||
// Validating the source address adds one additional network roundtrip to the handshake,
|
||||
// and should therefore only be used if a suspiciously high number of incoming connection is recorded.
|
||||
// For most use cases, wrapping the Allow function of a rate.Limiter will be a reasonable
|
||||
// implementation of this callback (negating its return value).
|
||||
VerifySourceAddress func(net.Addr) bool
|
||||
|
||||
// A Tracer traces events that don't belong to a single QUIC connection.
|
||||
// Tracer.Close is called when the transport is closed.
|
||||
Tracer *logging.Tracer
|
||||
|
||||
handlerMap packetHandlerManager
|
||||
|
@ -147,7 +160,7 @@ func (t *Transport) createServer(tlsConf *tls.Config, conf *Config, allow0RTT bo
|
|||
if t.server != nil {
|
||||
return nil, errListenerAlreadySet
|
||||
}
|
||||
conf = populateServerConfig(conf)
|
||||
conf = populateConfig(conf)
|
||||
if err := t.init(false); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -161,6 +174,7 @@ func (t *Transport) createServer(tlsConf *tls.Config, conf *Config, allow0RTT bo
|
|||
t.closeServer,
|
||||
*t.TokenGeneratorKey,
|
||||
t.MaxTokenAge,
|
||||
t.VerifySourceAddress,
|
||||
t.DisableVersionNegotiationPackets,
|
||||
allow0RTT,
|
||||
)
|
||||
|
@ -323,6 +337,9 @@ func (t *Transport) close(e error) {
|
|||
if t.server != nil {
|
||||
t.server.close(e, false)
|
||||
}
|
||||
if t.Tracer != nil && t.Tracer.Close != nil {
|
||||
t.Tracer.Close()
|
||||
}
|
||||
t.closed = true
|
||||
}
|
||||
|
||||
|
@ -379,13 +396,21 @@ func (t *Transport) handlePacket(p receivedPacket) {
|
|||
return
|
||||
}
|
||||
|
||||
if isStatelessReset := t.maybeHandleStatelessReset(p.data); isStatelessReset {
|
||||
return
|
||||
}
|
||||
// If there's a connection associated with the connection ID, pass the packet there.
|
||||
if handler, ok := t.handlerMap.Get(connID); ok {
|
||||
handler.handlePacket(p)
|
||||
return
|
||||
}
|
||||
// RFC 9000 section 10.3.1 requires that the stateless reset detection logic is run for both
|
||||
// packets that cannot be associated with any connections, and for packets that can't be decrypted.
|
||||
// We deviate from the RFC and ignore the latter: If a packet's connection ID is associated with an
|
||||
// existing connection, it is dropped there if if it can't be decrypted.
|
||||
// Stateless resets use random connection IDs, and at reasonable connection ID lengths collisions are
|
||||
// exceedingly rare. In the unlikely event that a stateless reset is misrouted to an existing connection,
|
||||
// it is to be expected that the next stateless reset will be correctly detected.
|
||||
if isStatelessReset := t.maybeHandleStatelessReset(p.data); isStatelessReset {
|
||||
return
|
||||
}
|
||||
if !wire.IsLongHeaderPacket(p.data[0]) {
|
||||
t.maybeSendStatelessReset(p)
|
||||
return
|
||||
|
|
|
@ -1,37 +0,0 @@
|
|||
# This is the official list of people who can contribute (and typically
|
||||
# have contributed) code to the gomock repository.
|
||||
# The AUTHORS file lists the copyright holders; this file
|
||||
# lists people. For example, Google employees are listed here
|
||||
# but not in AUTHORS, because Google holds the copyright.
|
||||
#
|
||||
# The submission process automatically checks to make sure
|
||||
# that people submitting code are listed in this file (by email address).
|
||||
#
|
||||
# Names should be added to this file only after verifying that
|
||||
# the individual or the individual's organization has agreed to
|
||||
# the appropriate Contributor License Agreement, found here:
|
||||
#
|
||||
# http://code.google.com/legal/individual-cla-v1.0.html
|
||||
# http://code.google.com/legal/corporate-cla-v1.0.html
|
||||
#
|
||||
# The agreement for individuals can be filled out on the web.
|
||||
#
|
||||
# When adding J Random Contributor's name to this file,
|
||||
# either J's name or J's organization's name should be
|
||||
# added to the AUTHORS file, depending on whether the
|
||||
# individual or corporate CLA was used.
|
||||
|
||||
# Names should be added to this file like so:
|
||||
# Name <email address>
|
||||
#
|
||||
# An entry with two email addresses specifies that the
|
||||
# first address should be used in the submit logs and
|
||||
# that the second address should be recognized as the
|
||||
# same person when interacting with Rietveld.
|
||||
|
||||
# Please keep the list sorted.
|
||||
|
||||
Aaron Jacobs <jacobsa@google.com> <aaronjjacobs@gmail.com>
|
||||
Alex Reece <awreece@gmail.com>
|
||||
David Symonds <dsymonds@golang.org>
|
||||
Ryan Barrett <ryanb@google.com>
|
|
@ -11,8 +11,10 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"go/ast"
|
||||
"go/token"
|
||||
"strings"
|
||||
|
||||
"go.uber.org/mock/mockgen/model"
|
||||
|
@ -98,6 +100,16 @@ func (p *fileParser) parseGenericMethod(field *ast.Field, it *namedInterface, if
|
|||
case *ast.IndexListExpr:
|
||||
indices = v.Indices
|
||||
typ = v.X
|
||||
case *ast.UnaryExpr:
|
||||
if v.Op == token.TILDE {
|
||||
return nil, errConstraintInterface
|
||||
}
|
||||
return nil, fmt.Errorf("~T may only appear as constraint for %T", field.Type)
|
||||
case *ast.BinaryExpr:
|
||||
if v.Op == token.OR {
|
||||
return nil, errConstraintInterface
|
||||
}
|
||||
return nil, fmt.Errorf("A|B may only appear as constraint for %T", field.Type)
|
||||
default:
|
||||
return nil, fmt.Errorf("don't know how to mock method of type %T", field.Type)
|
||||
}
|
||||
|
@ -114,3 +126,5 @@ func (p *fileParser) parseGenericMethod(field *ast.Field, it *namedInterface, if
|
|||
|
||||
return p.parseMethod(nf, it, iface, pkg, tps)
|
||||
}
|
||||
|
||||
var errConstraintInterface = errors.New("interface contains constraints")
|
||||
|
|
|
@ -31,6 +31,7 @@ import (
|
|||
"os/exec"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
@ -314,8 +315,14 @@ func (g *generator) Generate(pkg *model.Package, outputPkgName string, outputPac
|
|||
}
|
||||
g.p("//")
|
||||
g.p("// Generated by this command:")
|
||||
g.p("//")
|
||||
// only log the name of the executable, not the full path
|
||||
g.p("// %v", strings.Join(append([]string{filepath.Base(os.Args[0])}, os.Args[1:]...), " "))
|
||||
name := filepath.Base(os.Args[0])
|
||||
if runtime.GOOS == "windows" {
|
||||
name = strings.TrimSuffix(name, ".exe")
|
||||
}
|
||||
g.p("//\t%v", strings.Join(append([]string{name}, os.Args[1:]...), " "))
|
||||
g.p("//")
|
||||
|
||||
// Get all required imports, and generate unique names for them all.
|
||||
im := pkg.Imports()
|
||||
|
@ -371,7 +378,7 @@ func (g *generator) Generate(pkg *model.Package, outputPkgName string, outputPac
|
|||
}
|
||||
|
||||
i := 0
|
||||
for localNames[pkgName] || token.Lookup(pkgName).IsKeyword() {
|
||||
for localNames[pkgName] || token.Lookup(pkgName).IsKeyword() || pkgName == "any" {
|
||||
pkgName = base + strconv.Itoa(i)
|
||||
i++
|
||||
}
|
||||
|
@ -386,6 +393,10 @@ func (g *generator) Generate(pkg *model.Package, outputPkgName string, outputPac
|
|||
}
|
||||
|
||||
if *writePkgComment {
|
||||
// Ensure there's an empty line before the package to follow the recommendations:
|
||||
// https://github.com/golang/go/wiki/CodeReviewComments#package-comments
|
||||
g.p("")
|
||||
|
||||
g.p("// Package %v is a generated GoMock package.", outputPkgName)
|
||||
}
|
||||
g.p("package %v", outputPkgName)
|
||||
|
@ -508,7 +519,7 @@ func (g *generator) GenerateMockMethods(mockType string, intf *model.Interface,
|
|||
g.p("")
|
||||
_ = g.GenerateMockMethod(mockType, m, pkgOverride, shortTp)
|
||||
g.p("")
|
||||
_ = g.GenerateMockRecorderMethod(intf, mockType, m, shortTp, typed)
|
||||
_ = g.GenerateMockRecorderMethod(intf, m, shortTp, typed)
|
||||
if typed {
|
||||
g.p("")
|
||||
_ = g.GenerateMockReturnCallMethod(intf, m, pkgOverride, longTp, shortTp)
|
||||
|
@ -596,7 +607,8 @@ func (g *generator) GenerateMockMethod(mockType string, m *model.Method, pkgOver
|
|||
return nil
|
||||
}
|
||||
|
||||
func (g *generator) GenerateMockRecorderMethod(intf *model.Interface, mockType string, m *model.Method, shortTp string, typed bool) error {
|
||||
func (g *generator) GenerateMockRecorderMethod(intf *model.Interface, m *model.Method, shortTp string, typed bool) error {
|
||||
mockType := g.mockName(intf.Name)
|
||||
argNames := g.getArgNames(m, true)
|
||||
|
||||
var argString string
|
||||
|
@ -621,7 +633,7 @@ func (g *generator) GenerateMockRecorderMethod(intf *model.Interface, mockType s
|
|||
|
||||
g.p("// %v indicates an expected call of %v.", m.Name, m.Name)
|
||||
if typed {
|
||||
g.p("func (%s *%vMockRecorder%v) %v(%v) *%s%sCall%s {", idRecv, mockType, shortTp, m.Name, argString, intf.Name, m.Name, shortTp)
|
||||
g.p("func (%s *%vMockRecorder%v) %v(%v) *%s%sCall%s {", idRecv, mockType, shortTp, m.Name, argString, mockType, m.Name, shortTp)
|
||||
} else {
|
||||
g.p("func (%s *%vMockRecorder%v) %v(%v) *gomock.Call {", idRecv, mockType, shortTp, m.Name, argString)
|
||||
}
|
||||
|
@ -650,7 +662,7 @@ func (g *generator) GenerateMockRecorderMethod(intf *model.Interface, mockType s
|
|||
}
|
||||
if typed {
|
||||
g.p(`call := %s.mock.ctrl.RecordCallWithMethodType(%s.mock, "%s", reflect.TypeOf((*%s%s)(nil).%s)%s)`, idRecv, idRecv, m.Name, mockType, shortTp, m.Name, callArgs)
|
||||
g.p(`return &%s%sCall%s{Call: call}`, intf.Name, m.Name, shortTp)
|
||||
g.p(`return &%s%sCall%s{Call: call}`, mockType, m.Name, shortTp)
|
||||
} else {
|
||||
g.p(`return %s.mock.ctrl.RecordCallWithMethodType(%s.mock, "%s", reflect.TypeOf((*%s%s)(nil).%s)%s)`, idRecv, idRecv, m.Name, mockType, shortTp, m.Name, callArgs)
|
||||
}
|
||||
|
@ -661,6 +673,7 @@ func (g *generator) GenerateMockRecorderMethod(intf *model.Interface, mockType s
|
|||
}
|
||||
|
||||
func (g *generator) GenerateMockReturnCallMethod(intf *model.Interface, m *model.Method, pkgOverride, longTp, shortTp string) error {
|
||||
mockType := g.mockName(intf.Name)
|
||||
argNames := g.getArgNames(m, true /* in */)
|
||||
retNames := g.getArgNames(m, false /* out */)
|
||||
argTypes := g.getArgTypes(m, pkgOverride, true /* in */)
|
||||
|
@ -683,10 +696,10 @@ func (g *generator) GenerateMockReturnCallMethod(intf *model.Interface, m *model
|
|||
ia := newIdentifierAllocator(argNames)
|
||||
idRecv := ia.allocateIdentifier("c")
|
||||
|
||||
recvStructName := intf.Name + m.Name
|
||||
recvStructName := mockType + m.Name
|
||||
|
||||
g.p("// %s%sCall wrap *gomock.Call", intf.Name, m.Name)
|
||||
g.p("type %s%sCall%s struct{", intf.Name, m.Name, longTp)
|
||||
g.p("// %s%sCall wrap *gomock.Call", mockType, m.Name)
|
||||
g.p("type %s%sCall%s struct{", mockType, m.Name, longTp)
|
||||
g.in()
|
||||
g.p("*gomock.Call")
|
||||
g.out()
|
||||
|
|
|
@ -305,7 +305,7 @@ type PredeclaredType string
|
|||
func (pt PredeclaredType) String(map[string]string, string) string { return string(pt) }
|
||||
func (pt PredeclaredType) addImports(map[string]bool) {}
|
||||
|
||||
// TypeParametersType contains type paramters for a NamedType.
|
||||
// TypeParametersType contains type parameters for a NamedType.
|
||||
type TypeParametersType struct {
|
||||
TypeParameters []Type
|
||||
}
|
||||
|
|
|
@ -232,6 +232,9 @@ func (p *fileParser) parseFile(importPath string, file *ast.File) (*model.Packag
|
|||
continue
|
||||
}
|
||||
i, err := p.parseInterface(ni.name.String(), importPath, ni)
|
||||
if errors.Is(err, errConstraintInterface) {
|
||||
continue
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -187,6 +187,7 @@ type reflectData struct {
|
|||
// gob encoding of a model.Package to standard output.
|
||||
// JSON doesn't work because of the model.Type interface.
|
||||
var reflectProgram = template.Must(template.New("program").Parse(`
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
package main
|
||||
|
||||
import (
|
||||
|
|
|
@ -230,7 +230,7 @@ github.com/prometheus/common/model
|
|||
github.com/prometheus/procfs
|
||||
github.com/prometheus/procfs/internal/fs
|
||||
github.com/prometheus/procfs/internal/util
|
||||
# github.com/quic-go/quic-go v0.40.1-0.20240101045026-22b7f7744eb6
|
||||
# github.com/quic-go/quic-go v0.42.0
|
||||
## explicit; go 1.21
|
||||
github.com/quic-go/quic-go
|
||||
github.com/quic-go/quic-go/internal/ackhandler
|
||||
|
@ -313,7 +313,7 @@ go.opentelemetry.io/proto/otlp/trace/v1
|
|||
go.uber.org/automaxprocs/internal/cgroups
|
||||
go.uber.org/automaxprocs/internal/runtime
|
||||
go.uber.org/automaxprocs/maxprocs
|
||||
# go.uber.org/mock v0.3.0
|
||||
# go.uber.org/mock v0.4.0
|
||||
## explicit; go 1.20
|
||||
go.uber.org/mock/mockgen
|
||||
go.uber.org/mock/mockgen/model
|
||||
|
|
Loading…
Reference in New Issue